# -*- coding: utf-8 -*-
"""
多因子综合策略
- 利用现有的 MarketAnalyzer 15因子评分
- 总分 > +0.35 → LONG, < -0.35 → SHORT
- 前置过滤: pre_filter >= 0.4 才开仓
- 多时间框架共振确认
- 持仓中动态调整止盈止损
"""
import logging
from .base_strategy import BaseStrategy, StrategyConfig, Signal

logger = logging.getLogger("MyTrader")


class MultiFactorStrategy(BaseStrategy):
    """多因子综合策略"""

    def __init__(self, executor, analyzer=None, config: StrategyConfig = None):
        if config is None:
            config = StrategyConfig(
                name="multi_factor", symbol="ETH",
                max_position_pct=0.3, cooldown_bars=1,
                stop_loss_pct=0.02, take_profit_pct=0.04,
            )
        super().__init__(executor, config)
        self.analyzer = analyzer        # MarketAnalyzer 实例（由外部注入）
        self.score_threshold = 0.25     # 开仓阈值（降低以适应偏低的前置过滤）
        self.prefilter_threshold = 0.15  # 前置过滤阈值（原0.4过严，实际均值0.15）
        self._last_score = 0
        self._last_factors = {}

    def set_analyzer(self, analyzer):
        """注入外部MarketAnalyzer"""
        self.analyzer = analyzer

    def on_bar(self, o: float, h: float, l: float, c: float, v: float, timestamp: int) -> Signal:
        if self.analyzer is None:
            return Signal()

        # 从analyzer获取最新评分 (调用一次分析周期)
        try:
            result = self.analyzer.analyze()
        except Exception as e:
            logger.warning(f"[multi_factor] analyze error: {e}")
            return Signal()

        total_score = result.get("total_score", 0)
        pre_filter = result.get("pre_filter", 0)
        direction = result.get("direction", "WAIT")
        t1h = result.get("t1h", 0)
        t4h = result.get("t4h", 0)
        scores = result.get("scores", {})
        flip_penalty = result.get("flip_penalty", 0)

        self._last_score = total_score
        self._last_factors = scores

        atr_vals = self.atr(14)
        atr_val = atr_vals[-1] if atr_vals and atr_vals[-1] else (h - l)

        # 开仓条件：前置过滤 + 阈值 + 趋势共振
        if pre_filter >= self.prefilter_threshold:
            if total_score >= self.score_threshold and t1h >= 0 and t4h >= 0:
                sl = c - atr_val * 2
                tp = c + atr_val * 3
                return Signal("LONG", size=self._calc_size(total_score), price=c,
                             score=total_score,
                             reason=f"多因子LONG score={total_score:.3f} pf={pre_filter:.2f}",
                             tp_price=tp, sl_price=sl)

            if total_score <= -self.score_threshold and t1h <= 0 and t4h <= 0:
                sl = c + atr_val * 2
                tp = c - atr_val * 3
                return Signal("SHORT", size=self._calc_size(abs(total_score)), price=c,
                             score=total_score,
                             reason=f"多因子SHORT score={total_score:.3f} pf={pre_filter:.2f}",
                             tp_price=tp, sl_price=sl)

        # 平仓：评分回归或方向翻转
        pos = self.executor.get_position(self.symbol)
        if pos["long"] and pos["long"]["qty"] > 0:
            if total_score < -0.1 or flip_penalty < -0.5:
                return Signal("CLOSE", price=c,
                             reason=f"多因子平多 score={total_score:.3f} flip={flip_penalty:.2f}")
        if pos["short"] and pos["short"]["qty"] > 0:
            if total_score > 0.1 or flip_penalty < -0.5:
                return Signal("CLOSE", price=c,
                             reason=f"多因子平空 score={total_score:.3f} flip={flip_penalty:.2f}")

        return Signal()

    def _calc_size(self, score: float) -> float:
        """信号越强仓位越大"""
        try:
            _, equity = self.executor.okx.balance()
            base = equity * self.cfg.max_position_pct / 2000  # 基准张数
            return max(1, round(base * (1 + abs(score)), 1))
        except Exception:
            return 1
