# -*- coding: utf-8 -*-
"""
PaperExecutor — 模拟账户执行器
完全本地模拟，不发送任何真实订单。复刻 OrderExecutor 接口。
"""
import time
import math
import logging
import threading
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger("MyTrader")


@dataclass
class PaperPosition:
    """模拟持仓"""
    symbol: str
    direction: str           # LONG / SHORT
    entry_price: float
    size: float              # 张数
    leverage: int = 10
    tp_price: float = 0
    sl_price: float = 0
    opened_at: float = field(default_factory=time.time)


@dataclass
class PaperTrade:
    """模拟成交记录"""
    symbol: str
    action: str              # OPEN_LONG / OPEN_SHORT / CLOSE
    price: float
    size: float
    pnl: float = 0
    pnl_pct: float = 0
    reason: str = ""
    ts: float = field(default_factory=time.time)


class PaperExecutor:
    """模拟账户执行器"""

    def __init__(self, okx_client, instruments: dict, initial_equity: float = 10000,
                 default_leverage: int = 10, maker_fee: float = 0.0002, taker_fee: float = 0.0005):
        self.okx = okx_client          # 仅用于获取行情
        self.instruments = instruments
        self.leverage = default_leverage
        self.maker_fee = maker_fee
        self.taker_fee = taker_fee

        # 账户状态
        self.initial_equity = initial_equity
        self.equity = initial_equity
        self.available = initial_equity
        self.peak_equity = initial_equity
        self.realized_pnl = 0.0
        self.total_fees = 0.0

        # 持仓
        self.positions: dict[str, dict] = {}  # {symbol: {long: PaperPosition, short: PaperPosition}}
        self.position_history = deque(maxlen=500)
        self.trades: list[PaperTrade] = []

        # 订单
        self.pending_orders: list[dict] = []
        self._daily_trade_count = defaultdict(int)
        self._last_trade_ts = {}

        # 统计
        self.total_trades = 0
        self.winning_trades = 0
        self.losing_trades = 0
        self.consecutive_losses = 0
        self.max_drawdown_pct = 0.0
        self._equity_history = deque(maxlen=2000)
        self._equity_history.append((time.time(), initial_equity))

        # 风控
        self.max_daily_trades = 50
        self.min_trade_interval = 10  # 秒
        self.trading_enabled = True

        # 当前活跃策略名 (由 runner 设置，用于 trade_log)
        self.active_strategy = "manual"

        logger.info(f"[Paper] 模拟账户初始化 权益=${initial_equity:.2f}")

    # ═══════════════════════════════════════════════════════════
    #  行情代理（从 OKX 获取真实价格）
    # ═══════════════════════════════════════════════════════════

    def ticker_price(self, symbol: str) -> float:
        try:
            cfg = self._inst_cfg(symbol)
            ticker = self.okx.ticker(cfg["inst"])
            return float(ticker.get("last", 0))
        except Exception:
            return 0

    def candles_ohlc(self, symbol: str, bar: str = "1H", limit: int = 100) -> list:
        try:
            cfg = self._inst_cfg(symbol)
            data = self.okx.candles(cfg["inst"], bar=bar, limit=limit)
            result = []
            for c in reversed(data):
                result.append({
                    "ts": int(c[0]), "open": float(c[1]), "high": float(c[2]),
                    "low": float(c[3]), "close": float(c[4]), "vol": float(c[5]),
                })
            return result
        except Exception:
            return []

    # ═══════════════════════════════════════════════════════════
    #  下单接口（复刻 OrderExecutor）
    # ═══════════════════════════════════════════════════════════

    def market_open(self, symbol: str, direction: str, size: float,
                    tp_price: float = 0, sl_price: float = 0) -> bool:
        """市价开仓"""
        if not self.trading_enabled:
            logger.warning("[Paper] 交易已暂停")
            return False
        if self.has_position(symbol):
            logger.debug(f"[Paper] {symbol} 已有持仓，跳过开仓")
            return False
        if not self.check_daily_limit(symbol):
            logger.warning(f"[Paper] {symbol} 当日交易次数超限")
            return False
        if not self.check_cooldown(symbol):
            return False

        price = self.ticker_price(symbol)
        if price <= 0:
            logger.warning(f"[Paper] {symbol} 无法获取价格")
            return False

        direction = direction.upper()
        pos_key = "long" if direction == "LONG" else "short"

        # 计算保证金和手续费
        margin = price * size * 0.01 / self.leverage  # 合约面值0.01 * 张数 / 杠杆
        fee = price * size * 0.01 * self.taker_fee

        if margin + fee > self.available:
            logger.warning(f"[Paper] 资金不足: 需要{margin+fee:.2f} 可用{self.available:.2f}")
            return False

        # 扣除资金
        self.available -= margin + fee
        self.total_fees += fee

        # 创建模拟持仓
        pos = PaperPosition(
            symbol=symbol, direction=direction, entry_price=price,
            size=size, leverage=self.leverage,
            tp_price=tp_price, sl_price=sl_price,
        )
        if symbol not in self.positions:
            self.positions[symbol] = {}
        self.positions[symbol][pos_key] = pos

        # 记录成交
        trade = PaperTrade(
            symbol=symbol, action=f"OPEN_{direction}", price=price,
            size=size, reason="market_open",
        )
        self.trades.append(trade)
        self.total_trades += 1
        self._daily_trade_count[symbol] += 1
        self._last_trade_ts[symbol] = time.time()

        logger.info(f"[Paper] 开{direction} {symbol} {size}张 @{price:.2f} margin={margin:.2f}")
        return True

    def limit_open(self, symbol: str, direction: str, size: float,
                   price: float, tp_price: float = 0, sl_price: float = 0) -> bool:
        """限价开仓 — 模拟中直接按指定价成交"""
        # 记录但不立即成交，加入挂单队列
        self.pending_orders.append({
            "type": "limit", "symbol": symbol, "direction": direction.upper(),
            "size": size, "price": price, "tp": tp_price, "sl": sl_price,
            "ts": time.time(),
        })
        logger.info(f"[Paper] 限价挂单 {direction} {symbol} {size}张 @{price:.2f}")
        return True

    def market_close(self, symbol: str, pos_side: str) -> bool:
        """市价平仓"""
        if symbol not in self.positions:
            return False
        pos = self.positions[symbol].get(pos_side)
        if not pos or pos.size <= 0:
            return False

        price = self.ticker_price(symbol)
        if price <= 0:
            return False

        return self._close_position(pos, price, "manual_close")

    def scale_in(self, symbol: str, direction: str, size: float,
                 tp_pct: float = 0.02, sl_pct: float = 0.01) -> bool:
        """加仓"""
        price = self.ticker_price(symbol)
        if price <= 0:
            return False
        direction = direction.upper()
        is_long = direction == "LONG"
        tp = price * (1 + tp_pct) if is_long else price * (1 - tp_pct)
        sl = price * (1 - sl_pct) if is_long else price * (1 + sl_pct)
        return self.market_open(symbol, direction, size, tp_price=tp, sl_price=sl)

    # ═══════════════════════════════════════════════════════════
    #  持仓管理
    # ═══════════════════════════════════════════════════════════

    def get_position(self, symbol: str) -> dict:
        result = {"long": None, "short": None}
        if symbol in self.positions:
            for side in ("long", "short"):
                pos = self.positions[symbol].get(side)
                if pos and pos.size > 0:
                    price = self.ticker_price(symbol)
                    upl = self._calc_upl(pos, price) if price > 0 else 0
                    result[side] = {
                        "qty": pos.size,
                        "entry": pos.entry_price,
                        "upl": round(upl, 4),
                        "margin": pos.entry_price * pos.size * 0.01 / self.leverage,
                        "lever": pos.leverage,
                    }
        return result

    def has_position(self, symbol: str) -> bool:
        pos = self.get_position(symbol)
        return (pos["long"] is not None) or (pos["short"] is not None)

    def get_net_position(self, symbol: str) -> float:
        pos = self.get_position(symbol)
        net = 0
        if pos["long"]:
            net += pos["long"]["qty"]
        if pos["short"]:
            net -= pos["short"]["qty"]
        return net

    # ═══════════════════════════════════════════════════════════
    #  定期更新（每周期调用）
    # ═══════════════════════════════════════════════════════════

    def update(self):
        """更新持仓盈亏、检查止盈止损"""
        now = time.time()

        # 刷新权益 = 可用 + 所有持仓价值
        total_upl = 0
        for sym, sides in list(self.positions.items()):
            price = self.ticker_price(sym)
            if price <= 0:
                continue
            for side, pos in list(sides.items()):
                if not pos or pos.size <= 0:
                    continue
                upl = self._calc_upl(pos, price)
                total_upl += upl

                # 止盈检查
                if pos.tp_price > 0:
                    if (pos.direction == "LONG" and price >= pos.tp_price) or \
                       (pos.direction == "SHORT" and price <= pos.tp_price):
                        self._close_position(pos, price, "take_profit")

                # 止损检查
                elif pos.sl_price > 0:
                    if (pos.direction == "LONG" and price <= pos.sl_price) or \
                       (pos.direction == "SHORT" and price >= pos.sl_price):
                        self._close_position(pos, price, "stop_loss")

        # 更新权益
        self.equity = self.available + total_upl
        self.peak_equity = max(self.peak_equity, self.equity)
        dd = (self.peak_equity - self.equity) / self.peak_equity if self.peak_equity > 0 else 0
        self.max_drawdown_pct = max(self.max_drawdown_pct, dd)
        self._equity_history.append((now, self.equity))

        # 检查限价单是否成交
        self._check_limit_orders()

    def _check_limit_orders(self):
        """检查挂单是否触发"""
        price_cache = {}
        filled = []
        for i, order in enumerate(self.pending_orders):
            sym = order["symbol"]
            if sym not in price_cache:
                price_cache[sym] = self.ticker_price(sym)
            current = price_cache[sym]
            if current <= 0:
                continue

            triggered = False
            if order["direction"] == "LONG" and current <= order["price"]:
                triggered = True  # 买单：价格跌到挂单价
            elif order["direction"] == "SHORT" and current >= order["price"]:
                triggered = True  # 卖单：价格涨到挂单价

            if triggered:
                filled.append(i)
                self.market_open(order["symbol"], order["direction"],
                                order["size"], order["tp"], order["sl"])

        for i in reversed(filled):
            self.pending_orders.pop(i)

    # ═══════════════════════════════════════════════════════════
    #  风控
    # ═══════════════════════════════════════════════════════════

    def check_daily_limit(self, symbol: str) -> bool:
        return self._daily_trade_count.get(symbol, 0) < self.max_daily_trades

    def check_cooldown(self, symbol: str) -> bool:
        last = self._last_trade_ts.get(symbol, 0)
        return time.time() - last >= self.min_trade_interval

    def get_drawdown_pct(self) -> float:
        return self.max_drawdown_pct

    # ═══════════════════════════════════════════════════════════
    #  统计查询
    # ═══════════════════════════════════════════════════════════

    def get_summary(self) -> dict:
        active_count = 0
        for sym, sides in self.positions.items():
            for pos in sides.values():
                if pos and pos.size > 0:
                    active_count += 1

        win_rate = self.winning_trades / max(self.total_trades, 1)

        return {
            "mode": "paper",
            "equity": round(self.equity, 2),
            "available": round(self.available, 2),
            "initial": self.initial_equity,
            "realized_pnl": round(self.realized_pnl, 4),
            "total_fees": round(self.total_fees, 4),
            "total_trades": self.total_trades,
            "win_rate": round(win_rate, 3),
            "consecutive_losses": self.consecutive_losses,
            "max_drawdown": f"{self.max_drawdown_pct*100:.2f}%",
            "active_positions": active_count,
            "pending_orders": len(self.pending_orders),
            "trading_enabled": self.trading_enabled,
        }

    def get_trade_history(self, n: int = 20) -> list:
        return [
            {"symbol": t.symbol, "action": t.action, "price": t.price,
             "size": t.size, "pnl": t.pnl, "reason": t.reason,
             "ts": time.strftime("%m-%d %H:%M:%S", time.localtime(t.ts))}
            for t in self.trades[-n:]
        ]

    def reset(self, equity: float = 10000):
        """重置账户"""
        self.__init__(self.okx, self.instruments, equity, self.leverage)
        logger.info(f"[Paper] 账户已重置 权益=${equity}")

    # ═══════════════════════════════════════════════════════════
    #  内部方法
    # ═══════════════════════════════════════════════════════════

    def _calc_upl(self, pos: PaperPosition, price: float) -> float:
        if pos.direction == "LONG":
            return (price - pos.entry_price) * pos.size * 0.01
        else:
            return (pos.entry_price - price) * pos.size * 0.01

    def _close_position(self, pos: PaperPosition, price: float, reason: str) -> bool:
        upl = self._calc_upl(pos, price)
        fee = price * pos.size * 0.01 * self.taker_fee
        pnl = upl - fee

        # 更新账户
        margin = pos.entry_price * pos.size * 0.01 / self.leverage
        self.available += margin + pnl
        self.realized_pnl += pnl
        self.total_fees += fee
        self.equity = self.available

        if pnl > 0:
            self.winning_trades += 1
            self.consecutive_losses = 0
        else:
            self.losing_trades += 1
            self.consecutive_losses += 1

        # 写入跨库交易日志 (macro_decision.trade_log)
        try:
            from storage.mysql_client import insert_trade_log
            insert_trade_log(
                strategy=self.active_strategy, symbol=pos.symbol,
                side=pos.direction, size=pos.size,
                entry=pos.entry_price, exit_price=price, pnl=round(pnl, 4),
                reason=reason, status='closed',
            )
        except Exception:
            pass

        # 记录
        trade = PaperTrade(
            symbol=pos.symbol, action="CLOSE", price=price,
            size=pos.size, pnl=round(pnl, 4),
            pnl_pct=round(pnl / (pos.entry_price * pos.size * 0.01) * 100, 4),
            reason=reason,
        )
        self.trades.append(trade)
        self.total_trades += 1

        # 清除持仓
        pos_key = "long" if pos.direction == "LONG" else "short"
        if pos.symbol in self.positions:
            self.positions[pos.symbol].pop(pos_key, None)

        logger.info(f"[Paper] 平{pos.direction} {pos.symbol} {pos.size}张 @{price:.2f} "
                    f"盈亏={pnl:.4f} ({reason})")
        return True

    def _inst_cfg(self, symbol: str) -> dict:
        cfg = self.instruments.get(symbol.upper())
        if not cfg:
            raise ValueError(f"Unknown instrument: {symbol}")
        return cfg

    def cancel_all_orders(self, symbol: str = None):
        """撤销所有挂单"""
        if symbol:
            self.pending_orders = [o for o in self.pending_orders if o["symbol"] != symbol]
        else:
            self.pending_orders.clear()
