# -*- coding: utf-8 -*-
"""
exchange/okx_ws.py
OKX WebSocket 订单簿实时接收 + 幌骗检测（v2）

幌骗检测核心逻辑（修正版）：
  - 跟踪"固定价格档位"上的挂单生命周期
  - 只有在"价格未触及该档位"的前提下订单消失，才判定为幌骗
  - 避免把正常成交（价格穿越档位）误判为撤单

检测流程：
  1. 发现某价格档位挂单量暴增（>基线3x）→ 记录为候选墙
  2. 每5秒检查该档位是否仍存在于订单簿
  3. 若订单簿中该档位消失/大幅缩减：
     - 若当前价格已触及该档位（被成交）→ 正常，忽略
     - 若当前价格未触及（仍有安全距离）→ 撤单，幌骗！
"""

import json
import time
import logging
import threading
from collections import deque, defaultdict

logger = logging.getLogger("MyTrader")

WS_URL  = "wss://ws.okx.com:8443/ws/v5/public"

# 检测参数
DETECT_INTERVAL  = 5.0    # 检测间隔（秒）
SURGE_RATIO      = 3.0    # 暴增倍率阈值
SURGE_MIN_VOL    = 15.0   # 最小挂单量（张）才进入监控
WITHDRAW_RATIO   = 0.60   # 撤单比例阈值（60%+视为撤单）
MONITOR_WINDOW   = 20.0   # 监控窗口（秒），超过则视为真实墙
PRICE_TOUCH_PCT  = 0.0003 # 价格触及判定：档位距当前价格<0.03%视为被吃
HISTORY_LEN      = 20     # 保留历史条数（用于计算基线）
PRICE_BUCKET     = 0.10   # 价格分箱精度（USDT），合并相近档位


def _bucket(price: float) -> float:
    """将价格归并到最近的桶，减少档位碎片"""
    return round(price / PRICE_BUCKET) * PRICE_BUCKET


