# -*- coding: utf-8 -*-
"""
OrderExecutor — 统一订单执行接口
支持: 市价/限价 开仓/平仓, 止盈止损, 仓位查询, 批量操作
"""
import time
import logging
from collections import defaultdict

logger = logging.getLogger("MyTrader")


class OrderResult:
    """下单结果"""
    def __init__(self, success, ord_id="", detail=""):
        self.success = success
        self.ord_id = ord_id
        self.detail = detail

    def __bool__(self):
        return self.success

    def __repr__(self):
        return f"OrderResult({'OK' if self.success else 'FAIL'}, {self.ord_id}, {self.detail})"


class OrderExecutor:
    """统一订单执行器"""

    def __init__(self, okx, instruments: dict, default_leverage: int = 10):
        self.okx = okx
        self.instruments = instruments
        self.leverage = default_leverage
        self._daily_count = defaultdict(int)  # {inst_id: count}
        self._last_order_ts = {}  # {inst_id: timestamp}
        self._daily_pnl = 0.0  # 当日累计盈亏

    # ── 基础下单 ──────────────────────────────────

    def market_open(self, symbol: str, direction: str, size: float,
                    tp_price: float = 0, sl_price: float = 0) -> OrderResult:
        """市价开仓"""
        cfg = self._inst_cfg(symbol)
        inst_id = cfg["inst"]
        pos_side = "long" if direction.upper() == "LONG" else "short"
        side = "buy" if pos_side == "long" else "sell"

        try:
            # 设置杠杆
            self.okx.set_leverage(inst_id, self.leverage, pos_side)

            # 市价单 + 可选TP/SL
            tp = tp_price if tp_price > 0 else None
            sl = sl_price if sl_price > 0 else None
            r = self.okx.place_order(inst_id, side, pos_side, size, tp=tp, sl=sl)

            if r.get("code") == "0":
                data = r.get("data", [{}])[0]
                ord_id = data.get("ordId", "")
                self._daily_count[symbol] += 1
                self._last_order_ts[symbol] = time.time()
                logger.info(f"[Order] {direction} {symbol} {size}张 @market → {ord_id}")
                return OrderResult(True, ord_id, f"market {direction} {size}张")
            else:
                msg = r.get("msg", "unknown")
                logger.error(f"[Order] {direction} {symbol} 失败: {msg}")
                return OrderResult(False, detail=msg)
        except Exception as e:
            logger.error(f"[Order] {symbol} 异常: {e}")
            return OrderResult(False, detail=str(e)[:200])

    def limit_open(self, symbol: str, direction: str, size: float,
                   price: float, tp_price: float = 0, sl_price: float = 0) -> OrderResult:
        """限价开仓"""
        cfg = self._inst_cfg(symbol)
        inst_id = cfg["inst"]
        pos_side = "long" if direction.upper() == "LONG" else "short"
        side = "buy" if pos_side == "long" else "sell"

        try:
            self.okx.set_leverage(inst_id, self.leverage, pos_side)

            # 限价单通过 POST 构造不同 ordType
            r = self.okx.post("/api/v5/trade/order", {
                "instId": inst_id, "tdMode": "isolated",
                "side": side, "posSide": pos_side,
                "ordType": "limit", "sz": str(size),
                "px": str(round(price, 2)),
            })
            if r.get("code") == "0":
                data = r.get("data", [{}])[0]
                ord_id = data.get("ordId", "")
                self._daily_count[symbol] += 1
                self._last_order_ts[symbol] = time.time()
                logger.info(f"[Order] {direction} {symbol} {size}张 @{price} limit → {ord_id}")
                return OrderResult(True, ord_id, f"limit {direction} {size}张 @{price}")

                # 附带TP/SL algo单
                if tp_price or sl_price:
                    self.okx.place_algo_tpsl(inst_id, pos_side, size, tp=tp_price, sl=sl_price)
            else:
                return OrderResult(False, detail=r.get("msg", ""))
        except Exception as e:
            logger.error(f"[Order] limit {symbol} 异常: {e}")
            return OrderResult(False, detail=str(e)[:200])

    def market_close(self, symbol: str, pos_side: str) -> OrderResult:
        """市价平仓指定方向"""
        cfg = self._inst_cfg(symbol)
        inst_id = cfg["inst"]
        try:
            r = self.okx.close_position(inst_id, pos_side)
            if r.get("code") == "0":
                logger.info(f"[Order] 平仓 {symbol} {pos_side}")
                return OrderResult(True, detail=f"close {symbol} {pos_side}")
            return OrderResult(False, detail=r.get("msg", ""))
        except Exception as e:
            logger.error(f"[Order] close {symbol} 异常: {e}")
            return OrderResult(False, detail=str(e)[:200])

    def scale_in(self, symbol: str, direction: str, size: float, tp_pct: float = 0.02, sl_pct: float = 0.01) -> OrderResult:
        """加仓（市价，按当前价格计算TP/SL百分比）"""
        try:
            cfg = self._inst_cfg(symbol)
            ticker = self.okx.ticker(cfg["inst"])
            price = float(ticker.get("last", 0))
            if not price:
                return OrderResult(False, detail="价格获取失败")

            is_long = direction.upper() == "LONG"
            tp_price = price * (1 + tp_pct) if is_long else price * (1 - tp_pct)
            sl_price = price * (1 - sl_pct) if is_long else price * (1 + sl_pct)
            return self.market_open(symbol, direction, size, tp_price, sl_price)
        except Exception as e:
            return OrderResult(False, detail=str(e)[:200])

    # ── 仓位管理 ──────────────────────────────────

    def get_position(self, symbol: str) -> dict:
        """获取指定symbol的持仓"""
        cfg = self._inst_cfg(symbol)
        positions = self.okx.positions(cfg["inst"])
        result = {"long": None, "short": None}
        for p in positions:
            ps = p.get("posSide", "")
            if ps == "long":
                result["long"] = {
                    "qty": float(p.get("pos", 0)),
                    "entry": float(p.get("avgPx", 0)),
                    "upl": float(p.get("upl", 0)),
                    "margin": float(p.get("margin", 0)),
                    "lever": float(p.get("lever", 10)),
                }
            elif ps == "short":
                result["short"] = {
                    "qty": float(p.get("pos", 0)),
                    "entry": float(p.get("avgPx", 0)),
                    "upl": float(p.get("upl", 0)),
                    "margin": float(p.get("margin", 0)),
                    "lever": float(p.get("lever", 10)),
                }
        return result

    def has_position(self, symbol: str) -> bool:
        """是否有任何持仓"""
        pos = self.get_position(symbol)
        return (pos["long"] and pos["long"]["qty"] > 0) or (pos["short"] and pos["short"]["qty"] > 0)

    def get_net_position(self, symbol: str) -> float:
        """净持仓量（正=多头，负=空头）"""
        pos = self.get_position(symbol)
        net = 0
        if pos["long"] and pos["long"]["qty"] > 0:
            net += pos["long"]["qty"]
        if pos["short"] and pos["short"]["qty"] > 0:
            net -= pos["short"]["qty"]
        return net

    # ── 风控 ──────────────────────────────────────

    def check_daily_limit(self, symbol: str, max_trades: int = 20) -> bool:
        """检查当日交易次数是否超限"""
        return self._daily_count.get(symbol, 0) < max_trades

    def check_cooldown(self, symbol: str, cooldown_sec: float = 30) -> bool:
        """检查冷却时间"""
        last = self._last_order_ts.get(symbol, 0)
        return time.time() - last >= cooldown_sec

    # ── 工具 ──────────────────────────────────────

    def _inst_cfg(self, symbol: str) -> dict:
        """获取交易对配置（大小写不敏感）"""
        cfg = self.instruments.get(symbol.upper())
        if not cfg:
            # 模糊匹配
            for k, v in self.instruments.items():
                if k.upper() == symbol.upper():
                    return v
            raise ValueError(f"Unknown instrument: {symbol} (available: {list(self.instruments.keys())})")
        return cfg

    def ticker_price(self, symbol: str) -> float:
        """获取当前价格"""
        cfg = self._inst_cfg(symbol)
        ticker = self.okx.ticker(cfg["inst"])
        return float(ticker.get("last", 0))

    def candles_ohlc(self, symbol: str, bar: str = "1H", limit: int = 100) -> list:
        """获取K线数据 [open, high, low, close, vol]"""
        cfg = self._inst_cfg(symbol)
        data = self.okx.candles(cfg["inst"], bar=bar, limit=limit)
        # OKX返回: [ts, open, high, low, close, vol, ...]
        result = []
        for c in reversed(data):
            result.append({
                "ts": int(c[0]), "open": float(c[1]), "high": float(c[2]),
                "low": float(c[3]), "close": float(c[4]), "vol": float(c[5]),
            })
        return result

    def cancel_all_orders(self, symbol: str):
        """撤销所有挂单"""
        cfg = self._inst_cfg(symbol)
        try:
            self.okx.post("/api/v5/trade/cancel-all-orders", {"instId": cfg["inst"]})
        except Exception:
            pass
