# -*- coding: utf-8 -*-
"""
strategy/copy_trader.py — 链上跟单引擎
紧盯盈利最优的链上地址，解析其 DEX/Perps 操作，生成跟单信号。

数据源: Etherscan API (txlist + balance)
更新频率: 60s
跟踪范围: 最近 5 分钟内交易

信号逻辑:
  - 地址买入 ETH → bullish (+1)
  - 地址卖出 ETH → bearish (-1)
  - 多地址加权聚合 → copy_score [-1, +1]
  - 历史准确率 → 置信度加权
"""

import json
import os
import re
import time
import logging
from datetime import datetime, timezone
from collections import deque

from exchange.http import http_get

logger = logging.getLogger("MyTrader")

# ── DEX Router 地址 & Swap 函数签名 ─────────────────────────
DEX_ROUTERS = {
    "0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D".lower(): "Uniswap V2",
    "0xE592427A0AEce92De3Edee1F18E0157C05861564".lower(): "Uniswap V3",
    "0x68b3465833fb72a70ecdff279ebf668e295bcb4c".lower(): "Uniswap Universal",
    "0x1111111254EEB25477B68fb85Ed929f73A960582".lower(): "1inch V5",
    "0x1111111254fb6c44bAC0beD2854e76F90643097d".lower(): "1inch V4",
    "0xDef1C0ded9bec7F1a1670819833240f027b25EfF".lower(): "0x Exchange",
    "0x881D40237659C251811CEC9c364ef91dC08D300C".lower(): "Metamask Swap",
}

# Swap 函数签名 (前4字节)
SWAP_SIGS_BUY_ETH = {
    # swapExactTokensForETH / swapTokensForExactETH / swapExactTokensForETHSupportingFee
    "0x18cbafe5", "0x4a25d94a", "0x791ac947",
    # swapExactETHForTokens / swapETHForExactTokens (ETH→Token: 吃ETH, 所以是卖ETH)
}
SWAP_SIGS_SELL_ETH = {
    # swapExactETHForTokens / swapETHForExactTokens / swapExactETHForTokensSupportingFee
    "0x7ff36ab5", "0xfb3bdb41", "0xb6f9de95",
    "0x38ed1739", "0x8803dbee",  # swapExactTokensForTokens (需看路径方向)
}
# 通用 swap (需看 input data 里的 path 来判断方向)
GENERIC_SWAP_SIGS = {"0x38ed1739", "0x8803dbee"}

# WETH 地址
WETH = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2".lower()
# 稳定币地址
STABLES = {
    "0xdAC17F958D2ee523a2206206994597C13D831ec7".lower(): "USDT",
    "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48".lower(): "USDC",
    "0x6B175474E89094C44Da98b954EedeAC495271d0F".lower(): "DAI",
}

# 交易所存款地址 (识别资金转入交易所 = 可能出货)
EXCHANGE_DEPOSIT = {
    "0x28C6c06298d514Db089934071355E5743bf21d60".lower(): "Binance 14",
    "0x9696f59E4d72E237BE84fFD425DCaD154Bf96976".lower(): "Binance 15",
    "0xF977814e90dA44bFA03b6295A0616a897441aceC".lower(): "Binance 16",
    "0x21a31Ee1afC51d94C2eFcCAa2092aD1028285549".lower(): "Binance 17",
    "0xBE0eB53F46cd790Cd13851d5EFf43D12404d33E8".lower(): "Binance 7",
}

# ── 地址 PnL 追踪 ───────────────────────────────────────────
ADDRESS_STATE_FILE = "copy_trader_state.json"
TRADE_WINDOW_S = 300  # 5分钟内交易视为有效信号