class OKXBookWS:
    """
    WebSocket 订单簿订阅器 + 档位级幌骗检测
    使用方式：
        ws = OKXBookWS("ETH-USDT-SWAP", on_alert=my_callback, proxies=PROXIES)
        ws.start()
        summary = ws.get_summary()
        ws.stop()
    """

    def __init__(self, inst_id: str, on_alert=None, proxies: dict = None):
        self.inst_id  = inst_id
        self.on_alert = on_alert
        self.proxies  = proxies

        # 最新订单簿（线程安全）
        self._book_lock = threading.Lock()
        self._asks: list = []   # [[price_str, size_str, ...], ...]
        self._bids: list = []
        self._price: float = 0.0

        # 每个价格档位的历史均量（用于基线）
        # {bucket_price: deque([vol, vol, ...])}
        self._vol_history: dict = defaultdict(lambda: deque(maxlen=HISTORY_LEN))

        # 当前监控中的候选幌骗墙
        # {bucket_price: {"side", "ts", "peak_vol", "price_at_detect"}}
        self._watching: dict = {}

        # 对外状态
        self.alert_score: float = 0.0
        self.risk_level:  str   = "normal"
        self.last_alert:  dict  = None
        self._wall_side:  str   = None   # 最近触发的墙方向

        # 有毒订单流检测器
        self.toxic_flow = None  # 由外部注入或延迟初始化

        self._running      = False
        self._ws_thread    = None
        self._detect_thread = None

    # ─── 公开接口 ──────────────────────────────────────────
    def start(self):
        self._running = True
        self._ws_thread = threading.Thread(
            target=self._ws_loop, daemon=True, name="okx-ws-book")
        self._ws_thread.start()
        self._detect_thread = threading.Thread(
            target=self._detect_loop, daemon=True, name="okx-spoof-detect")
        self._detect_thread.start()
        logger.info(f"[WS订单簿] 已启动，标的={self.inst_id}")

    def stop(self):
        self._running = False

    def get_book(self):
        with self._book_lock:
            return list(self._asks), list(self._bids), self._price

    def get_summary(self) -> dict:
        return {
            "risk_level": self.risk_level,
            "score":      self.alert_score,
            "wall_active": len(self._watching) > 0,
            "wall_side":  self._wall_side,
            "last_alert": self.last_alert,
        }

    # ─── WebSocket 线程 ────────────────────────────────────
    def _ws_loop(self):
        import websocket
        while self._running:
            try:
                ws_kwargs = {}
                if self.proxies:
                    proxy_url = (self.proxies.get("http")
                                 or self.proxies.get("https", ""))
                    if proxy_url:
                        import re
                        m = re.search(r'(?:://)?([^:/]+):(\d+)', proxy_url)
                        if m:
                            ws_kwargs["http_proxy_host"] = m.group(1)
                            ws_kwargs["http_proxy_port"] = int(m.group(2))
                            ws_kwargs["proxy_type"]      = "http"

                ws = websocket.WebSocketApp(
                    WS_URL,
                    on_open=self._on_open,
                    on_message=self._on_message,
                    on_error=self._on_error,
                    on_close=self._on_close,
                )
                ws.run_forever(**ws_kwargs)
            except Exception as e:
                logger.warning(f"[WS订单簿] 连接异常: {e}，5s后重连")
            if self._running:
                time.sleep(5)

    def _on_open(self, ws):
        sub = json.dumps({
            "op": "subscribe",
            "args": [{"channel": "books5", "instId": self.inst_id}]
        })
        ws.send(sub)
        logger.info(f"[WS订单簿] 已订阅 books5/{self.inst_id}")

    def _on_message(self, ws, msg):
        try:
            d = json.loads(msg)
            if d.get("event"):
                return
            data_list = d.get("data", [])
            if not data_list:
                return
            book = data_list[0]
            with self._book_lock:
                self._asks = book.get("asks", [])
                self._bids = book.get("bids", [])
                if self._asks and self._bids:
                    self._price = (
                        float(self._asks[0][0]) + float(self._bids[0][0])
                    ) / 2
        except Exception as e:
            logger.debug(f"[WS订单簿] 消息解析: {e}")

    def _on_error(self, ws, error):
        logger.warning(f"[WS订单簿] WS错误: {error}")

    def _on_close(self, ws, code, msg):
        logger.info(f"[WS订单簿] 连接关闭: {code} {msg}")

    # ─── 幌骗检测线程 ──────────────────────────────────────
    def _detect_loop(self):
        time.sleep(3)
        while self._running:
            t0 = time.time()
            try:
                self._run_detect()
            except Exception as e:
                logger.debug(f"[幌骗检测] 异常: {e}")
            elapsed = time.time() - t0
            time.sleep(max(0, DETECT_INTERVAL - elapsed))

    def _run_detect(self):
        asks, bids, price = self.get_book()
        if not asks or not bids or price <= 0:
            return

        now = time.time()

        # ── Step1：构建当前档位快照（price→vol，分箱归并）──
        cur_ask: dict = {}  # bucket_price → vol
        cur_bid: dict = {}
        for a in asks:
            bp = _bucket(float(a[0]))
            cur_ask[bp] = cur_ask.get(bp, 0) + float(a[1])
        for b in bids:
            bp = _bucket(float(b[0]))
            cur_bid[bp] = cur_bid.get(bp, 0) + float(b[1])

        # ── Step2：更新历史均量 ──
        all_buckets = set(cur_ask) | set(cur_bid)
        for bp in all_buckets:
            vol = cur_ask.get(bp, 0) + cur_bid.get(bp, 0)
            self._vol_history[bp].append(vol)

        # ── Step3：检测新大墙出现 ──
        for bp, vol in {**cur_ask, **cur_bid}.items():
            if bp in self._watching:
                continue  # 已在监控
            hist = self._vol_history[bp]
            if len(hist) < 3:
                continue
            baseline = sum(list(hist)[:-1]) / (len(hist) - 1) or 1
            if vol < SURGE_MIN_VOL:
                continue
            ratio = vol / baseline
            if ratio >= SURGE_RATIO:
                side = "sell" if bp in cur_ask and cur_ask.get(bp, 0) > cur_bid.get(bp, 0) else "buy"
                self._watching[bp] = {
                    "side":             side,
                    "ts":               now,
                    "peak_vol":         vol,
                    "price_at_detect":  price,
                    "baseline":         baseline,
                }
                self._wall_side = side
                self.risk_level = "watch"
                logger.info(
                    f"[幌骗检测] ⚠️ 大{side}墙 @ ${bp:.2f}"
                    f" 挂单={vol:.1f}（基线{baseline:.1f}的{ratio:.1f}x）"
                    f" 当前价=${price:.2f}"
                )

        # ── Step4：检查监控中的档位 ──
        to_remove = []
        for bp, ev in list(self._watching.items()):
            elapsed = now - ev["ts"]
            side    = ev["side"]

            # 当前该档位的量
            cur_vol = (cur_ask.get(bp, 0) if side == "sell"
                       else cur_bid.get(bp, 0))
            if cur_vol > ev["peak_vol"]:
                ev["peak_vol"] = cur_vol  # 峰值更新

            withdraw = 1 - cur_vol / ev["peak_vol"] if ev["peak_vol"] > 0 else 0

            # 判断价格是否已经触及该档位（正常成交）
            if side == "sell":
                # 卖墙：价格从下方接近，触及意味着价格 >= bp*(1-PRICE_TOUCH_PCT)
                price_touched = price >= bp * (1 - PRICE_TOUCH_PCT)
            else:
                # 买墙：价格从上方接近，触及意味着价格 <= bp*(1+PRICE_TOUCH_PCT)
                price_touched = price <= bp * (1 + PRICE_TOUCH_PCT)

            if price_touched:
                # 价格已到达该档位 → 正常成交，不是幌骗
                logger.info(
                    f"[幌骗检测] ✅ {side}墙 @ ${bp:.2f}"
                    f" 被正常成交（价格=${price:.2f}），排除幌骗"
                )
                to_remove.append(bp)
                continue

            if elapsed > MONITOR_WINDOW:
                # 超时仍存在 → 真实支撑/阻力
                if withdraw < 0.3:
                    logger.info(
                        f"[幌骗检测] ✅ {side}墙 @ ${bp:.2f}"
                        f" 稳定 {elapsed:.0f}s，判定真实支撑/阻力"
                    )
                to_remove.append(bp)
                self.alert_score *= 0.6
                continue

            # 在监控窗口内，价格未触及，但量大幅缩减 → 幌骗！
            if withdraw >= WITHDRAW_RATIO:
                direction = "上涨" if side == "sell" else "下跌"
                dist_pct  = abs(price - bp) / bp * 100
                score     = round(withdraw * (1 if side == "sell" else -1), 3)
                self.alert_score = score
                self.risk_level  = "danger"
                alert = {
                    "type":          "spoof",
                    "side":          side,
                    "level_price":   bp,
                    "withdraw_pct":  round(withdraw * 100, 1),
                    "elapsed":       round(elapsed, 1),
                    "price":         price,
                    "dist_pct":      round(dist_pct, 3),
                    "direction":     direction,
                    "score":         score,
                    "msg": (
                        f"‼️ 幌骗告警：{side}墙 @ ${bp:.2f}"
                        f" 在{elapsed:.1f}s内撤单{withdraw*100:.0f}%"
                        f"（价格=${price:.2f}，距档位{dist_pct:.2f}%，未被成交）"
                        f" 预期{direction}"
                    )
                }
                self.last_alert = alert
                self._wall_side = side
                logger.warning(f"[幌骗检测] {alert['msg']}")
                if self.on_alert:
                    self.on_alert(alert)
                to_remove.append(bp)

            elif withdraw >= 0.35:
                logger.info(
                    f"[幌骗检测] ⚠️ {side}墙 @ ${bp:.2f}"
                    f" 撤单{withdraw*100:.0f}%（{elapsed:.0f}s），监控中"
                    f" 距价格{abs(price-bp)/bp*100:.2f}%"
                )

        for bp in to_remove:
            self._watching.pop(bp, None)

        # 喂入有毒订单流检测器
        if self.toxic_flow:
            self.toxic_flow.feed_orderbook(asks, bids, price)
            tox_summary = self.toxic_flow.analyze()
            if tox_summary.level in ("warning", "danger"):
                logger.info(f"[有毒流] 等级={tox_summary.level} 评分={tox_summary.score:.2f} bias={tox_summary.bias}")
                # 触发有毒流告警（独立于幌骗告警）
                if tox_summary.level == "danger" and self.on_alert and self.toxic_flow.predict_move()["action"] != "WAIT":
                    pred = self.toxic_flow.predict_move()
                    self.on_alert({
                        "type": "toxic_flow",
                        "level": tox_summary.level,
                        "score": tox_summary.score,
                        "direction": pred["action"],
                        "confidence": pred["confidence"],
                        "msg": f"⚠️ 有毒订单流: {tox_summary.level} "
                               f"预测方向={pred['action']}(置信度{pred['confidence']:.1%}) "
                               f"{pred.get('mm_detail', '')}",
                    })

        # 评分衰减
        if self.risk_level != "danger":
            self.alert_score *= 0.85
            if abs(self.alert_score) < 0.02:
                self.alert_score = 0.0
                if not self._watching:
                    self.risk_level = "normal"
                    self._wall_side = None
