# -*- coding: utf-8 -*-
"""
BaseStrategy — 策略基类
所有策略继承此类，实现 on_candle() 方法
"""
import time
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger("MyTrader")


@dataclass
class StrategyConfig:
    """策略配置"""
    name: str = "base"           # 策略名称
    symbol: str = "ETH"          # 交易标的
    enabled: bool = True         # 是否启用
    max_position_pct: float = 0.3  # 最大仓位占比
    cooldown_bars: int = 1       # 开仓冷却(bar数)
    max_daily_trades: int = 20   # 每日最大交易次数
    stop_loss_pct: float = 0.02  # 止损百分比
    take_profit_pct: float = 0.04  # 止盈百分比


@dataclass
class Signal:
    """策略信号"""
    action: str = "WAIT"          # LONG / SHORT / CLOSE / WAIT
    size: float = 0               # 建议仓位(张)
    price: float = 0              # 信号价格
    score: float = 0              # 信号强度 [-1, 1]
    reason: str = ""              # 信号原因
    tp_price: float = 0           # 止盈价
    sl_price: float = 0           # 止损价


@dataclass
class StrategyState:
    """策略运行状态"""
    active: bool = True
    bars_processed: int = 0
    signals_generated: int = 0
    trades_opened: int = 0
    last_signal: Optional[Signal] = None
    last_trade_ts: float = 0
    pnl_total: float = 0.0


