# -*- coding: utf-8 -*-
"""
⚠️ 过拟合警告：基于快照PnL相关性调权，样本极少（<50笔）时完全不可靠。
   建议冻结权重，此模块仅作分析参考，不再自动更新 config_weights.json。

weight_updater.py — 因子权重自更新模块
原理：从 trade_snapshots.json 读取历史开仓快照，
      计算每个因子与最终盈亏的相关性，按相关性调整权重。
调用方式：
  from weight_updater import WeightUpdater
  updater = WeightUpdater()
  new_weights = updater.update()  # 返回更新后的权重dict，同时写入config_weights.json
"""

import json, os, math, logging
from datetime import datetime

logger = logging.getLogger("WeightUpdater")

SNAP_FILE    = "trade_snapshots.json"
WEIGHT_FILE  = "config_weights.json"   # 持久化权重文件
UPDATE_LOG   = "weight_update.log"     # 更新日志

# 初始权重（从config.py同步）
BASE_WEIGHTS = {
    "TR": 0.15, "OB": 0.06, "TK": 0.07, "OI": 0.01,
    "FR": 0.02, "MP": 0.02, "VD": 0.06, "BTC": 0.07,
    "GM": 0.06, "IV": 0.01, "EX": 0.07, "LC": 0.15,
    "MR": 0.06, "NEWS": 0.19, "SM": 0.00,
}
# 权重上下界
W_MIN = 0.01
W_MAX = 0.20
# 学习率（每次调整幅度）
LR = 0.15
# 最少需要的样本数
MIN_SAMPLES = 5


def _load_snapshots():
    """读取交易快照，提取因子值和盈亏"""
    try:
        snaps = json.load(open(SNAP_FILE, encoding='utf-8'))
    except:
        return []

    records = []
    factor_map = {
        'trend': 'TR', 'orderbook': 'OB', 'taker': 'TK', 'oi': 'OI',
        'funding': 'FR', 'maxpain': 'MP', 'vol_delta': 'VD', 'btc_corr': 'BTC',
        'gamma': 'GM', 'iv': 'IV', 'exhaust': 'EX', 'liq_cool': 'LC',
        'mean_revert': 'MR', 'news': 'NEWS', 'smart_money': 'SM',
    }

    for i, s in enumerate(snaps):
        factors = s.get('factors', {})
        price   = s.get('price', 0)
        action  = s.get('action', '')
        tp      = s.get('tp', 0)
        sl      = s.get('sl', 0)
        score   = s.get('score', 0)

        if not price or not action:
            continue

        # 计算盈亏符号：用下一笔快照的价格估算（或用TP/SL方向）
        # 简化：score方向与action一致=1，否则=-1（后续可接真实PnL）
        dir_sign = 1 if action == 'LONG' else -1
        score_aligned = 1 if score * dir_sign > 0 else -1

        # 标准化因子值
        fv = {}
        for long_name, short_name in factor_map.items():
            v = factors.get(long_name, factors.get(short_name, 0))
            if v is None:
                v = 0
            fv[short_name] = float(v)

        records.append({
            'factors': fv,
            'action': action,
            'score': score,
            'dir_sign': dir_sign,
            'score_aligned': score_aligned,
            'price': price,
            'tp': tp,
            'sl': sl,
        })

    return records


def _estimate_pnl(records):
    """
    用相邻快照估算盈亏：
    对每条开仓记录，找下一条不同方向/平仓的快照，计算价格变化×方向=盈亏符号
    返回每条记录附加 pnl_sign: +1(盈) / -1(亏) / 0(未知)
    """
    for i, r in enumerate(records):
        r['pnl_sign'] = 0
        price_in  = r['price']
        dir_sign  = r['dir_sign']
        tp, sl    = r['tp'], r['sl']

        if tp > 0 and sl > 0 and price_in > 0:
            # 用TP/SL距离估算期望盈亏符号：盈亏比>1认为正期望
            tp_dist = abs(tp - price_in)
            sl_dist = abs(sl - price_in)
            r['pnl_sign'] = 1 if tp_dist > sl_dist else -1

        # 如果有下一条快照，用实际价格变化
        if i + 1 < len(records):
            next_price = records[i+1]['price']
            if next_price and price_in:
                raw_pnl = (next_price - price_in) * dir_sign
                r['pnl_sign'] = 1 if raw_pnl > 0 else (-1 if raw_pnl < 0 else 0)

    return records


def _compute_correlations(records):
    """计算每个因子与盈亏的方向相关性"""
    keys = list(BASE_WEIGHTS.keys())
    corr = {k: 0.0 for k in keys}
    counts = {k: 0 for k in keys}

    for r in records:
        pnl = r.get('pnl_sign', 0)
        if pnl == 0:
            continue
        for k in keys:
            fv = r['factors'].get(k, 0)
            dir_sign = r['dir_sign']
            # 因子与开仓方向一致性：fv * dir_sign > 0 → 同向
            # 盈利且同向 → 正相关；亏损且同向 → 负相关
            if abs(fv) > 0.05:
                alignment = 1 if fv * dir_sign > 0 else -1
                corr[k]   += alignment * pnl
                counts[k] += 1

    # 归一化
    for k in keys:
        if counts[k] > 0:
            corr[k] /= counts[k]

    return corr, counts


