# -*- coding: utf-8 -*-
"""
Macro Multi-Asset Correlation Engine
支持: OKX(crypto), Binance(PAXG), Yahoo Finance(美股/原油/美债/美元)
计算: 1W/1M/3M Pearson 相关系数矩阵 + 滚动相关性
"""
import math, json, os, sys, time
from datetime import datetime, timedelta

_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, _ROOT)
from config import MACRO_ASSETS
from exchange.http import http_get

ASSETS = MACRO_ASSETS
TIMEFRAMES = {
    "1W": {"bar": "1H", "limit": 168, "label": "1 Week"},
    "1M": {"bar": "4H", "limit": 180, "label": "1 Month"},
    "3M": {"bar": "1D", "limit": 90,  "label": "3 Months"},
}


def _fetch_okx(inst_id, bar, limit):
    try:
        r = http_get("https://www.okx.com/api/v5/market/candles",
                     params={"instId": inst_id, "bar": bar, "limit": str(limit)}, timeout=12)
        candles = r.json().get("data", [])
        return [float(c[4]) for c in reversed(candles)]
    except Exception:
        return []


def _fetch_binance(symbol, bar, limit):
    try:
        interval = {"1H": "1h", "4H": "4h", "1D": "1d"}.get(bar, "1h")
        r = http_get("https://api.binance.com/api/v3/klines",
                     params={"symbol": symbol, "interval": interval, "limit": limit}, timeout=12)
        return [float(c[4]) for c in r.json()]
    except Exception:
        return []


def _fetch_yahoo(symbol, bar, limit):
    """Yahoo Finance via yfinance library (handles crumb/cookie)"""
    try:
        import yfinance as yf
        interval = {"1H": "1h", "4H": "1h", "1D": "1d"}.get(bar, "1h")
        if bar == "1H":
            period = "7d" if limit <= 168 else "14d"
        elif bar == "4H":
            period = "1mo" if limit <= 180 else "2mo"
        else:
            period = "3mo" if limit >= 90 else "1mo"

        ticker = yf.Ticker(symbol)
        df = ticker.history(period=period, interval=interval)
        if df.empty:
            return []
        closes = df['Close'].dropna().tolist()
        return closes
    except Exception:
        return []


def fetch_prices(asset_key, bar, limit):
    cfg = ASSETS.get(asset_key)
    if not cfg:
        return []
    src = cfg["source"]
    sym = cfg["symbol"]
    if src == "okx":
        return _fetch_okx(sym, bar, limit)
    elif src == "binance":
        return _fetch_binance(sym, bar, limit)
    elif src == "yahoo":
        return _fetch_yahoo(sym, bar, limit)
    return []


def log_returns(prices):
    if len(prices) < 2:
        return []
    return [math.log(prices[i] / prices[i-1]) for i in range(1, len(prices))]


def pearson(x, y):
    n = min(len(x), len(y))
    if n < 5:
        return 0.0
    rx, ry = x[-n:], y[-n:]
    mx, my = sum(rx)/n, sum(ry)/n
    num = sum((a-mx)*(b-my) for a, b in zip(rx, ry))
    da = math.sqrt(sum((a-mx)**2 for a in rx))
    db = math.sqrt(sum((b-my)**2 for b in ry))
    return round(num/(da*db), 4) if da > 1e-10 and db > 1e-10 else 0.0


def rolling_correlation(x, y, window=30):
    """30-period rolling correlation between two return series"""
    n = min(len(x), len(y))
    if n < window:
        return [], []
    rx, ry = x[-n:], y[-n:]
    result = []
    for i in range(window, n+1):
        wx, wy = rx[i-window:i], ry[i-window:i]
        mx, my = sum(wx)/window, sum(wy)/window
        num = sum((a-mx)*(b-my) for a, b in zip(wx, wy))
        da = math.sqrt(sum((a-mx)**2 for a in wx))
        db = math.sqrt(sum((b-my)**2 for b in wy))
        r = num/(da*db) if da > 1e-10 and db > 1e-10 else 0.0
        result.append(round(r, 4))
    timestamps = list(range(len(result)))
    return timestamps, result


