# -*- coding: utf-8 -*-
"""
chart_snapshot.py — 生成近 N 小时主图 PNG（价格 + 总分双轴）
返回 bytes，直接用于 tg_send_photo()。

依赖: matplotlib（已是常见包，无需额外安装）
"""

import io
import os
import sys
from datetime import datetime, timedelta

_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, _ROOT)

try:
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    matplotlib.rcParams['font.family'] = 'Microsoft YaHei'
    matplotlib.rcParams['axes.unicode_minus'] = False
    _HAS_MPL = True
except ImportError:
    _HAS_MPL = False

from tools.log_parser import parse_rows, FK, FC, FN


def make_snapshot(hours=4, log_path=None) -> bytes:
    """
    生成最近 hours 小时的主图 PNG，返回 bytes。
    失败时返回空 bytes。
    """
    if not _HAS_MPL:
        return b''

    rows, events = parse_rows(log_path=log_path, max_rows=10000)
    if not rows:
        return b''

    # 过滤最近 hours 小时
    cutoff = datetime.now() - timedelta(hours=hours)
    cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S')
    rows = [r for r in rows if r['ts'] >= cutoff_str]
    if len(rows) < 3:
        return b''

    # 解析时间轴和数据
    times  = [datetime.strptime(r['ts'], '%Y-%m-%d %H:%M:%S') for r in rows]
    prices = [r['price'] for r in rows]
    scores = [r['total'] for r in rows]

    # 前向填充价格空缺
    last_p = None
    filled_prices = []
    for p in prices:
        if p is not None:
            last_p = p
        filled_prices.append(last_p)

    # ── 画布 ────────────────────────────────────────────────────
    fig, ax1 = plt.subplots(figsize=(12, 5))
    fig.patch.set_facecolor('#0d1117')
    ax1.set_facecolor('#0d1117')

    ax2 = ax1.twinx()  # 总分右轴

    # ── 价格线（金色，左轴）─────────────────────────────────────
    px_clean = [(t, p) for t, p in zip(times, filled_prices) if p is not None]
    if px_clean:
        tx, px = zip(*px_clean)
        ax1.plot(tx, px, color='#f0c040', linewidth=2.0, label='ETH 价格',
                 solid_capstyle='round', zorder=3)

    # ── 总分线（蓝色，右轴）─────────────────────────────────────
    sc_pairs = [(t, s) for t, s in zip(times, scores) if s is not None]
    if sc_pairs:
        ts_t, sc = zip(*sc_pairs)
        # 用颜色映射标注正负区域（背景色带，alpha极低不遮价格）
        ax2.fill_between(ts_t, -2, 2, where=[s > 0.25 for s in sc],
                         alpha=0.04, color='#44aaff', interpolate=True, zorder=0)
        ax2.fill_between(ts_t, -2, 2, where=[s < -0.25 for s in sc],
                         alpha=0.04, color='#ff6b6b', interpolate=True, zorder=0)
        ax2.plot(ts_t, sc, color='#44aaff', linewidth=2.0, label='总分',
                 solid_capstyle='round', zorder=4)

    # ── 总分阈值线 ───────────────────────────────────────────────
    ax2.axhline( 0.25, color='#4ade8044', linewidth=1, linestyle='--', zorder=2)
    ax2.axhline(-0.25, color='#f8717144', linewidth=1, linestyle='--', zorder=2)
    ax2.axhline( 0,    color='#333333',   linewidth=0.8, zorder=2)

    # ── 交易事件标注 ─────────────────────────────────────────────
    ev_cutoff = cutoff.strftime('%Y-%m-%d %H:%M:%S')
    for ev in events:
        if ev['ts'] < ev_cutoff:
            continue
        try:
            et = datetime.strptime(ev['ts'], '%Y-%m-%d %H:%M:%S')
        except Exception:
            continue
        if ev['type'] == 'open' and ev.get('price'):
            color = '#2ed573' if ev['dir'] == 'LONG' else '#ff4757'
            marker = '^' if ev['dir'] == 'LONG' else 'v'
            ax1.axvline(et, color=color + '55', linewidth=1, linestyle='-', zorder=2)
            ax1.scatter([et], [ev['price']], color=color, s=60,
                        marker=marker, zorder=5, edgecolors='#0d1117', linewidths=0.5)
        elif ev['type'] in ('tp', 'sl', 'close'):
            color = '#f0c040' if ev['type'] == 'tp' else '#ff6348'
            ax1.axvline(et, color=color + '44', linewidth=0.8, linestyle=':', zorder=2)

    # ── 轴样式 ──────────────────────────────────────────────────
    # 价格轴（左）
    vp = [p for p in filled_prices if p is not None]
    if vp:
        p_range = max(vp) - min(vp)
        pad = p_range * 0.08 or min(vp) * 0.001
        ax1.set_ylim(min(vp) - pad, max(vp) + pad)
    ax1.tick_params(axis='y', colors='#f0c04099', labelsize=9)
    ax1.tick_params(axis='x', colors='#555555',   labelsize=9)
    ax1.yaxis.label.set_color('#f0c04099')
    ax1.set_ylabel('价格 (USD)', color='#f0c04099', fontsize=9)
    ax1.spines[:].set_color('#21262d')
    ax1.yaxis.set_tick_params(which='both', colors='#f0c04099')

    # 总分轴（右）
    vs = [s for s in scores if s is not None]
    if vs:
        s_abs = max(abs(min(vs)), abs(max(vs)))
        s_lim = max(s_abs * 1.2, 0.40)
        ax2.set_ylim(-s_lim, s_lim)
    ax2.tick_params(axis='y', colors='#44aaff99', labelsize=9)
    ax2.set_ylabel('总分', color='#44aaff99', fontsize=9)
    ax2.spines[:].set_color('#21262d')

    # X 轴时间格式
    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
    ax1.xaxis.set_major_locator(mdates.AutoDateLocator())
    plt.setp(ax1.get_xticklabels(), rotation=0, ha='center')

    # 网格
    ax1.grid(True, color='#1c2128', linewidth=0.5, axis='both')

    # ── 图例（合并两轴）────────────────────────────────────────
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2,
               loc='upper left', fontsize=9, framealpha=0.3,
               facecolor='#161b22', edgecolor='#30363d',
               labelcolor='#c9d1d9')

    # ── 标题 ────────────────────────────────────────────────────
    now_str = datetime.now().strftime('%Y-%m-%d %H:%M')
    last_price = next((p for p in reversed(filled_prices) if p), 0)
    last_score = next((s for s in reversed(scores) if s is not None), 0)
    dir_str = '▲ LONG' if last_score > 0.25 else ('▼ SHORT' if last_score < -0.25 else '— WAIT')
    title = (f'ETH  ${last_price:,.2f}  总分 {last_score:+.3f}  {dir_str}'
             f'  ({hours}h  {now_str})')
    ax1.set_title(title, color='#c9d1d9', fontsize=10, pad=8)

    fig.tight_layout(pad=1.2)
    fig.subplots_adjust(right=0.88)  # 右侧留白给总分轴标签

    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=130, facecolor=fig.get_facecolor())
    plt.close(fig)
    buf.seek(0)
    return buf.read()
