# -*- coding: utf-8 -*-
"""
MySQL 存储层 — 连接池 + CRUD
"""
import json
import time
import logging
from contextlib import contextmanager
from datetime import datetime
from config import DB_ENABLED

import pymysql
from pymysql.cursors import DictCursor
from dbutils.pooled_db import PooledDB

logger = logging.getLogger("MyTrader")

_pool = None


def init_pool(host, port, user, password, database, pool_size=3):
    """初始化连接池（模块加载时调用一次）"""
    global _pool
    if _pool:
        return
    _pool = PooledDB(
        creator=pymysql,
        maxconnections=pool_size + 2,
        mincached=2,
        maxcached=pool_size,
        blocking=True,
        host=host,
        port=port,
        user=user,
        password=password,
        database=database,
        charset='utf8mb4',
        cursorclass=DictCursor,
        connect_timeout=5,
    )
    # 测试连接
    with _get_conn() as conn:
        conn.ping()
    logger.info(f"[MySQL] 连接池已初始化 {host}:{port}/{database}")


def _get_conn():
    return _pool.connection()


@contextmanager
def get_cursor():
    """获取游标的上下文管理器，自动提交/回滚"""
    conn = _get_conn()
    try:
        yield conn.cursor()
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        conn.close()


# ═══════════════════════════════════════════════════════════
# Trade Records
# ═══════════════════════════════════════════════════════════
def insert_trade(symbol, direction, entry_price, size, leverage, tp_price, sl_price,
                 entry_score, equity_before, is_bot=True):
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO trades (symbol, direction, entry_price, size, leverage,
                  tp_price, sl_price, entry_score, equity_before, is_bot, opened_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
        """, (symbol, direction, entry_price, size, leverage,
              tp_price, sl_price, entry_score, equity_before, is_bot,
              datetime.utcnow()))
        return cur.lastrowid


def close_trade(trade_id, exit_price, pnl, pnl_pct, exit_reason, equity_after):
    with get_cursor() as cur:
        cur.execute("""
            UPDATE trades SET exit_price=%s, pnl=%s, pnl_pct=%s,
                   exit_reason=%s, equity_after=%s, closed_at=%s
            WHERE id=%s
        """, (exit_price, pnl, pnl_pct, exit_reason, equity_after,
              datetime.utcnow(), trade_id))


def get_open_trades():
    with get_cursor() as cur:
        cur.execute("SELECT * FROM trades WHERE closed_at IS NULL ORDER BY opened_at DESC")
        return cur.fetchall()


def get_recent_trades(limit=50):
    with get_cursor() as cur:
        cur.execute("SELECT * FROM trades ORDER BY opened_at DESC LIMIT %s", (limit,))
        return cur.fetchall()


# ═══════════════════════════════════════════════════════════
# Factor Snapshots
# ═══════════════════════════════════════════════════════════
def insert_factor_snapshot(symbol, price, direction, total_score, scores, raw):
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO factor_snapshots (
                symbol, price, direction, total_score,
                trend, orderbook, funding, taker, oi, maxpain,
                vol_delta, btc_corr, gamma, iv, exhaust, liq_cool,
                mean_revert, smart_money, mtf, ob_liq, low_lev, liq_ex,
                retail, mm, liq_trigger, toxic,
                funding_rate, t1h, t4h, momentum_adj, flip_penalty,
                pre_filter, snapshot_at
            ) VALUES (
                %s,%s,%s,%s,
                %s,%s,%s,%s,%s,%s,
                %s,%s,%s,%s,%s,%s,
                %s,%s,%s,%s,%s,%s,%s,
                %s,%s,%s,%s,
                %s,%s,%s,%s,%s,
                %s,%s
            )
        """, (
            symbol, price, direction, total_score,
            scores.get('trend'), scores.get('orderbook'), scores.get('funding'),
            scores.get('taker'), scores.get('oi'), scores.get('maxpain'),
            scores.get('vol_delta'), scores.get('btc_corr'), scores.get('gamma'),
            scores.get('iv'), scores.get('exhaust'), scores.get('liq_cool'),
            scores.get('mean_revert'), scores.get('smart_money'),
            scores.get('mtf', 0), scores.get('ob_liq', 0),
            scores.get('low_lev', 0), scores.get('liq_ex', 0),
            scores.get('retail', 0), scores.get('mm', 0),
            scores.get('liq_trigger', 0), scores.get('toxic', 0),
            raw.get('fr_raw'), raw.get('t1h'), raw.get('t4h'),
            raw.get('momentum_adj'), raw.get('flip_penalty'),
            raw.get('pre_filter', {}).get('score') if isinstance(raw.get('pre_filter'), dict) else None,
            datetime.utcnow(),
        ))


