# -*- coding: utf-8 -*-
"""
toxic_flow.py — 有毒订单流检测
检测类型:
  1. Spoofing       — 幌骗挂单（大单快速撤单）
  2. Quote Stuffing — 报价填充（高频挂撤单扰乱市场）
  3. Layering       — 分层挂单（多档位同时挂大单）
  4. Momentum Ignition — 动量点火（集中成交触发连锁反应）
  5. Iceberg/Hidden — 冰山订单检测（大量碎单同向成交）
  6. Wash Trading   — 虚假成交量检测

输出: ToxicFlowSummary {score, level, alerts, direction}
"""

import time
import math
import logging
from collections import deque, defaultdict
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger("MyTrader")


@dataclass
class ToxicAlert:
    """有毒订单流告警"""
    type: str           # spoofing / stuffing / layering / ignition / iceberg / wash
    severity: str       # low / medium / high / critical
    direction: str      # buy / sell / neutral
    score: float        # 0~1
    detail: str         # 描述
    price: float = 0
    timestamp: float = field(default_factory=time.time)


@dataclass
class ToxicFlowSummary:
    """综合有毒订单流评估"""
    score: float = 0          # 0~1 有毒程度
    level: str = "normal"     # normal / watch / warning / danger
    bias: str = "neutral"     # buy / sell / neutral (哪方更有毒)
    alerts: list = field(default_factory=list)
    details: dict = field(default_factory=dict)