class BaseStrategy(ABC):
    """策略基类"""

    def __init__(self, executor, config: StrategyConfig):
        self.executor = executor   # OrderExecutor 实例
        self.cfg = config
        self.state = StrategyState()
        self._bar_count = 0
        self._prices = []          # 收盘价序列
        self._highs = []           # 最高价序列
        self._lows = []            # 最低价序列
        self._volumes = []         # 成交量序列
        self._signals_history = []  # 信号历史

    @property
    def symbol(self) -> str:
        return self.cfg.symbol

    @property
    def enabled(self) -> bool:
        return self.cfg.enabled

    # ── 子类必须实现 ──────────────────────────────

    @abstractmethod
    def on_bar(self, o: float, h: float, l: float, c: float, v: float, timestamp: int) -> Signal:
        """处理一根K线，返回交易信号"""
        ...

    # ── 公共方法 ────────────────────────────────────

    def feed_bar(self, o: float, h: float, l: float, c: float, v: float, timestamp: int = 0) -> Optional[Signal]:
        """喂入K线数据，自动更新内部序列并调用 on_bar"""
        self._prices.append(c)
        self._highs.append(h)
        self._lows.append(l)
        self._volumes.append(v)
        self._bar_count += 1
        self.state.bars_processed += 1

        signal = self.on_bar(o, h, l, c, v, timestamp)
        if signal and signal.action != "WAIT":
            self.state.signals_generated += 1
            self.state.last_signal = signal
            self._signals_history.append({
                "ts": timestamp, "action": signal.action,
                "score": signal.score, "price": c
            })
            # 限制信号历史长度
            if len(self._signals_history) > 500:
                self._signals_history = self._signals_history[-500:]
        return signal

    def execute_signal(self, signal: Signal) -> bool:
        """执行交易信号"""
        if signal.action == "WAIT":
            return False
        if not self.executor.check_cooldown(self.symbol, self.cfg.cooldown_bars * 15):
            logger.debug(f"[{self.cfg.name}] 冷却中，跳过信号")
            return False
        if not self.executor.check_daily_limit(self.symbol, self.cfg.max_daily_trades):
            logger.warning(f"[{self.cfg.name}] 当日交易次数超限")
            return False

        if signal.action in ("LONG", "SHORT"):
            if self.executor.has_position(self.symbol):
                logger.debug(f"[{self.cfg.name}] 已有持仓，跳过开仓")
                return False
            tp = signal.tp_price or (signal.price * (1 + self.cfg.take_profit_pct)
                                     if signal.action == "LONG" else signal.price * (1 - self.cfg.take_profit_pct))
            sl = signal.sl_price or (signal.price * (1 - self.cfg.stop_loss_pct)
                                     if signal.action == "LONG" else signal.price * (1 + self.cfg.stop_loss_pct))
            result = self.executor.market_open(self.symbol, signal.action, signal.size,
                                                tp_price=tp, sl_price=sl)
            if result:
                self.state.trades_opened += 1
                self.state.last_trade_ts = time.time()
            return bool(result)

        elif signal.action == "CLOSE":
            pos = self.executor.get_position(self.symbol)
            for side in ("long", "short"):
                if pos[side] and pos[side]["qty"] > 0:
                    self.executor.market_close(self.symbol, side)
            return True

        return False

    def run_backtest(self, candles: list) -> dict:
        """简易回测"""
        trades = []
        equity = 10000
        position = 0
        direction = None

        for c in candles:
            if isinstance(c, dict):
                signal = self.feed_bar(c["open"], c["high"], c["low"], c["close"],
                                       c.get("vol", 0), c.get("ts", 0))
            else:
                signal = self.feed_bar(c[0], c[1], c[2], c[3], c[4] if len(c) > 4 else 0, 0)

            if signal and signal.action != "WAIT" and position == 0:
                position = signal.size
                direction = signal.action
                entry_price = c[3] if not isinstance(c, dict) else c["close"]
                trades.append({"type": direction, "entry": entry_price, "size": position})

            elif position > 0:
                if signal and signal.action == "CLOSE":
                    exit_price = c[3] if not isinstance(c, dict) else c["close"]
                    pnl = (exit_price - trades[-1]["entry"]) * position
                    if direction == "SHORT":
                        pnl = -pnl
                    equity += pnl
                    trades[-1]["exit"] = exit_price
                    trades[-1]["pnl"] = pnl
                    position = 0
                    direction = None

        return {
            "trades": len(trades),
            "final_equity": equity,
            "pnl_pct": (equity / 10000 - 1) * 100,
            "trades_list": trades,
        }

    def get_prices(self, n: int = None) -> list:
        """获取最近n个收盘价"""
        if n:
            return self._prices[-n:]
        return self._prices

    def get_highs(self, n: int = None) -> list:
        if n:
            return self._highs[-n:]
        return self._highs

    def get_lows(self, n: int = None) -> list:
        if n:
            return self._lows[-n:]
        return self._lows

    def get_volumes(self, n: int = None) -> list:
        if n:
            return self._volumes[-n:]
        return self._volumes

    def sma(self, data: list, period: int) -> list:
        """简单移动平均"""
        result = []
        for i in range(len(data)):
            if i < period - 1:
                result.append(None)
            else:
                result.append(sum(data[i - period + 1:i + 1]) / period)
        return result

    def ema(self, data: list, period: int) -> list:
        """指数移动平均"""
        result = []
        multiplier = 2 / (period + 1)
        for i in range(len(data)):
            if i == 0:
                result.append(data[0])
            elif i < period - 1:
                sma_val = sum(data[:i + 1]) / (i + 1)
                ema_prev = result[-1] if result[-1] is not None else sma_val
                result.append((data[i] - ema_prev) * multiplier + ema_prev)
            else:
                result.append((data[i] - result[-1]) * multiplier + result[-1])
        return result

    def highest(self, data: list, period: int) -> list:
        """滚动最高值"""
        result = []
        for i in range(len(data)):
            start = max(0, i - period + 1)
            result.append(max(data[start:i + 1]))
        return result

    def lowest(self, data: list, period: int) -> list:
        """滚动最低值"""
        result = []
        for i in range(len(data)):
            start = max(0, i - period + 1)
            result.append(min(data[start:i + 1]))
        return result

    def atr(self, period: int = 14) -> list:
        """平均真实波幅"""
        result = []
        for i in range(len(self._prices)):
            if i == 0:
                result.append(self._highs[i] - self._lows[i])
            else:
                tr = max(
                    self._highs[i] - self._lows[i],
                    abs(self._highs[i] - self._prices[i - 1]),
                    abs(self._lows[i] - self._prices[i - 1]),
                )
                if i < period:
                    result.append(sum(max(self._highs[j] - self._lows[j],
                                          abs(self._highs[j] - (self._prices[j - 1] if j > 0 else self._prices[0])),
                                          abs(self._lows[j] - (self._prices[j - 1] if j > 0 else self._prices[0])))
                                     for j in range(i + 1)) / (i + 1))
                else:
                    result.append((result[-1] * (period - 1) + tr) / period)
        return result

    def rsi(self, period: int = 14) -> list:
        """相对强弱指数"""
        if len(self._prices) < period + 1:
            return [50] * len(self._prices)
        result = [50] * period
        gains = []
        losses = []
        for i in range(1, len(self._prices)):
            delta = self._prices[i] - self._prices[i - 1]
            gains.append(max(delta, 0))
            losses.append(max(-delta, 0))
        for i in range(period, len(self._prices)):
            avg_gain = sum(gains[i - period:i]) / period
            avg_loss = sum(losses[i - period:i]) / period
            if avg_loss == 0:
                result.append(100)
            else:
                rs = avg_gain / avg_loss
                result.append(100 - 100 / (1 + rs))
        return result

    # ── 状态查询 ──────────────────────────────────

    def get_state_summary(self) -> dict:
        try:
            has_pos = self.executor.has_position(self.symbol)
        except Exception:
            has_pos = False
        return {
            "name": self.cfg.name,
            "symbol": self.cfg.symbol,
            "enabled": self.cfg.enabled,
            "bars": self.state.bars_processed,
            "signals": self.state.signals_generated,
            "trades": self.state.trades_opened,
            "has_position": has_pos,
        }