def _update_weights(current_weights, correlations, counts):
    """基于相关性调整权重"""
    new_weights = dict(current_weights)
    adjustments = {}

    total_adj = 0
    for k in BASE_WEIGHTS:
        if counts.get(k, 0) < MIN_SAMPLES:
            adjustments[k] = 0
            continue
        c = correlations.get(k, 0)
        # 相关性越高，权重越大；负相关则降权重
        delta = LR * c * current_weights.get(k, BASE_WEIGHTS[k])
        new_weights[k] = max(W_MIN, min(W_MAX, current_weights.get(k, BASE_WEIGHTS[k]) + delta))
        adjustments[k] = delta

    # SM固定为0
    new_weights['SM'] = 0.00

    # 归一化到总和=1
    total = sum(new_weights.values())
    if total > 0:
        for k in new_weights:
            if k != 'SM':
                new_weights[k] = round(new_weights[k] / total, 4)

    return new_weights, adjustments


def load_weights():
    """加载当前权重（优先用持久化文件，否则用BASE_WEIGHTS）"""
    if os.path.exists(WEIGHT_FILE):
        try:
            data = json.load(open(WEIGHT_FILE, encoding='utf-8'))
            w = data.get('weights', BASE_WEIGHTS)
            logger.info(f"[权重] 从 {WEIGHT_FILE} 加载权重 (版本{data.get('version',0)})")
            return w
        except:
            pass
    return dict(BASE_WEIGHTS)


def save_weights(weights, meta=None):
    """持久化保存权重"""
    try:
        existing = {}
        if os.path.exists(WEIGHT_FILE):
            existing = json.load(open(WEIGHT_FILE, encoding='utf-8'))
        version = existing.get('version', 0) + 1
        data = {
            'weights': weights,
            'version': version,
            'updated': datetime.now().isoformat(),
            'meta': meta or {},
        }
        json.dump(data, open(WEIGHT_FILE, 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
        logger.info(f"[权重] 已保存 v{version}")
        return version
    except Exception as e:
        logger.warning(f"[权重] 保存失败: {e}")
        return 0


def append_update_log(entry):
    """追加权重更新日志"""
    try:
        with open(UPDATE_LOG, 'a', encoding='utf-8') as f:
            f.write(entry + '\n')
    except:
        pass


class WeightUpdater:
    def __init__(self):
        self.current_weights = load_weights()

    def update(self, force=False):
        """
        执行一次权重更新。
        - 至少需要 MIN_SAMPLES 条有效快照
        - 返回新权重dict
        - 同时写入 config_weights.json 和 weight_update.log
        """
        records = _load_snapshots()
        if len(records) < MIN_SAMPLES:
            logger.info(f"[权重] 样本不足({len(records)}<{MIN_SAMPLES})，跳过更新")
            return self.current_weights

        records = _estimate_pnl(records)
        corr, counts = _compute_correlations(records)
        new_weights, adjustments = _update_weights(self.current_weights, corr, counts)

        # 生成日志
        lines = [
            f"\n{'='*60}",
            f"权重更新 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
            f"样本数: {len(records)}  有效样本: {sum(1 for r in records if r.get('pnl_sign')!=0)}",
            f"{'因子':<8} {'旧权重':>8} {'新权重':>8} {'调整':>8} {'相关性':>8} {'样本':>6}",
            '-' * 56,
        ]
        for k in sorted(BASE_WEIGHTS.keys()):
            old_w = self.current_weights.get(k, BASE_WEIGHTS[k])
            new_w = new_weights.get(k, old_w)
            adj   = adjustments.get(k, 0)
            c     = corr.get(k, 0)
            n     = counts.get(k, 0)
            marker = '⬆' if adj > 0.001 else ('⬇' if adj < -0.001 else ' ')
            lines.append(f"{k:<8} {old_w:>8.4f} {new_w:>8.4f} {adj:>+8.4f} {c:>+8.3f} {n:>6} {marker}")

        log_entry = '\n'.join(lines)
        logger.info(log_entry)
        append_update_log(log_entry)

        version = save_weights(new_weights, meta={
            'samples': len(records),
            'correlations': {k: round(v, 3) for k, v in corr.items()},
        })

        self.current_weights = new_weights
        return new_weights

    def get_weights(self):
        return self.current_weights

    def report(self):
        """返回当前权重的简要字符串"""
        w = self.current_weights
        return ' '.join(f"{k}:{w.get(k,0):.3f}" for k in BASE_WEIGHTS if k != 'SM')


# ── 独立运行时直接输出报告
if __name__ == '__main__':
    import sys
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s %(levelname)s: %(message)s')
    updater = WeightUpdater()
    print("当前权重:", updater.report())
    print("执行更新...")
    new_w = updater.update()
    print("更新后权重:", updater.report())
    print(f"\n详细日志已写入 {UPDATE_LOG}")
    print(f"权重已保存到 {WEIGHT_FILE}")
