# -*- coding: utf-8 -*-
"""
网格交易策略
- 在指定价格区间内布设等距买卖挂单
- 价格触及网格线 → 成交后自动在反向布设新挂单
- 支持: 中性网格 / 偏多网格 / 偏空网格
- 每成交一对买卖 → 锁定网格利润
"""
import logging
from .base_strategy import BaseStrategy, StrategyConfig, Signal

logger = logging.getLogger("MyTrader")


class GridStrategy(BaseStrategy):
    """网格交易策略"""

    def __init__(self, executor, config: StrategyConfig = None):
        if config is None:
            config = StrategyConfig(
                name="grid", symbol="ETH",
                max_position_pct=0.4, cooldown_bars=0,
                stop_loss_pct=0.05, take_profit_pct=0.003,
            )
        super().__init__(executor, config)
        self.grid_count = 10           # 网格层数
        self.grid_range_pct = 0.05     # 网格总范围 (±5%)
        self.bias = "neutral"          # neutral / long / short
        self._grid_levels = []         # [(price, side), ...]
        self._filled = set()           # 已成交价格
        self._initialized = False
        self._reference_price = 0

    def set_bias(self, bias: str):
        """设置网格偏向: neutral, long, short"""
        if bias in ("neutral", "long", "short"):
            self.bias = bias

    def on_bar(self, o: float, h: float, l: float, c: float, v: float, timestamp: int) -> Signal:
        # 首次初始化：基于当前价建立网格
        if not self._initialized:
            self._reference_price = c
            self._build_grid(c)
            self._initialized = True
            logger.info(f"[grid] 网格初始化 ref={c} levels={len(self._grid_levels)} bias={self.bias}")
            # 立即布设反向第一层
            return self._check_and_signal(c)

        # 每一根bar检查是否需要重新布网（价格偏离超过一半范围）
        if self._reference_price > 0:
            deviation = abs(c - self._reference_price) / self._reference_price
            if deviation > self.grid_range_pct * 0.6:
                # 清空旧网格，以新价格重建
                self._filled.clear()
                self._build_grid(c)
                self._reference_price = c
                logger.info(f"[grid] 网格重置 ref={c} deviation={deviation:.3f}")

        return self._check_and_signal(c)

    def _build_grid(self, center: float):
        """构建网格层级"""
        self._grid_levels = []
        step = center * self.grid_range_pct / (self.grid_count / 2)

        bias_shift = 0
        if self.bias == "long":
            bias_shift = step * 2  # 买方网格更密
        elif self.bias == "short":
            bias_shift = -step * 2

        start = center - self.grid_range_pct * center + bias_shift
        for i in range(self.grid_count):
            price = round(start + i * step, 2)
            if price <= 0:
                continue
            # 下半是买单，上半是卖单
            if price < center:
                side = "buy"
            elif price > center:
                side = "sell"
            else:
                continue  # 跳过中心价
            self._grid_levels.append((price, side))

        self._grid_levels.sort(key=lambda x: x[0])

    def _check_and_signal(self, current_price: float) -> Signal:
        """检查是否需要触发网格信号"""
        pos = self.executor.get_position(self.symbol)
        has_long = pos["long"] and pos["long"]["qty"] > 0
        has_short = pos["short"] and pos["short"]["qty"] > 0

        # 找到当前价附近的未成交网格线
        for price, side in self._grid_levels:
            if price in self._filled:
                continue
            # 买单：价格触及或低于网格线
            if side == "buy" and current_price <= price:
                if not has_long:  # 没有多仓才买
                    self._filled.add(price)
                    return Signal("LONG", size=self._calc_size(), price=current_price,
                                 score=0.5, reason=f"网格买入 {price}")
            # 卖单：价格触及或高于网格线
            elif side == "sell" and current_price >= price:
                if not has_short:  # 没有空仓才卖
                    self._filled.add(price)
                    return Signal("SHORT", size=self._calc_size(), price=current_price,
                                 score=-0.5, reason=f"网格卖出 {price}")

        # 网格止盈：已有仓位且反向触及下一级
        if has_long and current_price > pos["long"]["entry"] * (1 + self.grid_range_pct / self.grid_count):
            return Signal("CLOSE", price=current_price, reason=f"网格止盈 +{self.grid_range_pct / self.grid_count:.4f}")

        if has_short and current_price < pos["short"]["entry"] * (1 - self.grid_range_pct / self.grid_count):
            return Signal("CLOSE", price=current_price, reason=f"网格止盈 +{self.grid_range_pct / self.grid_count:.4f}")

        return Signal()

    def _calc_size(self) -> float:
        try:
            _, equity = self.executor.okx.balance()
            per_grid = equity * self.cfg.max_position_pct / self.grid_count / 1000
            return max(1, round(per_grid, 1))
        except Exception:
            return 1
