# -*- coding: utf-8 -*-
"""
⚠️ 过拟合警告：本模块已改为默认 dry-run 模式。
   基于小样本（<200笔）的胜率统计不可靠，每日调优=追认噪声。
   如需实际修改权重，必须显式传 --apply。
   建议：冻结权重至少1个月，用回测而非实盘验证因子有效性。

weight_optimizer.py  —  因子权重分析工具
运行方式：
  python weight_optimizer.py            # 仅分析（dry-run），不修改文件
  python weight_optimizer.py --apply    # ⚠️ 实际修改config.py并重启trader

输出：
  weight_reports/YYYY-MM-DD.md          # 每日调优报告
  weight_reports/weight_history.json    # 权重历史记录
"""

import re
import os
import sys
import json
import time
import logging
import argparse
import subprocess
from datetime import datetime, timedelta
from collections import defaultdict

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, BASE_DIR)
from notify.telegram import tg_send

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger("WeightOptimizer")

# ── 路径配置 ──────────────────────────────────────────────────
LOG_FILE    = os.path.join(BASE_DIR, "my_trader.log")
CONFIG_FILE = os.path.join(BASE_DIR, "config.py")
REPORT_DIR  = os.path.join(BASE_DIR, "weight_reports")
HISTORY_FILE= os.path.join(REPORT_DIR, "weight_history.json")

# ── 因子键名 ──────────────────────────────────────────────────
FK = ["TR","OB","TK","OI","FR","MP","VD","BTC","GM","IV","EX","LC","MR","SM"]
# config.py 中对应的变量名
CONFIG_KEYS = {
    "TR":  "W_TREND",      "OB":  "W_ORDERBOOK",  "TK":  "W_TAKER",
    "OI":  "W_OI_CHANGE",  "FR":  "W_FUNDING",    "MP":  "W_MAXPAIN",
    "VD":  "W_VOL_DELTA",  "BTC": "W_BTC_CORR",   "GM":  "W_GAMMA",
    "IV":  "W_IV",         "EX":  "W_TAKER_EXHAUST","LC": "W_LIQ_COOLDOWN",
    "MR":  "W_MEAN_REVERT","SM":  "W_SMART_MONEY",
    "RT":  "W_RETAIL_POSITION","MM":"W_MM_POSITION",
    "LT":  "W_LIQ_TRIGGER","TX":"W_TOXIC_FLOW",
}

# 因子分类：先验（可预测未来）vs 后验（反映当前状态）
# 只对先验指标用胜率调权；后验指标维持不变
# 因子分类：先验（可预测未来）vs 后验（反映当前状态）
# 只对先验指标用胜率调权；后验指标维持不变
PRIOR_FACTORS    = {"TR","OB","OI","FR","BTC","GM","MP","IV"}   # 先验：有预测价值
POSTERIOR_FACTORS= {"TK","VD","EX","LC","MR","SM","NEWS"}       # 后验：反映已发生状态

# 权重调整限制
MIN_W     = 0.00   # 最低权重
MIN_W     = 0.00   # 最低权重
MAX_W     = 0.20   # 单因子最高权重
MAX_CHG   = 0.015  # 单次最大变化幅度（保守：每次最多±1.5%）
SMOOTH_K  = 0.3    # 新信号权重（0~1，越小越保守）；历史锚定权重 = 1 - SMOOTH_K
MIN_WIN_RATE_SAMPLES = 15  # 样本不足时不调整
LOOKBACK_DAYS = 7  # 分析过去N天的交易