class ToxicFlowDetector:
    """
    有毒订单流检测器
    从订单簿快照 + 成交数据中识别6种操纵行为
    """

    def __init__(self, max_history: int = 100):
        self.max_history = max_history

        # 订单簿历史快照
        self._book_snapshots = deque(maxlen=max_history)
        self._book_update_rates = deque(maxlen=30)  # 每秒更新次数

        # 成交流
        self._trade_history = deque(maxlen=500)
        self._trade_imbalance = deque(maxlen=50)  # 每窗口买卖比

        # 档位追踪 (用于分层检测)
        self._level_walls = defaultdict(lambda: deque(maxlen=20))  # {price_level: [vol_history]}

        # 冰山订单追踪
        self._iceberg_suspects = {}  # {price_level: {count, start_ts}}

        # 状态
        self.last_update = 0
        self._alert_history = deque(maxlen=50)
        self._quote_throttle = 0  # 报价填充计数
        self._quote_throttle_reset = 0

    # ═══════════════════════════════════════════════════════════
    #  Feed methods — 外部调用
    # ═══════════════════════════════════════════════════════════

    def feed_orderbook(self, asks: list, bids: list, mid_price: float):
        """喂入订单簿快照"""
        now = time.time()

        # 记录更新频率
        self._book_update_rates.append(now)
        if self._book_update_rates[-1] - (self._book_update_rates[0] if len(self._book_update_rates) > 1 else now) > 0:
            pass  # update_rates are timestamps, computed below

        snapshot = {
            "ts": now, "asks": [list(a) for a in asks[:20]],
            "bids": [list(b) for b in bids[:20]], "mid": mid_price,
        }
        self._book_snapshots.append(snapshot)

        # 更新档位追踪
        for a in asks[:10]:
            px = round(float(a[0]), 1)
            vol = float(a[1])
            self._level_walls[px].append({"ts": now, "vol": vol, "side": "ask"})
        for b in bids[:10]:
            px = round(float(b[0]), 1)
            vol = float(b[1])
            self._level_walls[px].append({"ts": now, "vol": vol, "side": "bid"})

        self.last_update = now

    def feed_trades(self, trades: list):
        """喂入成交记录 [{side, price, size, ts}, ...]"""
        now = time.time()
        for t in trades:
            self._trade_history.append({
                "ts": t.get("ts", now),
                "side": t.get("side", ""),
                "price": float(t.get("price", 0)),
                "size": float(t.get("size", 0)),
            })

    # ═══════════════════════════════════════════════════════════
    #  1. Quote Stuffing — 报价填充检测
    # ═══════════════════════════════════════════════════════════

    def _detect_quote_stuffing(self) -> list[ToxicAlert]:
        """检测报价填充：每秒订单簿更新次数异常高"""
        alerts = []
        now = time.time()
        rates = list(self._book_update_rates)

        if len(rates) < 5:
            return alerts

        # 最近10秒内的更新次数
        recent = [t for t in rates if now - t < 10]
        rate_per_sec = len(recent) / 10 if recent else 0

        # 正常WS推送约10次/秒（books5），超过20次可能是填充攻击
        if rate_per_sec > 20:
            severity = "critical" if rate_per_sec > 50 else "high" if rate_per_sec > 35 else "medium"
            alerts.append(ToxicAlert(
                type="stuffing", severity=severity, direction="neutral",
                score=min(1.0, rate_per_sec / 50),
                detail=f"报价填充: {rate_per_sec:.0f}次/s（正常~10/s）",
                timestamp=now,
            ))
        elif rate_per_sec > 15:
            alerts.append(ToxicAlert(
                type="stuffing", severity="low", direction="neutral",
                score=rate_per_sec / 50,
                detail=f"报价加速: {rate_per_sec:.0f}次/s",
                timestamp=now,
            ))

        return alerts

    # ═══════════════════════════════════════════════════════════
    #  2. Layering — 分层挂单检测
    # ═══════════════════════════════════════════════════════════

    def _detect_layering(self) -> list[ToxicAlert]:
        """检测分层挂单：多个相邻档位同时出现大额挂单"""
        alerts = []
        if len(self._book_snapshots) < 2:
            return alerts

        snap = self._book_snapshots[-1]
        mid = snap["mid"]
        if mid <= 0:
            return alerts

        # 检查卖方分层
        ask_levels = [(round(float(a[0]), 1), float(a[1])) for a in snap["asks"][:15]]
        # 检查买方分层
        bid_levels = [(round(float(b[0]), 1), float(b[1])) for b in snap["bids"][:15]]

        for side, levels in [("sell", ask_levels), ("buy", bid_levels)]:
            # 找连续3+档挂单量都超过阈值的区间
            threshold = 20  # 最小张数
            consecutive = []
            for px, vol in levels:
                if vol >= threshold:
                    consecutive.append((px, vol))
                else:
                    if len(consecutive) >= 3:
                        break
                    consecutive = []

            if len(consecutive) >= 3:
                total_vol = sum(v for _, v in consecutive)
                price_range = (consecutive[-1][0] - consecutive[0][0]) / mid * 100
                severity = "high" if len(consecutive) >= 5 else "medium"
                alerts.append(ToxicAlert(
                    type="layering", severity=severity, direction=side,
                    score=min(1.0, len(consecutive) / 8),
                    detail=f"{'卖' if side == 'sell' else '买'}方{len(consecutive)}层挂单 "
                           f"总量{total_vol:.0f}张（范围{price_range:.2f}%）",
                    price=mid, timestamp=snap["ts"],
                ))

        return alerts

    # ═══════════════════════════════════════════════════════════
    #  3. Momentum Ignition — 动量点火检测
    # ═══════════════════════════════════════════════════════════

    def _detect_momentum_ignition(self) -> list[ToxicAlert]:
        """检测动量点火：短时间窗口内单边成交量激增"""
        alerts = []
        now = time.time()

        if len(self._trade_history) < 20:
            return alerts

        # 最近5秒内的成交
        recent = [t for t in self._trade_history if now - t["ts"] < 5]
        if len(recent) < 10:
            return alerts

        buy_vol = sum(t["size"] for t in recent if t["side"] == "buy")
        sell_vol = sum(t["size"] for t in recent if t["side"] == "sell")
        total = buy_vol + sell_vol
        if total <= 0:
            return alerts

        # 买卖失衡
        imbalance = abs(buy_vol - sell_vol) / total

        # 对比过去30秒的基线
        baseline_trades = [t for t in self._trade_history if 5 < now - t["ts"] < 35]
        if baseline_trades:
            baseline_total = sum(t["size"] for t in baseline_trades)
            baseline_rate = baseline_total / 30  # 每秒均量
            current_rate = total / 5
            volume_surge = current_rate / baseline_rate if baseline_rate > 0 else 1
        else:
            volume_surge = 1

        if imbalance > 0.65 and volume_surge > 2:
            direction = "buy" if buy_vol > sell_vol else "sell"
            severity = "critical" if imbalance > 0.85 else "high" if imbalance > 0.75 else "medium"
            alerts.append(ToxicAlert(
                type="ignition", severity=severity, direction=direction,
                score=min(1.0, imbalance * volume_surge / 5),
                detail=f"动量点火 {direction}: 失衡{imbalance:.1%} "
                       f"量激增{volume_surge:.1f}x（{total:.0f}张/5s）",
                price=recent[-1]["price"] if recent else 0, timestamp=now,
            ))

        return alerts

    # ═══════════════════════════════════════════════════════════
    #  4. Iceberg / Hidden Order — 冰山订单检测
    # ═══════════════════════════════════════════════════════════

    def _detect_iceberg(self) -> list[ToxicAlert]:
        """检测冰山订单：同一价位反复出现等量成交"""
        alerts = []
        now = time.time()

        if len(self._trade_history) < 30:
            return alerts

        # 按价位聚合成交
        price_buckets = defaultdict(list)
        for t in self._trade_history:
            if now - t["ts"] < 30:
                bp = round(t["price"], 1)
                price_buckets[bp].append(t["size"])

        for bp, sizes in price_buckets.items():
            if len(sizes) < 5:
                continue
            # 检测等量成交模式（冰山订单特征：连续等量碎单）
            avg_size = sum(sizes) / len(sizes)
            if avg_size < 1:
                continue
            # 方差/均值比小 = 等量
            variance = sum((s - avg_size) ** 2 for s in sizes) / len(sizes)
            cv = math.sqrt(variance) / avg_size if avg_size > 0 else 1  # 变异系数

            if cv < 0.3 and len(sizes) >= 5:  # 等量模式
                direction = "neutral"
                score = min(1.0, len(sizes) / 20)
                alerts.append(ToxicAlert(
                    type="iceberg", severity="low" if cv < 0.3 else "medium",
                    direction=direction,
                    score=score,
                    detail=f"疑似冰山订单 @${bp:.1f}: {len(sizes)}笔均量{avg_size:.1f}（CV={cv:.2f}）",
                    price=bp, timestamp=now,
                ))

        return alerts

    # ═══════════════════════════════════════════════════════════
    #  5. Wash Trading — 虚假成交量检测
    # ═══════════════════════════════════════════════════════════

    def _detect_wash_trading(self) -> list[ToxicAlert]:
        """检测虚假成交：价格几乎不变但成交量异常大"""
        alerts = []
        now = time.time()

        if len(self._trade_history) < 50:
            return alerts

        recent = [t for t in self._trade_history if now - t["ts"] < 30]
        if len(recent) < 20:
            return alerts

        prices = [t["price"] for t in recent]
        volumes = [t["size"] for t in recent]

        avg_price = sum(prices) / len(prices)
        price_range = (max(prices) - min(prices)) / avg_price * 100 if avg_price > 0 else 0
        total_vol = sum(volumes)

        # 价格窄幅 + 大成交量 = 可疑洗盘
        if price_range < 0.1 and total_vol > 500:
            # 买卖交替 = 更可疑
            sides = [t["side"] for t in recent]
            alternations = sum(1 for i in range(1, len(sides)) if sides[i] != sides[i - 1])
            alt_ratio = alternations / len(sides)
            severity = "high" if alt_ratio > 0.6 else "medium"

            alerts.append(ToxicAlert(
                type="wash", severity=severity, direction="neutral",
                score=min(1.0, total_vol / 2000),
                detail=f"疑似wash trading: 价格波动{price_range:.2f}% "
                       f"成交{total_vol:.0f}张 交替率{alt_ratio:.1%}",
                price=avg_price, timestamp=now,
            ))

        return alerts

    # ═══════════════════════════════════════════════════════════
    #  6. Order Book Imbalance — 订单簿失衡检测
    # ═══════════════════════════════════════════════════════════

    def _detect_book_imbalance(self) -> list[ToxicAlert]:
        """检测订单簿深度严重失衡（可能预示大单方向）"""
        alerts = []
        if len(self._book_snapshots) < 2:
            return alerts

        snap = self._book_snapshots[-1]
        mid = snap["mid"]
        if mid <= 0:
            return alerts

        # 前5档买卖比
        ask_vol = sum(float(a[1]) for a in snap["asks"][:5])
        bid_vol = sum(float(b[1]) for b in snap["bids"][:5])
        total = ask_vol + bid_vol
        if total <= 0:
            return alerts

        ratio = ask_vol / bid_vol if bid_vol > 0 else 10

        if ratio > 3:
            alerts.append(ToxicAlert(
                type="imbalance", severity="high" if ratio > 5 else "medium",
                direction="sell", score=min(1.0, ratio / 8),
                detail=f"卖盘深度失衡: ask/bid={ratio:.1f}x ({ask_vol:.0f}/{bid_vol:.0f})",
                price=mid, timestamp=snap["ts"],
            ))
        elif ratio < 0.33:
            alerts.append(ToxicAlert(
                type="imbalance", severity="high" if ratio < 0.2 else "medium",
                direction="buy", score=min(1.0, (1 / max(ratio, 0.01)) / 8),
                detail=f"买盘深度失衡: bid/ask={1/ratio:.1f}x ({bid_vol:.0f}/{ask_vol:.0f})",
                price=mid, timestamp=snap["ts"],
            ))

        return alerts

    # ═══════════════════════════════════════════════════════════
    #  Composite — 综合评估
    # ═══════════════════════════════════════════════════════════

    def analyze(self) -> ToxicFlowSummary:
        """运行全部检测，返回综合评估"""
        alerts = []
        alerts += self._detect_quote_stuffing()
        alerts += self._detect_layering()
        alerts += self._detect_momentum_ignition()
        alerts += self._detect_iceberg()
        alerts += self._detect_wash_trading()
        alerts += self._detect_book_imbalance()

        # 去重（同type同direction合并）
        merged = self._merge_alerts(alerts)

        # 综合评分：加权平均，严重度越高权重越大
        weights = {"low": 0.3, "medium": 0.6, "high": 0.8, "critical": 1.0}
        total_weight = sum(weights.get(a.severity, 0.5) * a.score for a in merged)
        num = len(merged)
        composite_score = min(1.0, (total_weight / max(num, 1)) * min(num, 3) / 3) if num > 0 else 0

        # 风险等级
        if composite_score >= 0.7:
            level = "danger"
        elif composite_score >= 0.4:
            level = "warning"
        elif composite_score >= 0.15:
            level = "watch"
        else:
            level = "normal"

        # 方向偏向
        buy_count = sum(1 for a in merged if a.direction == "buy")
        sell_count = sum(1 for a in merged if a.direction == "sell")
        if sell_count > buy_count + 1:
            bias = "sell_toxic"
        elif buy_count > sell_count + 1:
            bias = "buy_toxic"
        else:
            bias = "neutral"

        # 详情统计
        type_counts = defaultdict(int)
        for a in merged:
            type_counts[a.type] += 1

        self._alert_history.extend(merged)

        return ToxicFlowSummary(
            score=round(composite_score, 3),
            level=level,
            bias=bias,
            alerts=merged,
            details={
                "alert_count": len(merged),
                "type_breakdown": dict(type_counts),
                "max_severity": max((a.severity for a in merged), key=lambda s: {"low": 0, "medium": 1, "high": 2, "critical": 3}.get(s, 0)) if merged else "none",
            },
        )

    # ═══════════════════════════════════════════════════════════
    #  7. Market Maker Retreat — 做市商退避检测
    # ═══════════════════════════════════════════════════════════

    def detect_mm_retreat(self) -> dict:
        """
        检测做市商退避信号:
        - 买卖价差扩大（spread widening）
        - 深度变薄（depth thinning）
        - 挂单量骤降（liquidity evaporation）
        返回: {retreating: bool, score: float, direction: str, detail: str}
        """
        if len(self._book_snapshots) < 10:
            return {"retreating": False, "score": 0, "direction": "neutral", "detail": ""}

        now = time.time()
        snap = self._book_snapshots[-1]
        mid = snap["mid"]
        if mid <= 0:
            return {"retreating": False, "score": 0, "direction": "neutral", "detail": ""}

        # 基线：过去30秒的均值
        baseline_snaps = [s for s in self._book_snapshots if now - s["ts"] < 30]
        if len(baseline_snaps) < 5:
            baseline_snaps = list(self._book_snapshots)[-10:]

        # 1. 价差检测
        cur_spread_pct = 0
        if snap["asks"] and snap["bids"]:
            best_ask = float(snap["asks"][0][0])
            best_bid = float(snap["bids"][0][0])
            cur_spread_pct = (best_ask - best_bid) / mid * 100

        base_spreads = []
        for bs in baseline_snaps:
            if bs["asks"] and bs["bids"]:
                ba = float(bs["asks"][0][0])
                bb = float(bs["bids"][0][0])
                base_spreads.append((ba - bb) / bs["mid"] * 100 if bs["mid"] > 0 else 0)

        avg_spread = sum(base_spreads) / len(base_spreads) if base_spreads else cur_spread_pct
        spread_widening = cur_spread_pct / avg_spread if avg_spread > 0 else 1

        # 2. 深度检测（前5档总挂单量）
        cur_ask_depth = sum(float(a[1]) for a in snap["asks"][:5])
        cur_bid_depth = sum(float(b[1]) for b in snap["bids"][:5])

        base_ask_depths = []
        base_bid_depths = []
        for bs in baseline_snaps:
            base_ask_depths.append(sum(float(a[1]) for a in bs["asks"][:5]))
            base_bid_depths.append(sum(float(b[1]) for b in bs["bids"][:5]))

        avg_ask_depth = sum(base_ask_depths) / len(base_ask_depths) if base_ask_depths else cur_ask_depth
        avg_bid_depth = sum(base_bid_depths) / len(base_bid_depths) if base_bid_depths else cur_bid_depth

        ask_thinning = 1 - cur_ask_depth / avg_ask_depth if avg_ask_depth > 0 else 0
        bid_thinning = 1 - cur_bid_depth / avg_bid_depth if avg_bid_depth > 0 else 0

        # 3. 综合判断
        retreat_score = 0
        retreat_side = "neutral"
        details = []

        if spread_widening > 1.5:
            retreat_score += 0.3
            details.append(f"价差扩大{spread_widening:.1f}x")

        if ask_thinning > 0.4:
            retreat_score += 0.35
            retreat_side = "sell"
            details.append(f"卖盘深度缩减{ask_thinning:.0%}")

        if bid_thinning > 0.4:
            retreat_score += 0.35
            retreat_side = "buy" if retreat_side == "neutral" else "both"
            details.append(f"买盘深度缩减{bid_thinning:.0%}")

        # 单边退避：一方深度大幅减少而另一方保持 → 预示方向性行情
        if abs(ask_thinning - bid_thinning) > 0.3 and retreat_score > 0.3:
            if ask_thinning > bid_thinning:
                retreat_side = "sell_mm_gone"  # 卖方做市商退避 → 价格可能上行
            else:
                retreat_side = "buy_mm_gone"   # 买方做市商退避 → 价格可能下行

        return {
            "retreating": retreat_score > 0.35,
            "score": round(min(1.0, retreat_score), 3),
            "direction": retreat_side,
            "spread_widening": round(spread_widening, 2),
            "ask_depth_change": round(-ask_thinning * 100, 1),
            "bid_depth_change": round(-bid_thinning * 100, 1),
            "detail": "; ".join(details) if details else "正常",
        }

    # ═══════════════════════════════════════════════════════════
    #  8. One-Sided Move Predictor — 单边行情预判
    # ═══════════════════════════════════════════════════════════

    def predict_move(self) -> dict:
        """
        综合有毒订单流 + MM退避 → 预判单边行情方向与强度
        核心逻辑：
          - 有毒sell方 + 买方MM退避 → 价格下行（空头机会）
          - 有毒buy方 + 卖方MM退避 → 价格上行（多头机会）
          - 只有单方有毒但MM未退避 → 大概率假突破
        """
        summary = self.analyze()
        mm = self.detect_mm_retreat()

        toxic_bias = summary.bias
        mm_retreat = mm["direction"]
        toxic_score = summary.score
        mm_score = mm["score"]

        # 方向推断
        move_direction = "WAIT"
        move_confidence = 0

        # 有毒sell + 买方MM退避 → 下行（做市商不接卖单，价格必跌）
        if toxic_bias == "sell_toxic" and mm_retreat in ("buy_mm_gone", "both"):
            move_direction = "SHORT"
            move_confidence = (toxic_score + mm_score) / 2

        # 有毒buy + 卖方MM退避 → 上行（做市商不接买单，价格必涨）
        elif toxic_bias == "buy_toxic" and mm_retreat in ("sell_mm_gone", "both"):
            move_direction = "LONG"
            move_confidence = (toxic_score + mm_score) / 2

        # 只有有毒信号，MM未退避 → 可能假信号
        elif toxic_score > 0.3 and mm_score < 0.2:
            move_direction = "WAIT"
            move_confidence = toxic_score * 0.3  # 大打折扣
            detail_extra = "MM未退避，可能假信号"

        # MM已退避但无毒信号 → 可能是正常流动性变化
        elif mm_score > 0.3 and toxic_score < 0.2:
            move_direction = "WAIT"
            move_confidence = mm_score * 0.2

        # 有明确方向信号
        if move_direction != "WAIT" and move_confidence > 0.4:
            action = "STRONG_" + move_direction if move_confidence > 0.7 else move_direction
        elif move_confidence > 0.25:
            action = move_direction
        else:
            action = "WAIT"

        return {
            "action": action,
            "direction": move_direction,
            "confidence": round(move_confidence, 3),
            "toxic_score": toxic_score,
            "mm_retreat_score": mm_score,
            "toxic_bias": toxic_bias,
            "mm_direction": mm_retreat,
            "toxic_detail": summary.details,
            "mm_detail": mm["detail"],
        }

    def _merge_alerts(self, alerts: list) -> list:
        """合并同类告警"""
        groups = defaultdict(list)
        for a in alerts:
            key = (a.type, a.direction)
            groups[key].append(a)

        merged = []
        for (atype, adir), items in groups.items():
            if len(items) == 1:
                merged.append(items[0])
            else:
                # 取最高严重度
                best = max(items, key=lambda a: {"low": 0, "medium": 1, "high": 2, "critical": 3}.get(a.severity, 0))
                best.score = min(1.0, sum(i.score for i in items) / len(items) + 0.1)
                best.detail += f" (+{len(items)-1}相关信号)"
                merged.append(best)

        return merged

    def get_recent_alerts(self, n: int = 10) -> list:
        return list(self._alert_history)[-n:]

    def reset(self):
        self._alert_history.clear()
        self._book_snapshots.clear()
        self._trade_history.clear()
        self._level_walls.clear()
