# -*- coding: utf-8 -*-
"""
趋势跟踪策略
- EMA交叉: 快线上穿慢线 → LONG, 下穿 → SHORT
- ADX过滤器: ADX>25才交易（强趋势市）
- 移动止盈: 价格回撤超过ATR的N倍时平仓
- 支持: 单EMA交叉 / 三EMA系统
"""
import logging
from .base_strategy import BaseStrategy, StrategyConfig, Signal

logger = logging.getLogger("MyTrader")


class TrendFollowingStrategy(BaseStrategy):
    """趋势跟踪策略"""

    def __init__(self, executor, config: StrategyConfig = None):
        if config is None:
            config = StrategyConfig(
                name="trend", symbol="ETH",
                max_position_pct=0.3, cooldown_bars=2,
                stop_loss_pct=0.02, take_profit_pct=0.06,
            )
        super().__init__(executor, config)
        self.fast_ema = 12           # 快线周期
        self.slow_ema = 26           # 慢线周期
        self.signal_ema = 9          # 信号线周期
        self.adx_period = 14         # ADX周期
        self.adx_threshold = 25      # ADX阈值
        self.trail_atr_mult = 3.0    # 移动止盈ATR倍数
        self._prev_fast = None
        self._prev_slow = None
        self._highest_since_entry = 0
        self._lowest_since_entry = float("inf")

    def on_bar(self, o: float, h: float, l: float, c: float, v: float, timestamp: int) -> Signal:
        n = self._bar_count
        min_bars = self.slow_ema + self.adx_period + 5
        if n < min_bars:
            return Signal()

        # EMA计算
        prices = self.get_prices()
        ema_fast = self.ema(prices, self.fast_ema)
        ema_slow = self.ema(prices, self.slow_ema)
        ema_signal = self.ema(prices, self.signal_ema)

        fast = ema_fast[-1]
        slow = ema_slow[-1]
        signal = ema_signal[-1]
        prev_fast = ema_fast[-2] if len(ema_fast) >= 2 else fast
        prev_slow = ema_slow[-2] if len(ema_slow) >= 2 else slow

        # ADX
        adx_val = self._adx(self.adx_period)
        strong_trend = adx_val >= self.adx_threshold if adx_val is not None else False

        # 价格与EMA的距离
        price_vs_ema = (c - slow) / slow

        atr_vals = self.atr(14)
        atr_val = atr_vals[-1] if atr_vals and atr_vals[-1] else (h - l)

        # 做多：金叉 + ADX确认
        if prev_fast <= prev_slow and fast > slow and strong_trend:
            sl = c - atr_val * 2
            tp = c + atr_val * 4
            score = 0.6 + (adx_val / 100) * 0.4 if adx_val else 0.6
            self._highest_since_entry = c
            return Signal("LONG", size=self._calc_size(atr_val), price=c, score=score,
                         reason=f"趋势LONG 金叉 ADX={adx_val:.0f} {self.fast_ema}/{self.slow_ema}",
                         tp_price=tp, sl_price=sl)

        # 做空：死叉 + ADX确认
        if prev_fast >= prev_slow and fast < slow and strong_trend:
            sl = c + atr_val * 2
            tp = c - atr_val * 4
            score = -(0.6 + (adx_val / 100) * 0.4) if adx_val else -0.6
            self._lowest_since_entry = c
            return Signal("SHORT", size=self._calc_size(atr_val), price=c, score=score,
                         reason=f"趋势SHORT 死叉 ADX={adx_val:.0f} {self.fast_ema}/{self.slow_ema}",
                         tp_price=tp, sl_price=sl)

        # 移动止盈
        pos = self.executor.get_position(self.symbol)
        if pos["long"] and pos["long"]["qty"] > 0:
            self._highest_since_entry = max(self._highest_since_entry, c)
            trail_stop = self._highest_since_entry - atr_val * self.trail_atr_mult
            # 快线下穿信号线 or 跌破移动止盈
            if fast < signal or c < trail_stop:
                return Signal("CLOSE", price=c,
                             reason=f"趋势平多 {'信号线下穿' if fast < signal else '移动止盈'} H={self._highest_since_entry:.1f}")

        if pos["short"] and pos["short"]["qty"] > 0:
            self._lowest_since_entry = min(self._lowest_since_entry, c)
            trail_stop = self._lowest_since_entry + atr_val * self.trail_atr_mult
            if fast > signal or c > trail_stop:
                return Signal("CLOSE", price=c,
                             reason=f"趋势平空 {'信号线上穿' if fast > signal else '移动止盈'} L={self._lowest_since_entry:.1f}")

        return Signal()

    def _adx(self, period: int = 14) -> float:
        """计算ADX"""
        n = self._bar_count
        if n < period * 2:
            return 0

        tr_vals = []
        plus_dm = []
        minus_dm = []
        for i in range(n - period, n):
            if i == 0:
                tr_vals.append(self._highs[i] - self._lows[i])
                plus_dm.append(0)
                minus_dm.append(0)
            else:
                tr = max(
                    self._highs[i] - self._lows[i],
                    abs(self._highs[i] - self._prices[i - 1]),
                    abs(self._lows[i] - self._prices[i - 1]),
                )
                tr_vals.append(tr)
                up = self._highs[i] - self._highs[i - 1]
                down = self._lows[i - 1] - self._lows[i]
                plus_dm.append(up if up > down and up > 0 else 0)
                minus_dm.append(down if down > up and down > 0 else 0)

        tr_smooth = sum(tr_vals[:period]) / period
        pdm_smooth = sum(plus_dm[:period]) / period
        ndm_smooth = sum(minus_dm[:period]) / period

        for i in range(period, len(tr_vals)):
            tr_smooth = tr_smooth - tr_smooth / period + tr_vals[i]
            pdm_smooth = pdm_smooth - pdm_smooth / period + plus_dm[i]
            ndm_smooth = ndm_smooth - ndm_smooth / period + minus_dm[i]

        pdi = (pdm_smooth / tr_smooth * 100) if tr_smooth > 0 else 0
        ndi = (ndm_smooth / tr_smooth * 100) if tr_smooth > 0 else 0
        dx = abs(pdi - ndi) / (pdi + ndi) * 100 if (pdi + ndi) > 0 else 0

        return round(dx, 1)

    def _calc_size(self, atr_val: float) -> float:
        try:
            _, equity = self.executor.okx.balance()
            risk = equity * 0.015
            return max(1, round(risk / atr_val, 1))
        except Exception:
            return 1
