# -*- coding: utf-8 -*-
"""
海龟交易法 — 20日/55日唐奇安通道突破
原始规则:
- Entry: 价格突破20日最高价 → LONG, 突破20日最低价 → SHORT
- Filter: 55日最高/最低价确认趋势方向
- Exit: 反向突破10日通道
- Position Sizing: ATR-based (1N = ATR(20))
"""
import logging
from .base_strategy import BaseStrategy, StrategyConfig, Signal

logger = logging.getLogger("MyTrader")


class TurtleStrategy(BaseStrategy):
    """海龟交易法"""

    def __init__(self, executor, config: StrategyConfig = None):
        if config is None:
            config = StrategyConfig(
                name="turtle", symbol="ETH",
                max_position_pct=0.25, cooldown_bars=2,
                stop_loss_pct=0.02, take_profit_pct=0.04,
            )
        super().__init__(executor, config)
        self.entry_period = 20   # 入场通道
        self.exit_period = 10    # 离场通道
        self.filter_period = 55  # 趋势过滤通道（0=不过滤）
        self.atr_period = 20     # ATR计算周期
        self.atr_mult = 2.0      # 止损ATR倍数

    def on_bar(self, o: float, h: float, l: float, c: float, v: float, timestamp: int) -> Signal:
        n = self._bar_count
        min_bars = self.filter_period + 5
        if n < min_bars:
            return Signal()

        high_20 = max(self._highs[-self.entry_period:-1]) if len(self._highs) > self.entry_period else h
        low_20 = min(self._lows[-self.entry_period:-1]) if len(self._lows) > self.entry_period else l
        high_10 = max(self._highs[-self.exit_period:-1]) if len(self._highs) > self.exit_period else h
        low_10 = min(self._lows[-self.exit_period:-1]) if len(self._lows) > self.exit_period else l

        # 趋势过滤：55日通道
        trend_up = True
        trend_down = True
        if self.filter_period > 0 and n > self.filter_period:
            high_55 = max(self._highs[-self.filter_period:-1])
            low_55 = min(self._lows[-self.filter_period:-1])
            trend_up = c > (high_55 + low_55) / 2
            trend_down = c < (high_55 + low_55) / 2

        atr_val = self.atr(self.atr_period)[-1] if self._bar_count >= self.atr_period else (h - l)

        # 做多信号
        if c > high_20 and trend_up:
            tp = c + atr_val * self.atr_mult * 2
            sl = c - atr_val * self.atr_mult
            return Signal("LONG", size=self._calc_size(atr_val), price=c,
                         score=0.8, reason=f"海龟突破LONG 20H={high_20:.1f}",
                         tp_price=tp, sl_price=sl)

        # 做空信号
        if c < low_20 and trend_down:
            tp = c - atr_val * self.atr_mult * 2
            sl = c + atr_val * self.atr_mult
            return Signal("SHORT", size=self._calc_size(atr_val), price=c,
                         score=-0.8, reason=f"海龟突破SHORT 20L={low_20:.1f}",
                         tp_price=tp, sl_price=sl)

        # 离场信号：反向突破10日通道
        pos = self.executor.get_position(self.symbol)
        if pos["long"] and pos["long"]["qty"] > 0 and c < low_10:
            return Signal("CLOSE", price=c, reason=f"海龟离场 LONG 10L={low_10:.1f}")
        if pos["short"] and pos["short"]["qty"] > 0 and c > high_10:
            return Signal("CLOSE", price=c, reason=f"海龟离场 SHORT 10H={high_10:.1f}")

        return Signal()

    def _calc_size(self, atr_val: float) -> float:
        """基于ATR的仓位计算: 每N风险1%资金"""
        if atr_val <= 0:
            return 1
        try:
            _, equity = self.executor.okx.balance()
            risk_amount = equity * 0.01
            return max(1, round(risk_amount / atr_val, 1))
        except Exception:
            return 1
