# -*- coding: utf-8 -*-
"""
均值回归策略
- 布林带回归: 价格触及下轨 → LONG, 触及上轨 → SHORT
- RSI极端回归: RSI<30 → LONG, RSI>70 → SHORT
- 两个条件同时满足才发信号（双重确认）
"""
import logging
from .base_strategy import BaseStrategy, StrategyConfig, Signal

logger = logging.getLogger("MyTrader")


class MeanReversionStrategy(BaseStrategy):
    """均值回归策略"""

    def __init__(self, executor, config: StrategyConfig = None):
        if config is None:
            config = StrategyConfig(
                name="mean_revert", symbol="ETH",
                max_position_pct=0.2, cooldown_bars=3,
                stop_loss_pct=0.015, take_profit_pct=0.025,
            )
        super().__init__(executor, config)
        self.bb_period = 20       # 布林带周期
        self.bb_std = 2.0         # 标准差倍数
        self.rsi_period = 14      # RSI周期
        self.rsi_oversold = 30    # 超卖阈值
        self.rsi_overbought = 70  # 超买阈值

    def on_bar(self, o: float, h: float, l: float, c: float, v: float, timestamp: int) -> Signal:
        n = self._bar_count
        if n < self.bb_period + 5:
            return Signal()

        # 布林带
        prices = self.get_prices(self.bb_period)
        sma_val = sum(prices[-self.bb_period:]) / self.bb_period
        variance = sum((p - sma_val) ** 2 for p in prices[-self.bb_period:]) / self.bb_period
        std_val = variance ** 0.5
        upper = sma_val + self.bb_std * std_val
        lower = sma_val - self.bb_std * std_val

        # 布林带位置
        bb_pct = (c - lower) / (upper - lower) if (upper - lower) > 0 else 0.5

        # RSI
        rsi_vals = self.rsi(self.rsi_period)
        rsi_val = rsi_vals[-1] if rsi_vals else 50

        # 均值回归做多：下轨 + RSI超卖
        if c <= lower and rsi_val <= self.rsi_oversold:
            # TP: 中轨, SL: 下轨再往下1个ATR
            atr_vals = self.atr(14)
            atr = atr_vals[-1] if atr_vals else (h - l)
            sl = c - atr * 1.5
            tp = sma_val  # 回归中轨
            score = min(1.0, (self.rsi_oversold - rsi_val) / self.rsi_oversold + 0.5)
            return Signal("LONG", size=self._calc_size(), price=c, score=score,
                         reason=f"均值回归LONG BB%={bb_pct:.2f} RSI={rsi_val:.0f}",
                         tp_price=tp, sl_price=sl)

        # 均值回归做空：上轨 + RSI超买
        if c >= upper and rsi_val >= self.rsi_overbought:
            atr_vals = self.atr(14)
            atr = atr_vals[-1] if atr_vals else (h - l)
            sl = c + atr * 1.5
            tp = sma_val
            score = min(1.0, (rsi_val - self.rsi_overbought) / (100 - self.rsi_overbought) + 0.5)
            return Signal("SHORT", size=self._calc_size(), price=c, score=-score,
                         reason=f"均值回归SHORT BB%={bb_pct:.2f} RSI={rsi_val:.0f}",
                         tp_price=tp, sl_price=sl)

        # 止盈离场：价格回到中轨附近
        pos = self.executor.get_position(self.symbol)
        if pos["long"] and pos["long"]["qty"] > 0:
            if abs(c - sma_val) / sma_val < 0.002:  # 回到中轨0.2%内
                return Signal("CLOSE", price=c, reason=f"均值回归止盈 回归中轨 {sma_val:.1f}")
        if pos["short"] and pos["short"]["qty"] > 0:
            if abs(c - sma_val) / sma_val < 0.002:
                return Signal("CLOSE", price=c, reason=f"均值回归止盈 回归中轨 {sma_val:.1f}")

        return Signal()

    def _calc_size(self) -> float:
        try:
            _, equity = self.executor.okx.balance()
            return max(1, round(equity * self.cfg.max_position_pct / 1000, 1))
        except Exception:
            return 1