# ── 正则 ──────────────────────────────────────────────────────
factor_re = re.compile(
    r"(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}).*"
    r"TR:([+-][0-9.]+) OB:([+-][0-9.]+) FR:([+-][0-9.]+) TK:([+-][0-9.]+) "
    r"OI:([+-][0-9.]+) MP:([+-][0-9.]+) VD:([+-][0-9.]+) BTC:([+-][0-9.]+) "
    r"GM:([+-][0-9.]+) IV:([+-][0-9.]+) EX:([+-][0-9.]+) LC:([+-][0-9.]+) "
    r"MR:([+-][0-9.]+) SM:([+-][0-9.]+).*=> ([+-][0-9.]+) -> (LONG|SHORT|WAIT)"
)
open_re  = re.compile(r"(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}).*触发 (LONG|SHORT).*执行开仓")
pos_re   = re.compile(r"(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}).*\[持仓\] ETH (long|short) [0-9.]+张 @ \$[0-9.]+ P.L:\$([+-][0-9.]+)")


def ts2sec(ts):
    return datetime.strptime(ts, "%Y-%m-%d %H:%M:%S").timestamp()


def read_current_weights():
    """从 config.py 读取当前权重"""
    weights = {}
    content = open(CONFIG_FILE, encoding="utf-8").read()
    for k, var in CONFIG_KEYS.items():
        m = re.search(rf"^{var}\s*=\s*([0-9.]+)", content, re.MULTILINE)
        weights[k] = float(m.group(1)) if m else 0.0
    return weights


