# -*- coding: utf-8 -*-
"""
Price Anomaly Detector — 价格异动检测
检测类型:
  1. Flash Crash/Pump  — 短时间内价格剧烈波动
  2. Volume Spike      — 成交量突然暴增
  3. Spread Widening   — 买卖价差异常扩大
  4. OI Anomaly        — 持仓量异动（暴增/暴跌）
  5. Liquidation Cascade — 连续大额清算

当检测到异动时 → 通过网关通知战略部调查原因并记录
"""
import time
import math
import logging
from collections import deque
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger("MyTrader")


@dataclass
class AnomalyEvent:
    """价格异动事件"""
    type: str            # flash_crash / flash_pump / vol_spike / spread / oi / liquidation
    severity: str        # info / warning / critical
    symbol: str
    price: float
    detail: str          # 详细描述
    metrics: dict = field(default_factory=dict)
    timestamp: float = field(default_factory=time.time)
    relayed: bool = False  # 是否已通知战略部


class PriceAnomalyDetector:
    """价格异动检测器"""

    def __init__(self, okx_client, instruments: dict, max_history: int = 200):
        self.okx = okx_client
        self.instruments = instruments

        # 价格历史
        self._prices = deque(maxlen=max_history)
        self._timestamps = deque(maxlen=max_history)

        # 成交量历史 (per symbol)
        self._volumes: dict[str, deque] = {}

        # 价差历史
        self._spreads: dict[str, deque] = {}

        # 异动事件
        self.events: list[AnomalyEvent] = []
        self._relayed_events = set()  # 已通知的event hash, 去重

        # 阈值
        self.flash_threshold_pct = 2.0     # 闪崩/暴拉: 价格变动超过2%
        self.flash_window_sec = 60          # 检测窗口60秒
        self.vol_spike_ratio = 3.0          # 量暴增: 超过基线3倍
        self.spread_threshold_pct = 0.15    # 价差异常: 超过0.15%
        self.oi_change_threshold = 0.10     # OI异动: 变化超过10%

        # 冷却: 同类型异动N秒内不重复通知
        self._cooldowns: dict[str, float] = {}

        logger.info("[异动检测] 已初始化")

    def feed_price(self, symbol: str, price: float, volume: float = 0):
        """喂入价格数据"""
        now = time.time()
        self._prices.append(price)
        self._timestamps.append(now)

        if symbol not in self._volumes:
            self._volumes[symbol] = deque(maxlen=100)
        self._volumes[symbol].append((now, volume))

    def feed_spread(self, symbol: str, spread_pct: float):
        """喂入价差数据"""
        if symbol not in self._spreads:
            self._spreads[symbol] = deque(maxlen=50)
        self._spreads[symbol].append((time.time(), spread_pct))

    def check_all(self, symbol: str = "ETH") -> list[AnomalyEvent]:
        """运行全部检测，返回新发现的异动事件"""
        new_events = []
        current_price = self._prices[-1] if self._prices else 0
        if current_price <= 0:
            return new_events

        # 1. 闪崩/暴拉
        event = self._check_flash(symbol, current_price)
        if event:
            new_events.append(event)

        # 2. 成交量暴增
        event = self._check_volume_spike(symbol)
        if event:
            new_events.append(event)

        # 3. 价差异常
        event = self._check_spread(symbol)
        if event:
            new_events.append(event)

        # 4. OI异动
        event = self._check_oi_anomaly(symbol, current_price)
        if event:
            new_events.append(event)

        self.events.extend(new_events)
        if len(self.events) > 200:
            self.events = self.events[-200:]

        return new_events

    # ═══════════════════════════════════════════════════════════
    #  检测方法
    # ═══════════════════════════════════════════════════════════

    def _check_flash(self, symbol: str, current_price: float) -> Optional[AnomalyEvent]:
        """闪崩/暴拉检测"""
        now = time.time()
        cooldown_key = f"flash_{symbol}"
        if cooldown_key in self._cooldowns and now - self._cooldowns[cooldown_key] < 120:
            return None

        if len(self._prices) < 5:
            return None

        # 最近60秒内的价格
        recent_prices = [p for i, p in enumerate(self._prices)
                        if now - self._timestamps[i] < self.flash_window_sec]
        if len(recent_prices) < 3:
            recent_prices = list(self._prices)[-10:]

        if not recent_prices:
            return None

        low = min(recent_prices)
        high = max(recent_prices)
        if low <= 0:
            return None

        # 从最早到现在的变化
        change_pct = (current_price - recent_prices[0]) / recent_prices[0] * 100
        range_pct = (high - low) / low * 100

        if abs(change_pct) >= self.flash_threshold_pct or range_pct >= self.flash_threshold_pct * 1.5:
            direction = "暴拉" if change_pct > 0 else "闪崩"
            severity = "critical" if abs(change_pct) >= self.flash_threshold_pct * 2 else "warning"
            self._cooldowns[cooldown_key] = now

            return AnomalyEvent(
                type="flash_crash" if change_pct < 0 else "flash_pump",
                severity=severity,
                symbol=symbol,
                price=current_price,
                detail=f"{direction}: {abs(change_pct):.2f}% ({self.flash_window_sec}s内) "
                       f"范围{range_pct:.2f}% 低{low:.1f} 高{high:.1f}",
                metrics={
                    "data_source": "okx_rest",
                    "change_pct": round(change_pct, 3),
                    "range_pct": round(range_pct, 3),
                    "low": round(low, 2),
                    "high": round(high, 2),
                    "window_sec": self.flash_window_sec,
                },
            )
        return None

    def _check_volume_spike(self, symbol: str) -> Optional[AnomalyEvent]:
        """成交量暴增检测"""
        now = time.time()
        cooldown_key = f"vol_{symbol}"
        if cooldown_key in self._cooldowns and now - self._cooldowns[cooldown_key] < 180:
            return None

        if symbol not in self._volumes or len(self._volumes[symbol]) < 10:
            return None

        vols = list(self._volumes[symbol])
        recent = [v[1] for v in vols if now - v[0] < 60]
        baseline = [v[1] for v in vols if 60 < now - v[0] < 300]

        if len(recent) < 3 or not baseline:
            return None

        avg_recent = sum(recent) / len(recent)
        avg_baseline = sum(baseline) / len(baseline)

        if avg_baseline <= 0:
            return None

        ratio = avg_recent / avg_baseline

        if ratio >= self.vol_spike_ratio:
            self._cooldowns[cooldown_key] = now
            return AnomalyEvent(
                type="vol_spike",
                severity="warning" if ratio < 5 else "critical",
                symbol=symbol,
                price=self._prices[-1] if self._prices else 0,
                detail=f"成交量暴增: {ratio:.1f}x 基线 ({avg_recent:.0f} vs {avg_baseline:.0f}/period)",
                metrics={"ratio": round(ratio, 2), "current": round(avg_recent, 1), "baseline": round(avg_baseline, 1)},
            )
        return None

    def _check_spread(self, symbol: str) -> Optional[AnomalyEvent]:
        """价差异常扩大检测"""
        now = time.time()
        if symbol not in self._spreads or len(self._spreads[symbol]) < 5:
            return None

        recent = [s[1] for s in self._spreads[symbol] if now - s[0] < 60]
        if not recent:
            return None

        avg_spread = sum(recent) / len(recent)

        if avg_spread >= self.spread_threshold_pct:
            return AnomalyEvent(
                type="spread",
                severity="info" if avg_spread < 0.3 else "warning",
                symbol=symbol,
                price=self._prices[-1] if self._prices else 0,
                detail=f"价差异常扩大: {avg_spread:.3f}%",
                metrics={"spread_pct": round(avg_spread, 4)},
            )
        return None

    def _check_oi_anomaly(self, symbol: str, current_price: float) -> Optional[AnomalyEvent]:
        """OI异动检测"""
        now = time.time()
        cooldown_key = f"oi_{symbol}"
        if cooldown_key in self._cooldowns and now - self._cooldowns[cooldown_key] < 300:
            return None

        try:
            cfg = self.instruments.get(symbol.upper())
            if not cfg:
                return None
            # 获取OI历史
            oi_data = self.okx.open_interest_hist(cfg["inst"], period="5m")
            if not oi_data or len(oi_data) < 3:
                return None

            current_oi = float(oi_data[0][1])
            prev_oi = float(oi_data[2][1]) if len(oi_data) > 2 else current_oi

            if prev_oi <= 0:
                return None

            change = (current_oi - prev_oi) / prev_oi

            if abs(change) >= self.oi_change_threshold:
                self._cooldowns[cooldown_key] = now
                direction = "暴增" if change > 0 else "暴跌"
                return AnomalyEvent(
                    type="oi",
                    severity="warning" if abs(change) < 0.2 else "critical",
                    symbol=symbol,
                    price=current_price,
                    detail=f"OI{direction}: {change*100:+.1f}% ({prev_oi:.0f} → {current_oi:.0f})",
                    metrics={"change_pct": round(change*100, 2), "prev": round(prev_oi, 1), "current": round(current_oi, 1)},
                )
        except Exception as e:
            logger.debug(f"[异动] OI检查异常: {e}")
        return None

    # ═══════════════════════════════════════════════════════════
    #  通知接口
    # ═══════════════════════════════════════════════════════════

    def get_unrelayed_events(self) -> list[AnomalyEvent]:
        """获取尚未通知战略部的事件"""
        return [e for e in self.events if not e.relayed]

    def mark_relayed(self, event: AnomalyEvent):
        event.relayed = True

    def get_recent_events(self, n: int = 10) -> list[dict]:
        """最近N个异动事件"""
        return [
            {"ts": time.strftime("%H:%M:%S", time.localtime(e.timestamp)),
             "type": e.type, "severity": e.severity, "symbol": e.symbol,
             "price": e.price, "detail": e.detail, "relayed": e.relayed}
            for e in self.events[-n:]
        ]

    def reset(self):
        self._prices.clear()
        self._timestamps.clear()
        self.events.clear()
        self._relayed_events.clear()
        self._cooldowns.clear()