class CopyTrader:
    """
    链上跟单引擎
    - 监控配置的地址列表
    - 解析交易方向 (买入/卖出 ETH)
    - 追踪每个地址的 PnL 表现
    - 生成加权跟单信号
    """

    def __init__(self, addresses: list, etherscan_key: str):
        """
        addresses: [{"addr": "0x...", "label": "AlphaTrader1", "weight": 1.0}, ...]
        """
        self.addresses = addresses
        self.etherscan_key = etherscan_key
        self._last_block: dict = {}  # addr → last_seen_block
        self._tx_log: dict = {}     # addr → deque of recent trades
        self._pnl_tracker: dict = {}  # addr → {"calls": int, "wins": int, "total_pnl_eth": float}
        self._load_state()

    # ─── 状态持久化 ────────────────────────────────────────
    def _load_state(self):
        if os.path.exists(ADDRESS_STATE_FILE):
            try:
                with open(ADDRESS_STATE_FILE, "r", encoding="utf-8") as f:
                    s = json.load(f)
                    self._pnl_tracker = s.get("pnl", {})
                    self._last_block = s.get("last_block", {})
            except Exception:
                pass

    def _save_state(self):
        try:
            with open(ADDRESS_STATE_FILE, "w", encoding="utf-8") as f:
                json.dump({
                    "pnl": self._pnl_tracker,
                    "last_block": self._last_block,
                    "updated": datetime.now().isoformat(),
                }, f, indent=2, ensure_ascii=False)
        except Exception:
            pass

    # ─── Etherscan API ──────────────────────────────────────
    def _fetch_txlist(self, address: str, start_block: int = 0) -> list:
        """获取地址最近交易列表"""
        params = {
            "chainid": 1,
            "module": "account",
            "action": "txlist",
            "address": address,
            "startblock": max(0, start_block - 500),
            "endblock": 99999999,
            "page": 1,
            "offset": 20,
            "sort": "desc",
            "apikey": self.etherscan_key,
        }
        try:
            r = http_get("https://api.etherscan.io/v2/api", params=params, timeout=10)
            d = r.json()
            if d.get("status") == "1":
                return d.get("result", [])
            elif d.get("message") == "No transactions found":
                return []
            logger.debug(f"[CopyTrader] Etherscan error for {address[:10]}: {d.get('message')}")
            return []
        except Exception as e:
            logger.debug(f"[CopyTrader] API error for {address[:10]}: {e}")
            return []

    def _fetch_eth_balance(self, address: str) -> float:
        """获取地址 ETH 余额"""
        params = {
            "chainid": 1,
            "module": "account",
            "action": "balance",
            "address": address,
            "apikey": self.etherscan_key,
        }
        try:
            r = http_get("https://api.etherscan.io/v2/api", params=params, timeout=8)
            d = r.json()
            if d.get("status") == "1":
                return int(d.get("result", "0")) / 1e18
        except Exception:
            pass
        return 0

    # ─── 交易解析 ───────────────────────────────────────────
    def _classify_tx(self, tx: dict) -> dict:
        """
        解析单笔交易，返回:
          {"action": "BUY"|"SELL"|"TRANSFER"|"UNKNOWN",
           "value_eth": float,  # 涉及 ETH 金额
           "protocol": str}      # DEX/协议名称
        """
        to_addr = tx.get("to", "").lower()
        input_data = tx.get("input", "0x")
        value_eth = int(tx.get("value", "0")) / 1e18
        method_sig = input_data[:10] if len(input_data) >= 10 else "0x"

        # 1. 判断是否 DEX swap
        if to_addr in DEX_ROUTERS:
            protocol = DEX_ROUTERS[to_addr]
            if method_sig in SWAP_SIGS_SELL_ETH:
                return {"action": "SELL", "value_eth": value_eth, "protocol": protocol}
            if method_sig in SWAP_SIGS_BUY_ETH:
                return {"action": "BUY", "value_eth": value_eth, "protocol": protocol}
            # 通用 swap → 看 input data 里的 path 来判断方向
            if method_sig in GENERIC_SWAP_SIGS:
                direction = self._parse_swap_path(input_data, to_addr)
                if direction:
                    return {"action": direction, "value_eth": value_eth, "protocol": protocol}
            # 其他函数签名 → 可能是 swap，标记为 UNKNOWN
            if method_sig != "0x" and len(input_data) > 10:
                return {"action": "UNKNOWN_SWAP", "value_eth": value_eth, "protocol": protocol}

        # 2. 判断是否交易所存款
        if to_addr in EXCHANGE_DEPOSIT and value_eth > 0:
            return {"action": "SELL", "value_eth": value_eth,
                    "protocol": f"CEX Deposit ({EXCHANGE_DEPOSIT[to_addr]})"}

        # 3. 大额 ETH 转账
        if value_eth > 1 and method_sig == "0x":
            if to_addr in EXCHANGE_DEPOSIT:
                return {"action": "SELL", "value_eth": value_eth, "protocol": "CEX Deposit"}
            return {"action": "TRANSFER", "value_eth": value_eth, "protocol": "Transfer"}

        return {"action": "UNKNOWN", "value_eth": value_eth, "protocol": "N/A"}

    def _parse_swap_path(self, input_data: str, router: str) -> str:
        """解析 Uniswap swap 的 path 参数来判断方向"""
        try:
            # Uniswap path: 前20字节=token_in, 中间20字节=token_out (简化版)
            data = input_data[10:]  # 去掉函数签名
            # path offset 在参数中的位置取决于具体函数
            # swapExactTokensForTokens: [amountIn, amountOutMin, path_offset, ...]
            # 简化：检查 input data 中是否 WETH→Stable 或 Stable→WETH
            data_lower = data.lower()
            if WETH in data_lower:
                idx = data_lower.index(WETH)
                # 检查 WETH 前后的 token
                before = data_lower[max(0, idx - 80):idx]
                after = data_lower[idx + 40:idx + 120]
                for sa, name in STABLES.items():
                    if sa in before:
                        return "BUY"  # Stable → WETH = 买入 ETH
                    if sa in after:
                        return "SELL"  # WETH → Stable = 卖出 ETH
        except Exception:
            pass
        return ""

    # ─── 主轮询 ─────────────────────────────────────────────
    def poll(self) -> list:
        """
        轮询所有跟踪地址的最新交易，返回信号列表。
        每个信号: {"addr": str, "label": str, "action": "BUY"|"SELL",
                  "value_eth": float, "weight": float, "ts": float}
        """
        signals = []
        now = time.time()

        for entry in self.addresses:
            addr = entry["addr"].lower()
            label = entry.get("label", addr[:10])
            weight = entry.get("weight", 1.0)
            last_block = int(self._last_block.get(addr, 0))

            txs = self._fetch_txlist(addr, last_block)
            if not txs:
                continue

            # 更新最新区块
            new_max_block = max(int(t.get("blockNumber", 0)) for t in txs)
            if new_max_block > last_block:
                self._last_block[addr] = str(new_max_block)

            # 只处理最近 TRADE_WINDOW_S 内的交易
            for tx in txs:
                tx_ts = int(tx.get("timeStamp", 0))
                if tx_ts < now - TRADE_WINDOW_S:
                    continue

                tx_hash = tx.get("hash", "")
                if tx_hash in self._tx_log.get(addr, set()):
                    continue

                cls = self._classify_tx(tx)
                if cls["action"] in ("BUY", "SELL"):
                    sig = {
                        "addr": addr,
                        "label": label,
                        "action": cls["action"],
                        "value_eth": cls["value_eth"],
                        "protocol": cls["protocol"],
                        "weight": weight,
                        "ts": tx_ts,
                        "tx_hash": tx_hash,
                    }
                    signals.append(sig)
                    logger.info(
                        f"[CopyTrader] {label}({addr[:8]}..) {cls['action']} "
                        f"{cls['value_eth']:.2f}ETH via {cls['protocol']}"
                    )

                    # 记录到日志
                    if addr not in self._tx_log:
                        self._tx_log[addr] = set()
                    self._tx_log[addr].add(tx_hash)
                    # 限制日志大小
                    if len(self._tx_log[addr]) > 1000:
                        self._tx_log[addr] = set(list(self._tx_log[addr])[-500:])

        self._save_state()
        return signals

    def get_score(self, signals: list) -> dict:
        """
        将信号列表转换为跟单评分。
        返回: {"score": float [-1,+1], "detail": str, "confidence": float}
        """
        if not signals:
            return {"score": 0.0, "detail": "无跟单信号", "confidence": 0.0}

        buy_weight = 0.0
        sell_weight = 0.0
        details = []

        for sig in signals:
            addr = sig["addr"]
            pnl_info = self._pnl_tracker.get(addr, {"calls": 0, "wins": 0})
            # 置信度 = 该地址历史胜率 (贝叶斯平滑: 默认 0.5)
            confidence = 0.5
            if pnl_info.get("calls", 0) >= 3:
                confidence = (pnl_info["wins"] + 1) / (pnl_info["calls"] + 2)
            effective_weight = sig["weight"] * confidence * (1 + min(sig["value_eth"], 10) / 20)

            if sig["action"] == "BUY":
                buy_weight += effective_weight
            else:
                sell_weight += effective_weight

            details.append(
                f"{sig['label']}({sig['addr'][:6]}..) {sig['action']} "
                f"{sig['value_eth']:.1f}ETH conf={confidence:.0%}"
            )

        total = buy_weight + sell_weight
        score = (buy_weight - sell_weight) / total if total > 0 else 0
        score = max(-1, min(1, score))

        return {
            "score": round(score, 3),
            "detail": " | ".join(details),
            "confidence": round(min(total / max(len(signals), 1), 1.0), 2),
            "buy_weight": round(buy_weight, 3),
            "sell_weight": round(sell_weight, 3),
        }

    def update_pnl(self, addr: str, action: str, entry_price: float, close_price: float):
        """
        跟单结束后更新地址 PnL 追踪。
        action='BUY' 时: 价格上涨=win, action='SELL' 时: 价格下跌=win
        """
        if addr not in self._pnl_tracker:
            self._pnl_tracker[addr] = {"calls": 0, "wins": 0, "total_pnl_eth": 0.0}
        p = self._pnl_tracker[addr]
        p["calls"] += 1
        if action == "BUY":
            pnl = close_price - entry_price
        else:
            pnl = entry_price - close_price
        p["total_pnl_eth"] += pnl
        if pnl > 0:
            p["wins"] += 1
        self._save_state()

    def get_top_addresses(self, n: int = 2) -> list:
        """返回当前胜率最高的 n 个地址"""
        ranked = sorted(
            self._pnl_tracker.items(),
            key=lambda x: x[1]["wins"] / max(x[1]["calls"], 1),
            reverse=True,
        )
        return [
            {"addr": addr, "win_rate": round(info["wins"] / max(info["calls"], 1), 3),
             "calls": info["calls"], "total_pnl": round(info.get("total_pnl_eth", 0), 3)}
            for addr, info in ranked[:n]
        ]