def parse_trades(lookback_days=7):
    """解析日志，返回交易列表，每笔含 factors + pnl_final"""
    cutoff = (datetime.now() - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
    logger.info(f"解析 {lookback_days} 天内交易（{cutoff} 以后）...")

    lines = open(LOG_FILE, encoding="utf-8", errors="replace").readlines()
    all_factors = []
    all_opens   = []
    all_pos     = []

    for line in lines:
        if line[:10] < cutoff:
            continue
        mf = factor_re.search(line)
        if mf:
            g = mf.groups()
            fdict = {k: float(v) for k, v in
                     zip(["TR","OB","FR","TK","OI","MP","VD","BTC","GM","IV","EX","LC","MR","SM"],
                         g[1:15])}
            all_factors.append((g[0], fdict, float(g[15]), g[16]))
        mo = open_re.search(line)
        if mo:
            all_opens.append((mo.group(1), mo.group(2)))
        mp = pos_re.search(line)
        if mp:
            all_pos.append((mp.group(1), float(mp.group(3))))

    logger.info(f"  因子行={len(all_factors)}  开仓={len(all_opens)}  持仓快照={len(all_pos)}")

    trades = []
    for open_ts, direction in all_opens:
        ots = ts2sec(open_ts)
        # 配对最近因子
        best = None
        for (fts, fd, sc, di) in all_factors:
            ft = ts2sec(fts)
            if -90 < ft - ots < 10:
                best = (fts, fd, sc, di)
        if not best:
            continue
        # 开仓后最终 PnL（取开仓后5分钟到1小时内的最后一条持仓记录）
        pnl_final = None
        for (pts, pnl) in all_pos:
            pt = ts2sec(pts)
            if 60 < pt - ots < 3600:
                pnl_final = pnl
        if pnl_final is None:
            continue
        trades.append({
            "ts":        open_ts,
            "direction": direction,
            "factors":   best[1],
            "score":     best[2],
            "pnl":       pnl_final,
        })

    logger.info(f"  配对成功交易: {len(trades)} 笔")
    return trades


def calc_win_rates(trades):
    """计算每个因子的胜率和平均绝对值"""
    stats = {k: {"correct":0,"total":0,"sum_abs":0.0,"sum_pnl_correct":0.0,"sum_pnl_wrong":0.0}
             for k in FK}

    for t in trades:
        won = t["pnl"] > 0
        for k in FK:
            v = t["factors"].get(k, 0)
            if abs(v) < 0.05:
                continue
            # 因子预测方向（OB反向）
            pred_long = (v > 0) if k != "OB" else (v < 0)
            is_long   = t["direction"] == "LONG"
            correct   = (pred_long == is_long) == won
            stats[k]["total"]    += 1
            stats[k]["sum_abs"]  += abs(v)
            if correct:
                stats[k]["correct"] += 1
                stats[k]["sum_pnl_correct"] += abs(t["pnl"])
            else:
                stats[k]["sum_pnl_wrong"] += abs(t["pnl"])

    results = {}
    for k in FK:
        s = stats[k]
        n = s["total"]
        results[k] = {
            "win_rate":  s["correct"]/n if n > 0 else 0.5,
            "avg_abs":   s["sum_abs"]/n if n > 0 else 0.0,
            "n":         n,
        }
    return results


def suggest_weights(win_rates, current_weights):
    """
    权重调整策略：
      - 先验因子（OB/FR/BTC/GM/OI/TR/IV/MP）：根据历史胜率调整
      - 后验因子（TK/VD/EX/LC/MR/NEWS/SM）：维持不变，胜率对其无意义
      - 样本不足(< MIN_WIN_RATE_SAMPLES) → 不调整
      - 胜率噪声区 [45%, 55%] → 不调整
      - 单次最大变化 ≤ MAX_CHG（±1.5%）
      - 贝叶斯平滑：历史权重锚定 70%，新信号 30%
    """
    suggested = {}
    for k in FK:
        wr  = win_rates[k]["win_rate"]
        n   = win_rates[k]["n"]
        cur = current_weights.get(k, 0.0)

        # 后验指标 → 直接保持不变
        if k in POSTERIOR_FACTORS:
            suggested[k] = cur
            continue

        # 先验指标：样本不足 → 不调整
        if n < MIN_WIN_RATE_SAMPLES:
            suggested[k] = cur
            continue

        # 先验指标：噪声区 → 不调整
        if 0.45 <= wr <= 0.55:
            suggested[k] = cur
            continue

        # 先验指标：胜率映射到调整幅度
        score = (wr - 0.5) * 2.0   # [-1, 1]
        if cur == 0.0 and score > 0:
            target = 0.02           # 零权重但有效 → 小幅激活
        elif cur == 0.0:
            target = 0.0
        else:
            target = cur * (1.0 + score * 0.30)
        target = max(MIN_W, min(MAX_W, target))

        # 贝叶斯平滑 + 步长限制
        blended = cur * (1.0 - SMOOTH_K) + target * SMOOTH_K
        delta   = max(-MAX_CHG, min(MAX_CHG, blended - cur))
        suggested[k] = round(cur + delta, 4)

    # NEWS 权重 = 1 - sum(其他)，变化幅度同样限制
    news_cur = current_weights.get("NEWS", 0.15)
    news_raw = round(1.0 - sum(suggested.values()), 4)
    news_delta = max(-MAX_CHG, min(MAX_CHG, news_raw - news_cur))
    news_new   = max(0.08, min(0.25, round(news_cur + news_delta, 4)))

    # 补偿误差到 LC（后验，本身不参与胜率调整，用作余量缓冲）
    err = round(1.0 - sum(suggested.values()) - news_new, 4)
    suggested["LC"] = round(suggested.get("LC", 0) + err, 4)
    suggested["NEWS"] = news_new
    return suggested


def update_config(suggested, dry_run=False):
    """更新 config.py 中的权重"""
    content = open(CONFIG_FILE, encoding="utf-8").read()
    for k, var in CONFIG_KEYS.items():
        new_val = suggested.get(k, 0.0)
        content = re.sub(
            rf"^({var}\s*=\s*)[0-9.]+",
            lambda m, v=new_val: f"{m.group(1)}{v}",
            content, flags=re.MULTILINE
        )
    if dry_run:
        logger.info("[dry-run] config.py 不修改")
        return
    open(CONFIG_FILE, "w", encoding="utf-8").write(content)
    logger.info("config.py 已更新")


def restart_trader():
    """重启 trader 进程"""
    import psutil
    for proc in psutil.process_iter(["pid","name","cmdline"]):
        try:
            cmdline = " ".join(proc.info["cmdline"] or [])
            if "my_trader.py" in cmdline and "python" in proc.info["name"].lower():
                proc.kill()
                logger.info(f"已终止旧 trader pid={proc.info['pid']}")
                time.sleep(2)
                break
        except Exception:
            pass
    subprocess.Popen(
        ["powershell", "-WindowStyle", "Minimized", "-Command",
         f"cd '{BASE_DIR}'; python my_trader.py"],
        creationflags=subprocess.CREATE_NEW_CONSOLE
    )
    logger.info("已重启 trader")


def explain_change(k, wr, n, cur, new):
    """生成每个因子权重变化的原因说明"""
    chg = new - cur

    # 后验指标不参与胜率调整
    if k in POSTERIOR_FACTORS:
        if abs(chg) < 0.001:
            return "后验指标（反映已发生状态），不依据胜率调整"
        else:
            return f"后验指标，误差补偿 {chg:+.4f}"

    # 先验指标
    if abs(chg) < 0.001:
        if n < MIN_WIN_RATE_SAMPLES:
            return f"先验指标，样本不足({n}笔<{MIN_WIN_RATE_SAMPLES})，维持"
        if 0.45 <= wr <= 0.55:
            return f"先验指标，胜率{wr*100:.0f}%在噪声区[45%,55%]，维持"
        return "先验指标，变化极小"

    direction = "提权" if chg > 0 else "降权"
    if wr >= 0.70:   quality = "强正效应"
    elif wr >= 0.60: quality = "正效应"
    elif wr <= 0.30: quality = "强负效应"
    elif wr <= 0.40: quality = "负效应"
    else:            quality = "弱效应"
    return (f"先验指标，胜率{wr*100:.0f}%(n={n})，{quality}，"
            f"{direction}{abs(chg):.4f}"
            f"（锚定{int((1-SMOOTH_K)*100)}%历史+步长±{MAX_CHG}）")


def save_report(today, trades, win_rates, current_weights, suggested, dry_run=False):
    """
    报告系统：
      1. weight_reports/YYYY-MM-DD_HH.md  → 单次调优详细报告（追加同天多次）
      2. weight_reports/CHANGELOG.md       → 追加写入，记录每次变化+原因
      3. weight_reports/weight_history.json→ 结构化历史，支持图表
    """
    os.makedirs(REPORT_DIR, exist_ok=True)
    now_str    = datetime.now().strftime("%H:%M:%S")
    now_full   = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    report_path = os.path.join(REPORT_DIR, f"{today}.md")
    changelog_path = os.path.join(REPORT_DIR, "CHANGELOG.md")

    win_cnt  = sum(1 for t in trades if t["pnl"] > 0.001)
    loss_cnt = sum(1 for t in trades if t["pnl"] < -0.001)
    flat_cnt = len(trades) - win_cnt - loss_cnt
    total_pnl = sum(t["pnl"] for t in trades)

    # ── 1. 单次详细报告（追加，同天多次调优都记录）──────────────────
    sep = "\n\n---\n"
    block = [
        f"\n## 调优记录 {now_full}{' [dry-run]' if dry_run else ''}",
        f"分析交易 **{len(trades)}** 笔 | 盈{win_cnt}/亏{loss_cnt}/平{flat_cnt} | 总PnL: **{total_pnl:+.4f}**",
        "",
        "| 因子 | 类型 | 样本 | 胜率 | 当前权重 | 新权重 | 变化 | 原因 |",
        "|------|------|------|------|---------|--------|------|------|",
    ]
    for k in FK + ["NEWS"]:
        wr_info = win_rates.get(k, {"win_rate": 0.5, "avg_abs": 0.0, "n": 0})
        wr  = wr_info["win_rate"]
        n   = wr_info["n"]
        cur = current_weights.get(k, 0.0)
        new = suggested.get(k, cur)
        chg = new - cur
        chg_s  = f"`{chg:+.4f}`"
        arrow  = "⬆️" if chg > 0.001 else ("⬇️" if chg < -0.001 else "—")
        wr_s   = f"{wr*100:.0f}%" if n >= MIN_WIN_RATE_SAMPLES else f"≈{wr*100:.0f}%(n={n})"
        reason = explain_change(k, wr, n, cur, new)
        ftype  = "🔭先验" if k in PRIOR_FACTORS else "📊后验"
        block.append(f"| **{k}** | {ftype} | {n} | {wr_s} | {cur:.4f} | {new:.4f} | {chg_s} {arrow} | {reason} |")

    block.append(f"\n权重合计: **{sum(suggested.values()):.4f}**")

    # 文件存在则追加，不存在则写标题
    if not os.path.exists(report_path):
        header = f"# 权重调优报告 {today}\n"
        with open(report_path, "w", encoding="utf-8") as f:
            f.write(header)

    with open(report_path, "a", encoding="utf-8") as f:
        f.write(sep + "\n".join(block))
    logger.info(f"详细报告已追加: {report_path}")

    # ── 2. CHANGELOG.md 追加（简洁变化日志）──────────────────────────
    changed = [(k, current_weights.get(k,0), suggested.get(k,0))
               for k in FK+["NEWS"] if abs(suggested.get(k,0)-current_weights.get(k,0)) >= 0.001]

    changelog_block = [
        f"\n### {now_full}{' [dry-run]' if dry_run else ''}"
        f"  ·  {len(trades)}笔交易  ·  PnL:{total_pnl:+.4f}",
    ]
    if changed:
        for k, cur, new in changed:
            wr     = win_rates.get(k, {}).get("win_rate", 0.5)
            n      = win_rates.get(k, {}).get("n", 0)
            chg    = new - cur
            arrow  = "⬆️" if chg > 0 else "⬇️"
            ftype  = "先验" if k in PRIOR_FACTORS else "后验"
            reason = explain_change(k, wr, n, cur, new)
            changelog_block.append(
                f"- {arrow} **{k}**({ftype}): {cur:.4f} → {new:.4f}  _{reason}_"
            )
    else:
        changelog_block.append("- 所有因子权重无显著变化（均在噪声区或样本不足）")

    if not os.path.exists(changelog_path):
        with open(changelog_path, "w", encoding="utf-8") as f:
            f.write("# 权重变化日志 (CHANGELOG)\n\n每次调优的变化记录，含原因说明。\n")

    with open(changelog_path, "a", encoding="utf-8") as f:
        f.write("\n" + "\n".join(changelog_block))
    logger.info(f"CHANGELOG 已追加: {changelog_path}")

    # ── 3. weight_history.json 追加（图表数据源）─────────────────────
    history = {}
    if os.path.exists(HISTORY_FILE):
        try:
            history = json.load(open(HISTORY_FILE, encoding="utf-8"))
        except Exception:
            pass

    # 每个时间戳一条记录（支持同天多次）
    ts_key = datetime.now().strftime("%Y-%m-%d %H:%M")
    history[ts_key] = {
        "date":      today,
        "time":      now_str,
        "dry_run":   dry_run,
        "trades":    len(trades),
        "win_rate_overall": round(win_cnt / len(trades), 3) if trades else 0,
        "total_pnl": round(total_pnl, 4),
        "weights":   {k: suggested.get(k, 0) for k in FK + ["NEWS"]},
        "win_rates": {k: {
            "wr": round(win_rates.get(k, {}).get("win_rate", 0.5), 3),
            "n":  win_rates.get(k, {}).get("n", 0),
        } for k in FK},
        "changes": {k: round(suggested.get(k,0) - current_weights.get(k,0), 4)
                    for k in FK + ["NEWS"]
                    if abs(suggested.get(k,0) - current_weights.get(k,0)) >= 0.0005},
    }

    with open(HISTORY_FILE, "w", encoding="utf-8") as f:
        json.dump(history, f, ensure_ascii=False, indent=2)
    logger.info(f"历史记录已更新: {HISTORY_FILE}  (共{len(history)}条)")
    return report_path, changelog_path


def build_tg_msg(today, trades, win_rates, current_weights, suggested):
    lines = [
        f"📊 <b>权重调优完成 {today}</b>",
        f"分析 <b>{len(trades)}</b> 笔交易（过去{LOOKBACK_DAYS}天）",
        f"总PnL: <b>{sum(t['pnl'] for t in trades):+.4f}</b>",
        "",
        "🔄 <b>权重变化（仅显示变化项）:</b>",
    ]
    for k in FK + ["NEWS"]:
        cur = current_weights.get(k, 0.0)
        new = suggested.get(k, cur)
        chg = new - cur
        if abs(chg) > 0.002:
            arrow = "⬆️" if chg > 0 else "⬇️"
            wr = win_rates.get(k, {}).get("win_rate", 0.5)
            n  = win_rates.get(k, {}).get("n", 0)
            lines.append(f"  {arrow} {k}: {cur:.3f}→{new:.3f}  胜率{wr*100:.0f}%(n={n})")
    lines += [
        "",
        "⭐ <b>最佳因子 Top3:</b>",
    ]
    sorted_wr = sorted([(k, win_rates.get(k,{}).get("win_rate",0.5), win_rates.get(k,{}).get("n",0))
                        for k in FK if win_rates.get(k,{}).get("n",0)>=10],
                       key=lambda x: -x[1])[:3]
    for k, wr, n in sorted_wr:
        lines.append(f"  {k}: 胜率{wr*100:.0f}% (n={n})")
    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--apply", action="store_true",
                        help="⚠️ 实际修改config.py（默认仅dry-run，防止过拟合）")
    parser.add_argument("--days", type=int, default=LOOKBACK_DAYS)
    parser.add_argument("--no-restart", action="store_true")
    parser.add_argument("--no-tg", action="store_true")
    args = parser.parse_args()
    # dry_run 默认True，需显式 --apply 才会实际修改权重
    args.dry_run = not args.apply

    today = datetime.now().strftime("%Y-%m-%d")
    logger.info(f"=== 权重调优开始 {today} ===")

    current_weights = read_current_weights()
    logger.info(f"当前权重: {current_weights}")

    trades = parse_trades(lookback_days=args.days)
    if len(trades) < 5:
        msg = f"[权重调优] 交易样本不足({len(trades)}笔)，跳过本次调优"
        logger.warning(msg)
        if not args.no_tg:
            tg_send(msg)
        return

    win_rates = calc_win_rates(trades)
    suggested = suggest_weights(win_rates, current_weights)

    # 打印对比
    logger.info("=" * 50)
    logger.info(f"{'因子':6} {'胜率':8} {'n':5} {'当前':8} {'建议':8} {'变化'}")
    for k in FK + ["NEWS"]:
        wr = win_rates.get(k, {}).get("win_rate", 0.5)
        n  = win_rates.get(k, {}).get("n", 0)
        cur = current_weights.get(k, 0)
        new = suggested.get(k, cur)
        logger.info(f"  {k:5} {wr*100:5.1f}%  {n:4}  {cur:.4f}  {new:.4f}  {new-cur:+.4f}")
    logger.info(f"  权重合计: {sum(suggested.values()):.4f}")
    logger.info("=" * 50)

    report_path, changelog_path = save_report(
        today, trades, win_rates, current_weights, suggested, dry_run=args.dry_run
    )

    if not args.dry_run:
        update_config(suggested)
        if not args.no_restart:
            try:
                restart_trader()
            except ImportError:
                logger.warning("psutil未安装，跳过自动重启。请手动重启trader。")

    if not args.no_tg:
        msg = build_tg_msg(today, trades, win_rates, current_weights, suggested)
        if args.dry_run:
            msg = "[dry-run]\n" + msg
        msg += f"\n\n📝 报告: weight_reports/{today}.md\n📋 日志: weight_reports/CHANGELOG.md"
        tg_send(msg)
        logger.info("TG通知已发送")

    logger.info(f"=== 调优完成 | 报告: {report_path} | 变化日志: {changelog_path} ===")

    # 自动生成权重图表
    try:
        import tools.weight_chart as weight_chart
        chart_path = weight_chart.main()
        logger.info(f"图表已生成: {chart_path}")
    except Exception as e:
        logger.warning(f"图表生成失败: {e}")


if __name__ == "__main__":
    main()