def analyze(tf="1W"):
    """主分析: 计算指定时间框架的相关性矩阵"""
    cfg = TIMEFRAMES.get(tf, TIMEFRAMES["1W"])
    bar, limit = cfg["bar"], cfg["limit"]

    symbols = list(ASSETS.keys())
    labels = {k: v["label"] for k, v in ASSETS.items()}
    categories = {k: v["category"] for k, v in ASSETS.items()}

    # Fetch all prices
    prices_raw = {}
    for s in symbols:
        px = fetch_prices(s, bar, limit)
        if px:
            prices_raw[s] = px

    # Calculate log returns
    returns = {}
    for s, px in prices_raw.items():
        ret = log_returns(px)
        if ret:
            returns[s] = ret

    # Correlation matrix
    active = [s for s in symbols if s in returns]
    matrix = {}
    for si in active:
        matrix[si] = {}
        for sj in active:
            matrix[si][sj] = pearson(returns[si], returns[sj])

    # Latest prices
    prices = {s: round(prices_raw[s][-1], 2) for s in active if prices_raw[s]}

    # Volatility (annualized)
    vols = {}
    for s in active:
        ret = returns[s]
        if ret:
            daily_vol = (sum((r - sum(ret)/len(ret))**2 for r in ret) / len(ret)) ** 0.5
            # Annualize based on bar
            if bar == "1H":
                annual = daily_vol * math.sqrt(365 * 24)
            elif bar == "4H":
                annual = daily_vol * math.sqrt(365 * 6)
            else:
                annual = daily_vol * math.sqrt(365)
            vols[s] = round(annual * 100, 1)

    # BTC correlations (sorted)
    btc_corrs = []
    if "BTC" in matrix:
        for s in active:
            if s != "BTC" and s in matrix["BTC"]:
                btc_corrs.append({"symbol": s, "label": labels.get(s, s),
                                  "corr": matrix["BTC"][s],
                                  "category": categories.get(s, "")})
    btc_corrs.sort(key=lambda x: -abs(x["corr"]))

    # Rolling correlations (key pairs vs BTC)
    rolling = {}
    key_pairs = ["SPX", "XAU", "WTI", "DXY", "T10", "NDQ"]
    if "BTC" in returns:
        for s in key_pairs:
            if s in returns and s != "BTC":
                ts, vals = rolling_correlation(returns["BTC"], returns[s],
                                              window=min(30, len(returns["BTC"])-1))
                if vals:
                    rolling[s] = {"ts": ts, "values": vals, "label": labels.get(s, s)}

    return {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "timeframe": tf,
        "timeframe_label": cfg["label"],
        "symbols": active,
        "labels": labels,
        "categories": categories,
        "matrix": matrix,
        "prices": prices,
        "volatilities": vols,
        "btc_correlations": btc_corrs,
        "rolling": rolling,
        "n_assets": len(active),
    }


