# -*- coding: utf-8 -*-
"""
signal_filter.py — 延迟感知信号过滤器

我们的系统存在REST轮询延迟(20-80s)，不能做高频抢跑，应该做"慢而准"的确认型系统。

核心思路:
  1. 不接受延迟过大的信号 (stale rejection)
  2. 信号发出后等待确认期 (post-signal validation)
  3. 如果价格已朝信号方向大幅移动 → 等回调再入场
  4. 高时间框架必须确认方向 (HTF alignment)
  5. 有毒订单流作为禁止入场信号 (veto)，而非入场信号
"""

import time
import math
import logging
from collections import deque
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger("MyTrader")


@dataclass
class FilteredSignal:
    """过滤后的信号"""
    action: str = "WAIT"       # LONG / SHORT / WAIT / VETO
    original_action: str = ""  # 原始信号方向
    entry_mode: str = ""       # immediate / pullback / confirmed
    target_price: float = 0    # 建议入场价
    confidence: float = 0      # 最终置信度 [0,1]
    reason: str = ""           # 过滤结果说明
    details: dict = field(default_factory=dict)


class DelayAwareFilter:
    """
    延迟感知信号过滤器

    配置:
      max_signal_age: 信号最大允许延迟(秒)，超时直接拒绝
      confirmation_period: 信号发出后的确认等待时间(秒)
      max_price_slippage: 价格已朝信号方向移动超过此比例 → 等回调
      pullback_pct: 回调目标比例
      htf_required: 是否要求高时间框架确认
    """

    def __init__(self, max_signal_age=120, confirmation_period=30,
                 max_price_slippage=0.005, pullback_pct=0.003, htf_required=True):
        self.max_signal_age = max_signal_age
        self.confirmation_period = confirmation_period
        self.max_price_slippage = max_price_slippage
        self.pullback_pct = pullback_pct
        self.htf_required = htf_required

        # 信号记忆
        self._pending_signals = {}    # {sig_id: {signal_data}}
        self._confirmed_events = deque(maxlen=200)  # 已确认的事件
        self._veto_events = deque(maxlen=100)       # 被否决的事件
        self._price_checkpoints = {}  # {sig_id: price_at_signal}

    # ═══════════════════════════════════════════════════════════
    #  1. 延迟检查 - 拒绝过期信号
    # ═══════════════════════════════════════════════════════════

    def check_freshness(self, signal_ts: float, now: float = None) -> tuple[bool, str]:
        """
        检查信号是否过期
        返回: (is_fresh, reason)
        """
        if now is None:
            now = time.time()
        age = now - signal_ts

        if age > self.max_signal_age:
            return False, f"信号过期({age:.0f}s > {self.max_signal_age}s)"
        elif age > self.max_signal_age * 0.5:
            return True, f"信号偏旧({age:.0f}s)，降权处理"
        else:
            return True, f"信号新鲜({age:.0f}s)"

    # ═══════════════════════════════════════════════════════════
    #  2. 价格滑移检查 - 信号发出后价格已跑远
    # ═══════════════════════════════════════════════════════════

    def check_slippage(self, direction: str, signal_price: float,
                        current_price: float) -> tuple[str, float, str]:
        """
        检查价格滑移
        返回: (entry_mode, suggested_entry, reason)

        - 价格未动或反方向: immediate入场 (好的入场点)
        - 价格朝信号方向小幅移动: immediate入场 (趋势确认)
        - 价格朝信号方向大幅移动: pullback入场 (等待回调)
        - 价格反方向大幅移动: VETO (信号可能错误)
        """
        if signal_price <= 0 or current_price <= 0:
            return "veto", 0, "价格无效"

        change_pct = (current_price - signal_price) / signal_price

        if direction == "LONG":
            if change_pct < -self.max_price_slippage * 2:
                # 价格大幅下跌 → 做多信号可能错误
                return "veto", 0, f"价格逆势大跌{change_pct:.2%}，可能信号错误"
            elif change_pct < -self.max_price_slippage:
                # 价格小幅下跌 → 更好的入场价
                return "immediate", current_price, f"价格回调{change_pct:.2%} → 入场"
            elif change_pct > self.max_price_slippage * 1.5:
                # 价格大幅上涨 → 已经跑了，等回调
                target = signal_price + (current_price - signal_price) * 0.5
                return "pullback", round(target, 2), f"价格已涨{change_pct:.2%} → 等回调至{target:.2f}"
            elif change_pct > 0:
                return "immediate", current_price, f"价格微涨{change_pct:.2%} → 趋势确认，入场"
            else:
                return "immediate", current_price, "价格未动 → 入场"

        else:  # SHORT
            if change_pct > self.max_price_slippage * 2:
                return "veto", 0, f"价格逆势大涨{change_pct:.2%}，可能信号错误"
            elif change_pct > self.max_price_slippage:
                return "immediate", current_price, f"价格反弹{change_pct:.2%} → 更好入场"
            elif change_pct < -self.max_price_slippage * 1.5:
                target = signal_price - (signal_price - current_price) * 0.5
                return "pullback", round(target, 2), f"价格已跌{abs(change_pct):.2%} → 等回调至{target:.2f}"
            elif change_pct < 0:
                return "immediate", current_price, f"价格微跌{abs(change_pct):.2%} → 趋势确认，入场"
            else:
                return "immediate", current_price, "价格未动 → 入场"

    # ═══════════════════════════════════════════════════════════
    #  3. HTF对齐检查 - 高时间框架必须确认
    # ═══════════════════════════════════════════════════════════

    def check_htf_alignment(self, direction: str, t1h: int, t4h: int) -> tuple[bool, str]:
        """
        检查高时间框架是否支持此方向
        t1h/t4h: +1(long), -1(short), 0(neutral)
        """
        if not self.htf_required:
            return True, "HTF检查已关闭"

        if direction == "LONG":
            if t4h == 1 and t1h >= 0:
                return True, "4H/1H均支持多头"
            elif t4h == 1:
                return True, "4H支持多头(1H中性)"
            elif t1h == 1:
                return True, "1H支持多头(4H中性)"
            elif t4h == -1:
                return False, "4H空头 → 拒绝做多"
            elif t1h == -1:
                return False, "1H空头 → 拒绝做多"
            else:
                return False, "HTF方向不明 → 需等待"

        else:  # SHORT
            if t4h == -1 and t1h <= 0:
                return True, "4H/1H均支持空头"
            elif t4h == -1:
                return True, "4H支持空头(1H中性)"
            elif t1h == -1:
                return True, "1H支持空头(4H中性)"
            elif t4h == 1:
                return False, "4H多头 → 拒绝做空"
            elif t1h == 1:
                return False, "1H多头 → 拒绝做空"
            else:
                return False, "HTF方向不明 → 需等待"

    # ═══════════════════════════════════════════════════════════
    #  4. 有毒流否决 - 有毒订单流时禁止入场
    # ═══════════════════════════════════════════════════════════

    def check_toxic_veto(self, direction: str, toxic_level: str,
                          toxic_bias: str) -> tuple[bool, str]:
        """
        有毒订单流否决逻辑:
        - danger级别: 完全禁止入场
        - warning级别 + 方向冲突: 禁止入场
        - watch级别: 降低置信度，但不禁止
        """
        if toxic_level == "danger":
            return False, f"有毒流danger级别 → 全面禁止入场"
        if toxic_level == "warning":
            # 有毒方向与入场方向一致 → 可能被操纵，禁止
            if (direction == "LONG" and toxic_bias == "buy_toxic") or \
               (direction == "SHORT" and toxic_bias == "sell_toxic"):
                return False, f"有毒流与入场方向一致 → 可能被操纵"
            # 有毒方向与入场方向相反 → 对手方在撤退，有利
            if (direction == "LONG" and toxic_bias == "sell_toxic") or \
               (direction == "SHORT" and toxic_bias == "buy_toxic"):
                return True, "对手方中毒 → 有利方向"
        if toxic_level == "watch":
            return True, f"有毒流轻度({toxic_bias}) → 降低仓位"
        return True, "无毒信号 → 正常入场"

    # ═══════════════════════════════════════════════════════════
    #  5. 综合过滤
    # ═══════════════════════════════════════════════════════════

    def filter(self, signal_action: str, signal_price: float, signal_ts: float,
               current_price: float, t1h: int = 0, t4h: int = 0,
               toxic_level: str = "normal", toxic_bias: str = "neutral") -> FilteredSignal:
        """
        综合信号过滤 — 依次检查:
          1. 延迟 → 过期直接拒绝
          2. 滑移 → 确定入场模式(immediate/pullback/veto)
          3. HTF → 高时间框架确认
          4. 有毒流 → 否决检查

        返回 FilteredSignal
        """
        now = time.time()

        # Step 1: 延迟检查
        fresh, fresh_reason = self.check_freshness(signal_ts, now)
        if not fresh:
            return FilteredSignal(
                action="WAIT", original_action=signal_action,
                reason=fresh_reason, confidence=0,
            )

        # Step 2: 滑移检查
        entry_mode, target_price, slip_reason = self.check_slippage(
            signal_action, signal_price, current_price)
        if entry_mode == "veto":
            self._veto_events.append({
                "ts": now, "action": signal_action,
                "reason": slip_reason, "signal_price": signal_price,
            })
            return FilteredSignal(
                action="VETO", original_action=signal_action,
                reason=slip_reason, confidence=0,
            )

        # Step 3: HTF对齐
        htf_ok, htf_reason = self.check_htf_alignment(signal_action, t1h, t4h)
        if not htf_ok:
            return FilteredSignal(
                action="WAIT", original_action=signal_action,
                reason=htf_reason, confidence=0.15,
                entry_mode=entry_mode, target_price=target_price,
            )

        # Step 4: 有毒流否决
        toxic_ok, toxic_reason = self.check_toxic_veto(
            signal_action, toxic_level, toxic_bias)
        if not toxic_ok:
            self._veto_events.append({
                "ts": now, "action": signal_action,
                "reason": toxic_reason,
            })
            return FilteredSignal(
                action="VETO", original_action=signal_action,
                reason=toxic_reason, confidence=0,
            )

        # 计算最终置信度
        confidence = 0.8  # 基准
        if "偏旧" in fresh_reason:
            confidence -= 0.2
        if entry_mode == "pullback":
            confidence -= 0.15
        if toxic_level == "watch":
            confidence -= 0.15
        if toxic_ok and "对手方中毒" in toxic_reason:
            confidence += 0.1

        reason_parts = [fresh_reason, slip_reason, htf_reason, toxic_reason]
        reason = " | ".join(p for p in reason_parts if p)

        result = FilteredSignal(
            action=signal_action,
            original_action=signal_action,
            entry_mode=entry_mode,
            target_price=target_price,
            confidence=round(min(1.0, max(0, confidence)), 3),
            reason=reason,
            details={
                "signal_price": signal_price,
                "current_price": current_price,
                "signal_age_s": round(now - signal_ts, 1),
                "price_change_pct": round((current_price - signal_price) / signal_price * 100, 3),
                "t1h": t1h, "t4h": t4h,
                "toxic_level": toxic_level, "toxic_bias": toxic_bias,
            },
        )

        self._pending_signals[id(result)] = {
            "filtered": result, "ts": now,
        }

        return result

    # ═══════════════════════════════════════════════════════════
    #  查询
    # ═══════════════════════════════════════════════════════════

    def get_recent_vetoes(self, n: int = 10) -> list:
        return list(self._veto_events)[-n:]

    def get_stats(self) -> dict:
        return {
            "pending_count": len(self._pending_signals),
            "veto_count": len(self._veto_events),
            "confirmed_count": len(self._confirmed_events),
            "recent_vetoes": self.get_recent_vetoes(5),
        }
