# -*- coding: utf-8 -*-
"""
FactorEngine — 因子计算后台引擎
- 持续运行，每个周期计算所有品种的全量因子
- 写入 factor_snapshots 供回测和展示
- 追踪每个因子的长期表现 (accuracy/contribution/hit_rate)
- 支持跨品种因子对比

用法:
  engine = FactorEngine(okx, deribit, instruments)
  engine.start()  # 后台线程
  engine.get_factor_performance("BTC", "trend")  # 查看单个因子长期表现
"""
import time
import math
import logging
import threading
import json
from collections import deque, defaultdict
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger("MyTrader")


@dataclass
class FactorSnapshot:
    """单次因子快照"""
    symbol: str
    timestamp: float
    price: float
    factors: dict   # {factor_name: score}
    total_score: float
    direction: str   # LONG/SHORT/WAIT
    metadata: dict = field(default_factory=dict)


@dataclass
class FactorPerformance:
    """单个因子的长期表现统计"""
    factor_name: str
    symbol: str
    count: int = 0
    avg_value: float = 0          # 平均值
    std_value: float = 0           # 标准差
    max_value: float = -999
    min_value: float = 999
    directional_accuracy: float = 0  # 方向准确率 (factor方向 vs 实际价格方向)
    contribution_pct: float = 0      # 对总分的贡献占比
    signal_hit_rate: float = 0       # 因子信号命中率 (score>threshold → 价格同向)
    last_updated: float = 0
    history: list = field(default_factory=list)  # 最近的值序列