def fetch_liquidity():
    """
    各品种相对流动性：当前成交量 vs 过去3年最高日成交量
    返回: {symbol: {current_vol, max_3yr_vol, ratio_pct, percentile, ...}}
    统一用百分比衡量，跨品种可比
    """
    liquidity = {}

    # ── 加密货币: CoinGecko 全球聚合 → Binance(最大所) → OKX 降级 ──
    cg_ids = {"BTC": "bitcoin"}
    binance_syms = {"BTC": "BTCUSDT"}
    for key in ["BTC"]:
        vols = []
        source = ""
        unit = "USDT"
        # 1) CoinGecko 全球聚合
        try:
            cg_id = cg_ids.get(key)
            if cg_id:
                now_ts = int(time.time())
                three_yr_ago = now_ts - 3 * 365 * 86400
                r = http_get(f"https://api.coingecko.com/api/v3/coins/{cg_id}/market_chart/range",
                             params={"vs_currency": "usd", "from": three_yr_ago, "to": now_ts}, timeout=20)
                data = r.json()
                vols_raw = data.get("total_volumes", [])
                vols = [v[1] for v in vols_raw if v[1] > 0]
                source = f"CoinGecko ({len(vols)}d global)"
                unit = "USD"
                time.sleep(1.5)  # Rate limit
        except Exception:
            pass

        # 2) Binance 全球最大交易所 (USDT本位合约日线, 3年=1095根)
        if len(vols) < 30:
            try:
                bsym = binance_syms.get(key)
                if bsym:
                    r = http_get("https://api.binance.com/api/v3/klines",
                                 params={"symbol": bsym, "interval": "1d", "limit": 1095}, timeout=15)
                    candles = r.json()
                    vols = [float(c[5]) for c in candles if float(c[5]) > 0]  # quote volume
                    source = f"Binance ({len(vols)}d, largest CEX)"
                    unit = "USDT"
            except Exception:
                pass

        # 3) OKX 降级
        if len(vols) < 30:
            try:
                cfg = ASSETS.get(key)
                if cfg:
                    r = http_get("https://www.okx.com/api/v5/market/candles",
                                 params={"instId": cfg["symbol"], "bar": "1D", "limit": "1095"}, timeout=15)
                    candles = r.json().get("data", [])
                    vols = [float(c[5]) for c in candles if float(c[5]) > 0]
                    source = f"OKX ({len(vols)}d)"
                    unit = "USDT"
            except Exception:
                pass

        if len(vols) < 30:
            continue
        current = vols[-1] if vols else 0
        max_3yr = max(vols)
        ratio = round(current / max_3yr * 100, 1) if max_3yr > 0 else 0
        sorted_vols = sorted(vols)
        rank = sum(1 for v in sorted_vols if v <= current)
        percentile = round(rank / len(sorted_vols) * 100, 1)
        liquidity[key] = {
            "current_vol": round(current, 0),
            "max_3yr_vol": round(max_3yr, 0),
            "ratio_pct": ratio,
            "percentile": percentile,
            "n_days": len(vols),
            "unit": unit,
            "source": source,
        }

    # ── 传统资产: Yahoo Finance 3年日线成交量 ──
    # symbol -> (yfinance ticker, multiplier for notional)
    # 流动性: (yfinance symbol, 合约乘数, 说明)
    # 期货/ETF/现货 — 均取3年日线成交量
    yahoo_assets = {
        # 贵金属
        "XAU": ("GLD", 1,      "GLD 黄金ETF"),
        "SVR": ("SI=F", 1,     "COMEX 白银期货"),
        "PLT": ("PL=F", 1,     "NYMEX 铂金期货"),
        # 能源
        "WTI": ("CL=F", 1,     "NYMEX 原油期货"),
        "BNO": ("NG=F", 1,     "NYMEX 天然气期货"),
        # 工业金属
        "CPR": ("HG=F", 1,     "COMEX 铜期货"),
        "ALU": ("ALI=F", 1,    "LME 铝期货"),
        # 美股 (期货/ETF)
        "SPX": ("ES=F", 50,    "CME 标普500 E-mini"),
        "NDQ": ("NQ=F", 20,    "CME 纳斯达克 E-mini"),
        "RUT": ("IWM", 1,      "iShares 罗素2000 ETF"),
        "DJI": ("YM=F", 5,     "CME 道琼斯 E-mini"),
        # 全球股票
        "FTSE": ("EWU", 1,     "iShares 英国 ETF"),
        "N225": ("EWJ", 1,     "iShares 日本 ETF"),
        "DAX":  ("EWG", 1,     "iShares 德国 ETF"),
        "HSI":  ("EWH", 1,     "iShares 香港 ETF"),
        # 汇率 (ETF代理)
        "DXY": ("UUP", 1,      "Invesco 美元指数 ETF"),
        "EUR": ("FXE", 1,      "Invesco 欧元 ETF"),
        "JPY": ("FXY", 1,      "Invesco 日元 ETF"),
        "GBP": ("FXB", 1,      "Invesco 英镑 ETF"),
        "CNH": ("CYB", 1,      "WisdomTree 人民币 ETF"),
        "AUD": ("FXA", 1,      "Invesco 澳元 ETF"),
        # 利率 (期货/ETF)
        "T2Y": ("ZT=F", 2000,  "CME 2年期T-Note"),
        "T10": ("ZN=F", 1000,  "CME 10年期T-Note"),
        "T30": ("ZB=F", 1000,  "CME 30年期T-Bond"),
        "TIP": ("TIP", 1,      "iShares TIPS ETF"),
        # 信用
        "HYG": ("HYG", 1,      "iShares 高收益债 ETF"),
        "LQD": ("LQD", 1,      "iShares 投资级债 ETF"),
        # 波动率/另类
        "VIX": ("UVXY", 1,     "ProShares VIX ETF"),
        "EEM": ("EEM", 1,      "iShares 新兴市场 ETF"),
        "REIT":("VNQ", 1,      "Vanguard 不动产 ETF"),
    }
    try:
        import yfinance as yf
        for key, (sym, mult, desc) in yahoo_assets.items():
            if key in liquidity:
                continue
            try:
                t = yf.Ticker(sym)
                df = t.history(period="3y", interval="1d")
                if df.empty or "Volume" not in df.columns:
                    continue
                vols = df["Volume"].dropna().tolist()
                if len(vols) < 50:
                    continue
                current = vols[-1] if vols else 0
                max_3yr = max(vols)
                ratio = round(current / max_3yr * 100, 1) if max_3yr > 0 else 0
                sorted_vols = sorted(vols)
                rank = sum(1 for v in sorted_vols if v <= current)
                percentile = round(rank / len(sorted_vols) * 100, 1)
                unit = "contracts" if mult > 1 else "shares"

                liquidity[key] = {
                    "current_vol": round(current, 0),
                    "max_3yr_vol": round(max_3yr, 0),
                    "ratio_pct": ratio,
                    "percentile": percentile,
                    "n_days": len(vols),
                    "unit": unit,
                    "source": f"{desc} ({len(vols)}d)",
                }
            except Exception:
                pass
    except Exception:
        pass

    return liquidity


def to_json(data):
    """Serialize for frontend, including liquidity"""
    liq = fetch_liquidity()
    out = {
        "ts": data["timestamp"],
        "tf": data["timeframe"],
        "tf_label": data["timeframe_label"],
        "symbols": data["symbols"],
        "labels": data["labels"],
        "categories": data["categories"],
        "prices": data["prices"],
        "vols": data["volatilities"],
        "n": data["n_assets"],
        "matrix": [[data["matrix"].get(si, {}).get(sj, 0) for sj in data["symbols"]] for si in data["symbols"]],
        "btc_corr": data["btc_correlations"],
        "rolling": {k: {"label": v["label"], "values": v["values"]} for k, v in data["rolling"].items()},
        "liquidity": liq,
    }
    return out


if __name__ == "__main__":
    import sys
    tf = sys.argv[1] if len(sys.argv) > 1 else "1W"
    data = analyze(tf)
    print(json.dumps(to_json(data), ensure_ascii=False, indent=2))