def get_recent_snapshots(symbol=None, limit=100):
    with get_cursor() as cur:
        if symbol:
            cur.execute("SELECT * FROM factor_snapshots WHERE symbol=%s ORDER BY snapshot_at DESC LIMIT %s",
                       (symbol, limit))
        else:
            cur.execute("SELECT * FROM factor_snapshots ORDER BY snapshot_at DESC LIMIT %s", (limit,))
        return cur.fetchall()


# ═══════════════════════════════════════════════════════════
# Equity History
# ═══════════════════════════════════════════════════════════
def insert_equity(equity, available, peak=None, drawdown_pct=None, margin_used=None, positions=0):
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO equity_history (equity, available, peak, drawdown_pct, margin_used, positions, recorded_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s)
        """, (equity, available, peak, drawdown_pct, margin_used, positions, datetime.utcnow()))


def get_equity_history(hours=24):
    with get_cursor() as cur:
        cur.execute("""
            SELECT * FROM equity_history
            WHERE recorded_at >= DATE_SUB(NOW(), INTERVAL %s HOUR)
            ORDER BY recorded_at ASC
        """, (hours,))
        return cur.fetchall()


# ═══════════════════════════════════════════════════════════
# Position Snapshots
# ═══════════════════════════════════════════════════════════
def insert_position_snapshot(positions):
    """批量插入持仓快照"""
    if not positions:
        return
    with get_cursor() as cur:
        now = datetime.utcnow()
        rows = [(
            p['symbol'], p['side'], p['entry'], p.get('last', p['entry']),
            p['qty'], int(p.get('lever', 10)), p['margin'], p['upl'],
            p.get('is_bot', 1), now
        ) for p in positions]
        cur.executemany("""
            INSERT INTO position_snapshots
                (symbol, direction, entry_price, mark_price, qty, leverage, margin, upl, is_bot, recorded_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
        """, rows)


# ═══════════════════════════════════════════════════════════
# System Events
# ═══════════════════════════════════════════════════════════
def insert_event(level, source, event_type, message, details=None):
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO system_events (level, source, event_type, message, details, created_at)
            VALUES (%s,%s,%s,%s,%s,%s)
        """, (level, source, event_type, message,
              json.dumps(details, ensure_ascii=False) if details else None,
              datetime.utcnow()))


def get_recent_events(limit=50, level=None):
    with get_cursor() as cur:
        if level:
            cur.execute("SELECT * FROM system_events WHERE level=%s ORDER BY created_at DESC LIMIT %s",
                       (level, limit))
        else:
            cur.execute("SELECT * FROM system_events ORDER BY created_at DESC LIMIT %s", (limit,))
        return cur.fetchall()


# ═══════════════════════════════════════════════════════════
# News Memory
# ═══════════════════════════════════════════════════════════
def insert_news(title, content="", impact="", score=None, source=None, price_at=None):
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO news_memory (title, content, impact, score, source, price_at, created_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s)
        """, (title, content, impact[:200], score, source, price_at, datetime.utcnow()))


def get_recent_news(hours=72):
    with get_cursor() as cur:
        cur.execute("""
            SELECT * FROM news_memory
            WHERE created_at >= DATE_SUB(NOW(), INTERVAL %s HOUR)
            ORDER BY created_at DESC
        """, (hours,))
        return cur.fetchall()


# ═══════════════════════════════════════════════════════════
# Raw Data (第一时间入库 — write-before-process)
# ═══════════════════════════════════════════════════════════

def insert_raw_ticker(symbol, price, bid=None, ask=None, volume_24h=None, source='okx_rest'):
    """价格快照 — 每个周期第一时间写入"""
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO raw_ticker (symbol, price, bid, ask, volume_24h, source, recorded_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s)
        """, (symbol, price, bid, ask, volume_24h, source, datetime.utcnow()))


