# -*- coding: utf-8 -*-
"""
GoldAnalyzer — XAU 分析引擎
复用 BTC MarketAnalyzer 的评分引擎/动量/方向判断逻辑
"""
import math, time, logging
from collections import deque

from config import (
    ASSET_FACTOR_CONFIG,
    SCORE_THRESHOLD, MIN_FACTORS_AGREE, REQUIRE_TREND_ALIGN,
    PROXIES,
    logger as _root_logger,
)
from exchange.gold import (
    comex_price, comex_candles, comex_volume_oi, comex_oi_history,
    cot_positioning, gld_flow, macro_drivers, calc_rsi, calc_bb,
)

logger = logging.getLogger("MyTrader")


class GoldAnalyzer:
    """XAU 黄金期货分析器"""

    def __init__(self):
        cfg = ASSET_FACTOR_CONFIG.get("XAU", {})
        self.factor_cfg = cfg.get("factors", {})
        self.threshold = cfg.get("score_threshold", 0.40)
        self.stop_loss = cfg.get("stop_loss_pct", 0.003)
        self.vol_history = deque(maxlen=20)
        self.oi_history = deque(maxlen=20)
        self.score_history = deque(maxlen=30)
        self.factor_history = deque(maxlen=10)

    # ═══════════════════════════════════════
    # 趋势 (1H+4H+日线EMA三重确认)
    # ═══════════════════════════════════════
    def _trend_score(self):
        t1h, t4h, t1d = 0, 0, 0
        for bar, attr in [("1H", "t1h"), ("4H", "t4h"), ("1D", "t1d")]:
            candles = comex_candles(bar=bar, limit=20)
            if len(candles) < 10:
                continue
            closes = [c[4] for c in candles[:20]]  # c[4] = close
            ema5 = sum(closes[-5:]) / 5
            ema10 = sum(closes[-10:]) / 10 if len(closes) >= 10 else ema5
            s = 1 if ema5 > ema10 * 1.001 else (-1 if ema5 < ema10 * 0.999 else 0)
            if bar == "1H": t1h = s
            elif bar == "4H": t4h = s
            else: t1d = s

        # 三重确认: 3个同向=±1, 2个同向=±0.7, 单方向=±0.3
        signals = [t1h, t4h, t1d]
        pos = sum(1 for s in signals if s > 0)
        neg = sum(1 for s in signals if s < 0)
        if pos >= 3: score = 1.0
        elif neg >= 3: score = -1.0
        elif pos >= 2: score = 0.7
        elif neg >= 2: score = -0.7
        elif pos >= 1: score = 0.3
        elif neg >= 1: score = -0.3
        else: score = 0
        logger.info(f"[XAU Trend] 1H:{t1h:+d} 4H:{t4h:+d} 1D:{t1d:+d} => {score:+.1f}")
        return score, t1h, t4h, t1d

    # ═══════════════════════════════════════
    # 多周期共振 (15m/1H/4H)
    # ═══════════════════════════════════════
    def _mtf_score(self):
        scores = []
        for bar, wt in [("15m", 0.3), ("1H", 0.4), ("4H", 0.3)]:
            candles = comex_candles(bar=bar, limit=24)
            if len(candles) < 20: continue
            closes = [c[4] for c in candles]
            ma5 = sum(closes[-5:]) / 5
            ma10 = sum(closes[-10:]) / 10
            ma20 = sum(closes[-20:]) / 20
            # 均线多头排列 vs 空头排列
            if ma5 > ma10 > ma20: s = 1.0
            elif ma5 < ma10 < ma20: s = -1.0
            elif ma5 > ma10: s = 0.5
            elif ma5 < ma10: s = -0.5
            else: s = 0
            scores.append(s * wt)
        total = sum(scores)
        logger.info(f"[XAU MTF] 15m/1H/4H共振 => {total:+.2f}")
        return max(-1, min(1, total * 1.5))

    # ═══════════════════════════════════════
    # 均线排列 (MA20/50/200)
    # ═══════════════════════════════════════
    def _ma_alignment(self):
        candles = comex_candles(bar="1D", limit=200)
        if len(candles) < 50: return 0
        closes = [c[4] for c in candles]
        ma20 = sum(closes[-20:]) / 20
        ma50 = sum(closes[-50:]) / 50 if len(closes) >= 50 else ma20
        ma200 = sum(closes) / len(closes) if len(closes) >= 100 else ma50
        last = closes[-1]
        # 价格 vs MA位置 + MA排列
        above_all = last > ma20 and last > ma50 and last > ma200
        below_all = last < ma20 and last < ma50 and last < ma200
        if above_all and ma20 > ma50 > ma200: return 1.0   # 经典多头排列
        if below_all and ma20 < ma50 < ma200: return -1.0  # 经典空头排列
        if above_all: return 0.5
        if below_all: return -0.5
        return 0

    # ═══════════════════════════════════════
    # 动量 (ROC 12周期)
    # ═══════════════════════════════════════
    def _momentum_score(self):
        candles = comex_candles(bar="1H", limit=24)
        if len(candles) < 14: return 0
        closes = [c[4] for c in candles]
        roc = (closes[-1] - closes[-12]) / closes[-12] * 100 if closes[-12] > 0 else 0
        # Gold typically moves 0.5-2% per day
        score = max(-1, min(1, roc / 1.5))
        logger.info(f"[XAU Mom] ROC12={roc:+.2f}% => {score:+.2f}")
        return score

    # ═══════════════════════════════════════
    # 量能分析
    # ═══════════════════════════════════════
    def _volume_score(self):
        vol, avg_vol, oi = comex_volume_oi()
        self.vol_history.append(vol)
        # Use rolling self-average if yfinance avg is unreliable
        if len(self.vol_history) >= 5:
            hist_vols = list(self.vol_history)[:-1]  # exclude current
            rolling_avg = sum(hist_vols) / len(hist_vols) if hist_vols else 1
            if rolling_avg > 0:
                ratio = vol / rolling_avg
                px, _, _ = comex_price()
                prev = self._prev_price if hasattr(self, '_prev_price') else px
                self._prev_price = px
                direction = 1 if px > prev else (-1 if px < prev else 0)
                score = max(-1, min(1, (ratio - 1) * 2 * direction if abs(ratio - 1) < 5 else direction * 0.5))
                logger.info(f"[XAU Vol] vol={vol:.0f} roll_avg={rolling_avg:.0f} ratio={ratio:.2f}x dir={direction:+d} => {score:+.2f}")
                return score
        return 0

    # ═══════════════════════════════════════
    # 持仓量变化
    # ═══════════════════════════════════════
    def _oi_score(self):
        oi_list = comex_oi_history(10)
        if len(oi_list) < 5: return 0
        recent = oi_list[-3:] if len(oi_list) >= 3 else oi_list
        older = oi_list[:3] if len(oi_list) >= 6 else oi_list[:1]
        avg_recent = sum(recent) / len(recent)
        avg_older = sum(older) / len(older)
        if avg_older > 0:
            chg = (avg_recent - avg_older) / avg_older
            px, _, _ = comex_price()
            dir_sign = 1 if px > (getattr(self, '_prev_price', px) or px) else -1
            score = max(-1, min(1, chg / 0.05 * dir_sign))
            return score
        return 0

    # ═══════════════════════════════════════
    # COT持仓 (CFTC)
    # ═══════════════════════════════════════
    def _cot_score(self):
        cot = cot_positioning()
        if not cot: return 0
        net_pct = cot.get("net_pct_oi", 0)
        # Extreme positioning: net long > 30% OI = crowded long → bearish reversal
        # Extreme short: net short > 20% OI = crowded short → bullish reversal
        if net_pct > 30: score = -min(1, (net_pct - 30) / 30)
        elif net_pct < -20: score = min(1, (-net_pct - 20) / 30)
        else: score = net_pct / 50  # mild trend following
        logger.info(f"[XAU COT] net={net_pct:.0f}% OI => {score:+.2f}")
        return round(score, 3)

    # ═══════════════════════════════════════
    # 量价背离
    # ═══════════════════════════════════════
    def _vol_divergence(self):
        if len(self.vol_history) < 3: return 0
        vols = list(self.vol_history)[-3:]
        candles = comex_candles(bar="1H", limit=3)
        if len(candles) < 3: return 0
        prices = [c[4] for c in candles]
        # Shrinking volume + rising price = bearish divergence
        # Shrinking volume + falling price = bullish (selling exhausted)
        vol_decay = vols[-1] / max(vols[0], 1)
        price_chg = (prices[-1] - prices[0]) / prices[0] if prices[0] > 0 else 0
        if vol_decay < 0.7:
            return max(-1, min(1, -price_chg / 0.005))
        return 0

    # ═══════════════════════════════════════
    # 宏观驱动
    # ═══════════════════════════════════════
    def _macro_scores(self):
        macro = macro_drivers()
        if not macro: return 0, 0, 0, 0

        # 实际利率: TIPS价格↑=实际利率↓=黄金利好
        real_rate = 0
        tips_chg = macro.get("tips_change", 0)
        if tips_chg: real_rate = max(-1, min(1, tips_chg / 0.5))
        logger.info(f"[XAU Macro] TIPS chg={tips_chg:+.2f}% => real_rate={real_rate:+.2f}")

        # 美元反向: DXY↑=黄金↓
        dxy_inv = 0
        dxy_chg = macro.get("dxy_change", 0)
        if dxy_chg: dxy_inv = max(-1, min(1, -dxy_chg / 0.3))
        logger.info(f"[XAU Macro] DXY chg={dxy_chg:+.2f}% => dxy_inv={dxy_inv:+.2f}")

        # ETF资金流: GLD量+方向
        etf = gld_flow()
        etf_flow = 0
        if etf:
            flow_dir = etf.get("flow_direction", 0)
            vol = etf.get("volume", 0)
            avg = etf.get("avg_volume", 1)
            ratio = vol / avg if avg > 0 else 1
            etf_flow = max(-1, min(1, flow_dir * min(2, ratio)))
            logger.info(f"[XAU ETF] flow_dir={flow_dir:+d} vol_ratio={ratio:.1f}x => {etf_flow:+.2f}")

        # 信用风险: HYG↓=避险↑=黄金利好
        credit = 0
        hyg_px = macro.get("hyg_price", 0)
        if hyg_px:
            hyg_change = (hyg_px / (getattr(self, '_hyg_prev', hyg_px) or hyg_px) - 1) * 100
            self._hyg_prev = hyg_px
            credit = max(-1, min(1, -hyg_change / 0.3))
            logger.info(f"[XAU Credit] HYG => {credit:+.2f}")

        return real_rate, dxy_inv, etf_flow, credit

    # ═══════════════════════════════════════
    # RSI均值回归
    # ═══════════════════════════════════════
    def _rsi_score(self):
        candles = comex_candles(bar="1H", limit=24)
        if len(candles) < 15: return 0
        closes = [c[4] for c in candles]
        rsi = calc_rsi(closes, 14)
        if rsi > 70: return -min(1, (rsi - 70) / 20)
        if rsi < 30: return min(1, (30 - rsi) / 20)
        return 0

    # ═══════════════════════════════════════
    # 布林带
    # ═══════════════════════════════════════
    def _bb_score(self):
        candles = comex_candles(bar="1H", limit=30)
        if len(candles) < 20: return 0
        closes = [c[4] for c in candles]
        bb = calc_bb(closes, 20, 2)
        if not bb: return 0
        pct_b = bb["pct_b"]
        bw = bb["bandwidth"]
        # %B > 80: overbought → bearish; %B < 20: oversold → bullish
        # Tight bandwidth: breakout imminent → follow trend
        if pct_b > 80: return -0.7
        if pct_b > 60: return -0.3
        if pct_b < 20: return 0.7
        if pct_b < 40: return 0.3
        return 0

    # ═══════════════════════════════════════
    # 量能衰竭
    # ═══════════════════════════════════════
    def _exhaustion_score(self):
        candles = comex_candles(bar="1H", limit=6)
        if len(candles) < 6: return 0
        vols = [c[5] for c in candles]
        closes = [c[4] for c in candles]
        # 连续3根缩量 + 波幅收窄
        if len(vols) >= 6:
            recent = vols[-3:]
            prev = vols[-6:-3]
            vol_decay = sum(recent) / max(sum(prev), 1)
            ranges = [abs(candles[i][2] - candles[i][3]) for i in range(-3, 0)]  # high-low
            range_decay = sum(ranges[-3:]) / max(sum(ranges[-6:-3]), 0.01)
            trend = 1 if closes[-1] > closes[0] else -1
            if vol_decay < 0.6 and range_decay < 0.7:
                return -trend * 0.5  # exhaustion against trend
        return 0

    # ═══════════════════════════════════════
    # 综合评分 (复用BTC的归一化净差法)
    # ═══════════════════════════════════════
    def analyze(self):
        px, bid, ask = comex_price()
        if px == 0: return {"error": "no price", "price": 0}

        # 计算因子
        trend, t1h, t4h, t1d = self._trend_score()
        mtf = self._mtf_score()
        ma_align = self._ma_alignment()
        momentum = self._momentum_score()
        volume = self._volume_score()
        oi = self._oi_score()
        cot = self._cot_score()
        vol_div = self._vol_divergence()
        real_rate, dxy_inv, etf_flow, credit = self._macro_scores()
        rsi = self._rsi_score()
        bb = self._bb_score()
        exhaust = self._exhaustion_score()

        # 构建因子对 (名称, 权重, 得分)
        factor_data = {
            "trend":     (trend,     "XTR"),
            "mtf":       (mtf,       "XMF"),
            "ma_align":  (ma_align,  "XMA"),
            "momentum":  (momentum,  "XMO"),
            "volume":    (volume,    "XVL"),
            "oi":        (oi,        "XOI"),
            "cot":       (cot,       "XCF"),
            "vol_div":   (vol_div,   "XVX"),
            "real_rate": (real_rate, "XRD"),
            "dxy_inv":   (dxy_inv,   "XDX"),
            "etf_flow":  (etf_flow,  "XEF"),
            "credit":    (credit,    "XCR"),
            "rsi":       (rsi,       "XRSI"),
            "bb":        (bb,        "XBB"),
            "exhaust":   (exhaust,   "XEX"),
        }

        # 归一化净差评分
        factor_pairs = []
        for fname, (score, tag) in factor_data.items():
            w = self.factor_cfg.get(fname, (tag, 0.04))[1]
            factor_pairs.append((w, score))

        bull = sum(w * max(0, s) for w, s in factor_pairs)
        bear = sum(w * abs(min(0, s)) for w, s in factor_pairs)
        total_w = sum(w * abs(s) for w, s in factor_pairs)
        total = round((bull - bear) / total_w, 3) if total_w > 0.01 else 0.0
        total = max(-1, min(1, total))

        # 动量调整
        self.score_history.append(total)
        momentum_adj = 0
        if len(self.score_history) >= 3:
            sc = list(self.score_history)[-3:]
            if all(s > 0 for s in sc) and sc[-1] > sc[0]: momentum_adj = 0.03
            elif all(s < 0 for s in sc) and sc[-1] < sc[0]: momentum_adj = -0.03
        total = round(total + momentum_adj, 3)

        # 方向决策
        direction = ("LONG" if total >= self.threshold else
                    ("SHORT" if total <= -self.threshold else "WAIT"))

        # 趋势对齐
        t1h_dir = 1 if t1h == 1 else (-1 if t1h == -1 else 0)
        if direction != "WAIT" and REQUIRE_TREND_ALIGN:
            if t1h_dir != 0 and t4h != 0 and (t1h_dir != t4h):
                logger.info(f"[XAU] 1H({t1h:+d})/4H({t4h:+d})分歧 → WAIT")
                direction = "WAIT"

        # 多因子确认
        if direction != "WAIT":
            dir_sign = 1 if direction == "LONG" else -1
            all_scores = [s for s, _ in factor_data.values()]
            agree = sum(1 for s in all_scores if s * dir_sign > 0.15)
            if agree < MIN_FACTORS_AGREE:
                logger.info(f"[XAU] 因子确认不足 {agree}/{MIN_FACTORS_AGREE} → WAIT")
                direction = "WAIT"

        # RSI极端 → 部分反转
        if direction == "LONG" and rsi > 65:
            total = round(total * 0.8, 3)
            if total < self.threshold: direction = "WAIT"
        elif direction == "SHORT" and rsi < 35:
            total = round(total * 0.8, 3)
            if abs(total) < self.threshold: direction = "WAIT"

        # COT极端反转增强
        if cot != 0 and direction != "WAIT":
            cot_dir = 1 if cot > 0 else -1
            if direction == "LONG" and cot_dir > 0: total = min(1, total + 0.05)
            elif direction == "SHORT" and cot_dir < 0: total = min(1, abs(total) + 0.05)

        logger.info(f"[XAU] ${px:,.1f} | TR:{trend:+.1f} MTF:{mtf:+.2f} "
                   f"MA:{ma_align:+.1f} MOM:{momentum:+.2f} VL:{volume:+.2f} OI:{oi:+.2f} "
                   f"COT:{cot:+.2f} VD:{vol_div:+.2f} RR:{real_rate:+.2f} DXY:{dxy_inv:+.2f} "
                   f"ETF:{etf_flow:+.2f} CR:{credit:+.2f} RSI:{rsi:+.2f} BB:{bb:+.2f} EX:{exhaust:+.2f} "
                   f"=> {total:+.3f} -> {direction}")

        result = {
            "symbol": "XAU", "price": round(px, 2),
            "direction": direction, "total": total,
            "scores": {tag: s for _, (s, tag) in factor_data.items()},
            "raw": {
                "t1h": t1h, "t4h": t4h, "t1d": t1d,
                "cot_net_pct": cot,
                "real_rate": real_rate, "dxy_inv": dxy_inv,
                "etf_flow": etf_flow, "credit": credit,
                "rsi_val": rsi, "bb_pct": bb,
                "momentum_adj": momentum_adj,
            },
        }

        # Write to MySQL (same table as BTC)
        try:
            from storage.mysql_client import get_cursor
            scores = result["scores"]
            raw = result["raw"]
            with get_cursor() as cur:
                cur.execute("""
                    INSERT INTO factor_snapshots
                    (symbol, price, direction, total_score,
                     trend, orderbook, funding, taker, oi, maxpain,
                     vol_delta, btc_corr, gamma, iv, exhaust, liq_cool,
                     mean_revert, smart_money, mtf, ob_liq, low_lev, liq_ex,
                     retail, mm, liq_trigger, toxic,
                     funding_rate, t1h, t4h, pre_filter)
                    VALUES (%s,%s,%s,%s, %s,%s,%s,%s,%s,%s, %s,%s,%s,%s,%s,%s, %s,%s,%s,%s,%s,%s,%s, %s,%s,%s,%s, %s,%s,%s,%s)
                """, (
                    "XAU", result["price"], direction, total,
                    scores.get("XTR", 0), 0, scores.get("XEF", 0), scores.get("XVL", 0), scores.get("XOI", 0), 0,
                    scores.get("XVX", 0), scores.get("XCR", 0), 0, 0, scores.get("XEX", 0), 0,
                    scores.get("XRSI", 0), 0, scores.get("XMF", 0), 0, scores.get("XMA", 0), 0,
                    0, 0, 0, 0,
                    0, raw.get("t1h", 0), raw.get("t4h", 0), 0,
                ))
        except Exception:
            pass

        return result