class FactorEngine:
    """因子计算后台引擎"""

    def __init__(self, okx_client, instruments: dict, deribit_client=None,
                 symbol: str = "ETH", interval: int = 60):
        self.okx = okx_client
        self.deribit = deribit_client
        self.instruments = instruments
        self.symbol = symbol
        self.interval = interval

        # 因子定义
        self.factor_names = [
            "trend", "orderbook", "funding", "taker", "oi", "maxpain",
            "vol_delta", "btc_corr", "gamma", "iv", "exhaust", "liq_cool",
            "mean_revert", "news", "smart_money",
        ]

        # 历史快照缓存
        self._snapshots: deque = deque(maxlen=2000)
        self._performance: dict[str, FactorPerformance] = {}  # key: "symbol:factor"

        # 线程
        self._running = False
        self._thread: Optional[threading.Thread] = None
        self._cycle_count = 0
        self._last_snapshot_ts = 0

        # Analyzer引用 (延迟初始化避免循环导入)
        self._analyzer = None

        logger.info(f"[FactorEngine] 初始化 symbol={symbol} interval={interval}s")

    # ═══════════════════════════════════════════════════════════
    #  生命周期
    # ═══════════════════════════════════════════════════════════

    def start(self):
        if self._running:
            return
        self._running = True
        self._thread = threading.Thread(target=self._loop, name="FactorEngine", daemon=True)
        self._thread.start()
        logger.info("[FactorEngine] 后台线程已启动")

    def stop(self):
        self._running = False

    def is_running(self) -> bool:
        return self._running

    def _loop(self):
        logger.info(f"[FactorEngine] 开始循环 interval={self.interval}s")
        while self._running:
            try:
                self._compute_and_record()
                self._update_performance()
                self._cycle_count += 1
            except Exception as e:
                logger.error(f"[FactorEngine] 循环异常: {e}", exc_info=True)
            time.sleep(self.interval)

    # ═══════════════════════════════════════════════════════════
    #  因子计算
    # ═══════════════════════════════════════════════════════════

    def _compute_and_record(self):
        """计算因子并写入快照"""
        now = time.time()
        if now - self._last_snapshot_ts < 10:  # 去重
            return

        # 延迟加载 analyzer
        if self._analyzer is None:
            try:
                from strategy.analyzer import MarketAnalyzer
                self._analyzer = MarketAnalyzer(self.okx, self.deribit)
                logger.info("[FactorEngine] MarketAnalyzer 已加载")
            except Exception as e:
                logger.warning(f"[FactorEngine] Analyzer 加载失败: {e}")
                return

        # 计算所有品种
        for sym, cfg in self.instruments.items():
            try:
                total_score, result = self._analyzer.analyze(sym)
                direction = result.get("direction", "WAIT")
                price = result.get("price", 0)
                scores = result.get("scores", {})

                snapshot = FactorSnapshot(
                    symbol=sym,
                    timestamp=now,
                    price=price,
                    factors={k: float(v) for k, v in scores.items()},
                    total_score=total_score,
                    direction=direction,
                    metadata={
                        "t1h": result.get("t1h", 0),
                        "t4h": result.get("t4h", 0),
                        "fng": 50,
                        "pre_filter": result.get("pre_filter", 0),
                    },
                )
                self._snapshots.append(snapshot)

                # 写入 MySQL
                try:
                    self._write_to_db(snapshot)
                except Exception:
                    pass

            except Exception as e:
                logger.warning(f"[FactorEngine] {sym} 计算失败: {e}")

        self._last_snapshot_ts = now

    def compute_now(self) -> dict:
        """手动触发一次计算，返回所有品种的因子"""
        results = {}
        if self._analyzer is None:
            try:
                from strategy.analyzer import MarketAnalyzer
                self._analyzer = MarketAnalyzer(self.okx, self.deribit)
            except Exception:
                return results

        for sym in self.instruments:
            try:
                results[sym] = self._analyzer.analyze(sym)
            except Exception:
                pass
        return results

    # ═══════════════════════════════════════════════════════════
    #  因子表现追踪
    # ═══════════════════════════════════════════════════════════

    def _update_performance(self):
        """更新所有因子的长期表现统计"""
        if len(self._snapshots) < 2:
            return
        recent = list(self._snapshots)[-500:]  # 最近500条

        for snap in recent[-1:]:  # 只更新最新的
            sym = snap.symbol
            price = snap.price

            for fname, score in snap.factors.items():
                key = f"{sym}:{fname}"
                if key not in self._performance:
                    self._performance[key] = FactorPerformance(
                        factor_name=fname, symbol=sym)

                perf = self._performance[key]
                perf.count += 1
                perf.last_updated = snap.timestamp
                perf.max_value = max(perf.max_value, score)
                perf.min_value = min(perf.min_value, score)
                perf.history.append({"ts": snap.timestamp, "value": score, "price": price})
                if len(perf.history) > 500:
                    perf.history = perf.history[-500:]

                # 更新均值和标准差 (在线算法)
                delta = score - perf.avg_value
                perf.avg_value += delta / perf.count
                perf.std_value = math.sqrt(
                    ((perf.count - 1) * perf.std_value ** 2 + delta * (score - perf.avg_value)) / perf.count
                ) if perf.count > 1 else 0

        # 计算方向准确率 (每100条更新一次)
        if self._cycle_count % 20 == 0 and len(self._snapshots) >= 100:
            self._calc_accuracy()

    def _calc_accuracy(self):
        """计算因子方向准确率"""
        snaps = list(self._snapshots)
        if len(snaps) < 100:
            return

        for key, perf in self._performance.items():
            sym, fname = key.split(":", 1)
            hits = 0
            total = 0
            for i in range(1, len(snaps)):
                prev = snaps[i - 1]
                curr = snaps[i]
                if prev.symbol != sym or curr.symbol != sym:
                    continue
                factor_val = prev.factors.get(fname, 0)
                price_change = curr.price - prev.price
                if factor_val * price_change > 0:
                    hits += 1
                total += 1
            if total > 0:
                perf.directional_accuracy = round(hits / total, 4)

            # 贡献占比 (该因子值的绝对平均 / 所有因子绝对平均之和)
            all_avg = sum(abs(getattr(self._performance.get(f"{sym}:{k}"), 'avg_value', 0))
                         for k in self.factor_names)
            if all_avg > 0:
                perf.contribution_pct = round(abs(perf.avg_value) / all_avg * 100, 1)

    # ═══════════════════════════════════════════════════════════
    #  查询接口
    # ═══════════════════════════════════════════════════════════

    def get_factor_performance(self, symbol: str = None, factor: str = None) -> list:
        """查询因子表现 (可按品种/因子过滤)"""
        results = []
        for key, perf in self._performance.items():
            sym, fname = key.split(":", 1)
            if symbol and sym != symbol:
                continue
            if factor and fname != factor:
                continue
            results.append({
                "symbol": sym,
                "factor": fname,
                "count": perf.count,
                "avg": round(perf.avg_value, 4),
                "std": round(perf.std_value, 4),
                "max": round(perf.max_value, 4),
                "min": round(perf.min_value, 4),
                "accuracy": perf.directional_accuracy,
                "contribution_pct": perf.contribution_pct,
                "updated": time.strftime("%H:%M:%S", time.localtime(perf.last_updated)),
            })
        return sorted(results, key=lambda x: -abs(x["avg"]))

    def get_recent_snapshots(self, symbol: str = None, n: int = 50) -> list:
        """获取最近N条因子快照"""
        snaps = list(self._snapshots)[-n:]
        if symbol:
            snaps = [s for s in snaps if s.symbol == symbol]
        return [
            {"ts": time.strftime("%H:%M:%S", time.localtime(s.timestamp)),
             "symbol": s.symbol, "price": s.price, "total": s.total_score,
             "direction": s.direction, "factors": s.factors}
            for s in snaps
        ]

    def get_status(self) -> dict:
        """引擎状态"""
        syms = set()
        for s in self._snapshots:
            syms.add(s.symbol)
        return {
            "running": self._running,
            "cycle": self._cycle_count,
            "symbols": list(syms),
            "snapshots": len(self._snapshots),
            "factors_tracked": len(self._performance),
            "last_update": time.strftime("%H:%M:%S", time.localtime(self._last_snapshot_ts))
            if self._last_snapshot_ts else "--",
        }

    # ═══════════════════════════════════════════════════════════
    #  DB写入
    # ═══════════════════════════════════════════════════════════

    def _write_to_db(self, snap: FactorSnapshot):
        """写入 factor_snapshots 表"""
        try:
            from storage.mysql_client import get_cursor
            import os
            if os.environ.get('DB_ENABLED', 'true').lower() != 'true':
                return
            scores = snap.factors
            with get_cursor() as cur:
                cur.execute("""
                    INSERT INTO factor_snapshots (
                        symbol, price, direction, total_score,
                        trend, orderbook, funding, taker, oi, maxpain,
                        vol_delta, btc_corr, gamma, iv, exhaust, liq_cool,
                        mean_revert, news, smart_money, mtf, ob_liq, low_lev, liq_ex,
                        fng, t1h, t4h, pre_filter, snapshot_at
                    ) VALUES (%s,%s,%s,%s, %s,%s,%s,%s,%s,%s, %s,%s,%s,%s,%s,%s, %s,%s,%s,%s,%s,%s,%s, %s,%s,%s,%s,%s)
                """, (
                    snap.symbol, snap.price, snap.direction, snap.total_score,
                    scores.get('trend'), scores.get('orderbook'), scores.get('funding'),
                    scores.get('taker'), scores.get('oi'), scores.get('maxpain'),
                    scores.get('vol_delta'), scores.get('btc_corr'), scores.get('gamma'),
                    scores.get('iv'), scores.get('exhaust'), scores.get('liq_cool'),
                    scores.get('mean_revert'), scores.get('news'), scores.get('smart_money'),
                    scores.get('mtf', 0), scores.get('ob_liq', 0),
                    scores.get('low_lev', 0), scores.get('liq_ex', 0),
                    snap.metadata.get('fng'), snap.metadata.get('t1h'),
                    snap.metadata.get('t4h'), snap.metadata.get('pre_filter'),
                    __import__('datetime').datetime.utcnow(),
                ))
        except Exception as e:
            logger.debug(f"[FactorEngine] DB写入跳过: {e}")