def insert_raw_orderbook(symbol, best_bid, best_ask, bid_depth_5=None,
                         ask_depth_5=None, spread_pct=None, mid_price=None, source='okx_ws'):
    """订单簿快照"""
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO raw_orderbook (symbol, best_bid, best_ask, bid_depth_5,
                      ask_depth_5, spread_pct, mid_price, source, recorded_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s)
        """, (symbol, best_bid, best_ask, bid_depth_5, ask_depth_5,
              spread_pct, mid_price, source, datetime.utcnow()))


def insert_raw_candle(symbol, bar, ts, open_, high, low, close, volume, source='okx_rest'):
    """K线 — 使用 INSERT IGNORE 避免重复"""
    with get_cursor() as cur:
        cur.execute("""
            INSERT IGNORE INTO raw_candles (symbol, bar, ts, open, high, low, close, volume, source, recorded_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
        """, (symbol, bar, ts, open_, high, low, close, volume, source, datetime.utcnow()))


def insert_raw_event(event_type, symbol, severity=None, direction=None,
                     price=None, detail=None, metrics=None, source=None):
    """异动/有毒流/幌骗/信号 事件"""
    with get_cursor() as cur:
        cur.execute("""
            INSERT INTO raw_events (event_type, symbol, severity, direction,
                      price, detail, metrics, source, created_at)
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s)
        """, (event_type, symbol, severity, direction, price, detail,
              json.dumps(metrics, ensure_ascii=False) if metrics else None,
              source, datetime.utcnow()))


# ═══════════════════════════════════════════════════════════
# Cross-DB: 写入战略部 macro_decision 数据库
# ═══════════════════════════════════════════════════════════

_MACRO_POOL = None

def _get_macro_pool():
    """延迟初始化 macro_decision 连接池"""
    global _MACRO_POOL
    if _MACRO_POOL:
        return _MACRO_POOL
    from config import DB_HOST, DB_PORT, DB_USER, DB_PASSWORD
    _MACRO_POOL = PooledDB(
        creator=pymysql, maxconnections=3, mincached=1, maxcached=2, blocking=True,
        host=DB_HOST, port=DB_PORT, user=DB_USER, password=DB_PASSWORD,
        database='macro_decision', charset='utf8mb4', cursorclass=DictCursor, connect_timeout=5,
    )
    return _MACRO_POOL


def insert_trade_log(strategy: str, symbol: str, side: str, size: float,
                     entry: float, exit_price: float = None, pnl: float = None,
                     reason: str = ""):
    """
    交易部 → macro_decision.trade_log
    strategy: turtle/mean_revert/multi_factor/grid/trend/manual
    Note: table uses 'asset' not 'symbol', 'side' is 'long'/'short'
    """
    conn = _get_macro_pool().connection()
    try:
        with conn.cursor() as cur:
            cur.execute("""
                INSERT INTO trade_log (strategy, asset, side, size, entry_price,
                       exit_price, pnl, reason)
                VALUES (%s,%s,%s,%s,%s,%s,%s,%s)
            """, (strategy, symbol, side.lower(), size, entry, exit_price, pnl, reason))
        conn.commit()
        return True
    except Exception as e:
        logger.warning(f"[trade_log] 写入失败: {e}")
        conn.rollback()
        return False
    finally:
        conn.close()


def update_trade_log(trade_id: int, exit_price: float, pnl: float, reason: str):
    """更新 trade_log 的平仓信息"""
    conn = _get_macro_pool().connection()
    try:
        with conn.cursor() as cur:
            cur.execute("""
                UPDATE trade_log SET exit_price=%s, pnl=%s, reason=%s
                WHERE id=%s
            """, (exit_price, pnl, reason, trade_id))
        conn.commit()
    except Exception as e:
        logger.warning(f"[trade_log] 更新失败: {e}")
        conn.rollback()
    finally:
        conn.close()


def insert_flash_event(event_type: str, symbol: str, severity: str,
                       price: float, detail: str, metrics: dict = None):
    """
    交易部 → macro_decision.price_anomalies
    写入闪崩/暴拉/量暴增等异动事件
    event_type: flash_crash / flash_pump / vol_spike / oi_anomaly
    """
    direction = "up" if "pump" in event_type else "down"
    conn = _get_macro_pool().connection()
    try:
        change_pct = metrics.get("change_pct", 0) if metrics else 0
        with conn.cursor() as cur:
            cur.execute("""
                INSERT INTO price_anomalies
                    (asset, price_change_pct, current_price, direction, trigger_reason, volume_surge)
                VALUES (%s,%s,%s,%s,%s,%s)
            """, (symbol, abs(float(change_pct)), price, direction,
                  detail[:500], 1 if "vol_spike" in event_type else 0))
        conn.commit()
        return True
    except Exception as e:
        logger.warning(f"[price_anomalies] 写入失败: {e}")
        conn.rollback()
        return False
    finally:
        conn.close()
