# Day 13: 因子合成 - 等权、IC加权、ICIR加权
# 测试价值+动量+质量三因子组合

# ==================== 可切换参数（在这里修改）====================
VERSION = 'v1'  # 'v1': 等权 | 'v2': IC加权 | 'v3': ICIR加权 | 'v4': 单因子基准
# ================================================================

import pandas as pd
import numpy as np
from jqdata import *

"""
回测参数：2020-01-01 至 2026-02-01，100万，沪深300，月度调仓，持仓20只

V1（等权）：1/3, 1/3, 1/3
V2（IC加权）：基于过去12个月IC均值动态调权
V3（ICIR加权）：基于过去12个月IR动态调权
V4（单因子基准）：只用行业中性化PCF
"""

def initialize(context):
    g.stock_num = 20
    g.index = '000300.XSHG'

    # IC历史数据
    g.ic_history = {'pcf': [], 'momentum': [], 'turnover': []}
    # 上期因子值（用于计算IC，避免未来函数）
    g.last_factors = {'pcf': None, 'momentum': None, 'turnover': None}

    set_option('use_real_price', True)
    log.set_level('order', 'error')
    run_monthly(trade, 1)

# ==================== 因子预处理 ====================

def winsorize(series):
    return series.clip(series.quantile(0.05), series.quantile(0.95))

def standardize(series):
    std = series.std()
    if std == 0:
        return series * 0
    return (series - series.mean()) / std

def neutralize_industry(factor_df, stocks, date):
    from jqdata import get_industries
    industry_dict = {}
    try:
        industries = get_industries('sw_l1', date=date)
        for ind_code in industries.index:
            ind_stocks = get_industry_stocks(ind_code, date=date)
            matched = [s for s in stocks if s in ind_stocks]
            if matched:
                industry_dict[industries.loc[ind_code]['name']] = matched
    except:
        return standardize(factor_df)

    if len(industry_dict) <= 1:
        return standardize(factor_df)

    neutral = pd.Series(0.0, index=stocks)
    for ind, ind_stocks in industry_dict.items():
        common = [s for s in ind_stocks if s in factor_df.index]
        if len(common) > 1:
            neutral.loc[common] = standardize(factor_df.loc[common])
    return neutral

# ==================== 因子计算 ====================

def get_pcf_factor(stocks, date):
    q = query(valuation.code, valuation.pcf_ratio).filter(
        valuation.code.in_(stocks), valuation.pcf_ratio > 0)
    df = get_fundamentals(q, date)
    if df.empty:
        return pd.Series(dtype=float)
    df = df.set_index('code')
    pcf = df['pcf_ratio']
    return neutralize_industry(-standardize(winsorize(pcf)), pcf.index.tolist(), date)

def get_momentum_factor(stocks, date):
    try:
        price_df = get_price(stocks, end_date=date, count=21,
                             fields=['close'], panel=False, fq='pre')
        if price_df is None or price_df.empty:
            return pd.Series(dtype=float)
        pp = price_df.pivot(index='time', columns='code', values='close')
        mom = {}
        for stock in pp.columns:
            ps = pp[stock].dropna()
            if len(ps) >= 2:
                mom[stock] = ps.iloc[-1] / ps.iloc[0] - 1
        ms = pd.Series(mom)
        if ms.empty:
            return pd.Series(dtype=float)
        return standardize(winsorize(ms))
    except Exception as e:
        log.error(f"动量因子失败: {e}")
        return pd.Series(dtype=float)

def get_roe_factor(stocks, date):
    q = query(valuation.code, indicator.roe).filter(
        valuation.code.in_(stocks), indicator.roe > 0)
    df = get_fundamentals(q, date)
    if df.empty:
        return pd.Series(dtype=float)
    df = df.set_index('code')
    return standardize(winsorize(df['roe']))

# ==================== 收益率计算（用于IC） ====================

def get_monthly_returns(stocks, date):
    """获取上个月的收益率（用于和上期因子值计算IC）"""
    try:
        price_df = get_price(stocks, end_date=date, count=22,
                             fields=['close'], panel=False, fq='pre')
        if price_df is None or price_df.empty:
            return pd.Series(dtype=float)
        pp = price_df.pivot(index='time', columns='code', values='close')
        rets = {}
        for stock in pp.columns:
            ps = pp[stock].dropna()
            if len(ps) >= 2:
                rets[stock] = ps.iloc[-1] / ps.iloc[0] - 1
        return pd.Series(rets)
    except:
        return pd.Series(dtype=float)

def update_ic_history(context):
    """用上期因子值 + 本期收益率计算IC（无未来函数）"""
    if g.last_factors['pcf'] is None:
        return

    stocks = get_index_stocks(g.index)
    returns = get_monthly_returns(stocks, context.current_dt)

    if returns.empty:
        return

    for name in ['pcf', 'momentum', 'turnover']:
        factor = g.last_factors[name]
        if factor is not None and not factor.empty:
            common = factor.index.intersection(returns.index)
            if len(common) > 10:
                ic = factor.loc[common].corr(returns.loc[common], method='spearman')
                g.ic_history[name].append(ic if not pd.isna(ic) else 0.0)
            else:
                g.ic_history[name].append(0.0)
        else:
            g.ic_history[name].append(0.0)

# ==================== 因子合成 ====================

def combine_equal(factors):
    common = set(factors['pcf'].index)
    for f in factors.values():
        common = common.intersection(set(f.index))
    common = list(common)
    if not common:
        return pd.Series(dtype=float)
    result = pd.Series(0.0, index=common)
    for f in factors.values():
        result += f.loc[common] / len(factors)
    return result

def combine_ic(factors):
    if len(g.ic_history['pcf']) < 3:
        return combine_equal(factors)

    ic_means = {}
    for name in ['pcf', 'momentum', 'turnover']:
        ic_means[name] = np.mean(g.ic_history[name][-12:])

    total = sum(abs(v) for v in ic_means.values())
    if total == 0:
        return combine_equal(factors)

    weights = {k: v / total for k, v in ic_means.items()}
    log.info(f"IC权重: PCF={weights['pcf']:.2f}, Mom={weights['momentum']:.2f}, ROE={weights['turnover']:.2f}")

    common = set(factors['pcf'].index)
    for f in factors.values():
        common = common.intersection(set(f.index))
    common = list(common)
    if not common:
        return pd.Series(dtype=float)

    result = pd.Series(0.0, index=common)
    for name in ['pcf', 'momentum', 'turnover']:
        result += factors[name].loc[common] * weights[name]
    return result

def combine_icir(factors):
    if len(g.ic_history['pcf']) < 3:
        return combine_equal(factors)

    irs = {}
    for name in ['pcf', 'momentum', 'turnover']:
        series = g.ic_history[name][-12:]
        mean = np.mean(series)
        std = np.std(series)
        irs[name] = mean / std if std > 0 else 0.0

    total = sum(abs(v) for v in irs.values())
    if total == 0:
        return combine_equal(factors)

    weights = {k: v / total for k, v in irs.items()}
    log.info(f"ICIR权重: PCF={weights['pcf']:.2f}, Mom={weights['momentum']:.2f}, ROE={weights['turnover']:.2f}")

    common = set(factors['pcf'].index)
    for f in factors.values():
        common = common.intersection(set(f.index))
    common = list(common)
    if not common:
        return pd.Series(dtype=float)

    result = pd.Series(0.0, index=common)
    for name in ['pcf', 'momentum', 'turnover']:
        result += factors[name].loc[common] * weights[name]
    return result

# ==================== 交易 ====================

def trade(context):
    stocks = get_index_stocks(g.index)

    # 1. 先用上期因子值+本期收益率更新IC历史
    update_ic_history(context)

    # 2. 计算本期因子
    pcf = get_pcf_factor(stocks, context.current_dt)
    mom = get_momentum_factor(stocks, context.current_dt)
    roe = get_roe_factor(stocks, context.current_dt)

    if pcf.empty or mom.empty or roe.empty:
        log.warn(f"因子为空: PCF={len(pcf)}, Mom={len(mom)}, ROE={len(roe)}")
        return

    factors = {'pcf': pcf, 'momentum': mom, 'turnover': roe}

    # 3. 保存本期因子值（下个月用于计算IC）
    g.last_factors = {'pcf': pcf.copy(), 'momentum': mom.copy(), 'turnover': roe.copy()}

    # 4. 合成
    if VERSION == 'v1':
        combined = combine_equal(factors)
    elif VERSION == 'v2':
        combined = combine_ic(factors)
    elif VERSION == 'v3':
        combined = combine_icir(factors)
    elif VERSION == 'v4':
        combined = pcf
    else:
        combined = combine_equal(factors)

    if combined.empty:
        return

    # 5. 选股交易
    target = combined.sort_values(ascending=False).head(g.stock_num).index.tolist()

    for stock in list(context.portfolio.positions.keys()):
        if stock not in target:
            order_target(stock, 0)

    if target:
        w = 1.0 / len(target)
        for stock in target:
            order_target_value(stock, context.portfolio.total_value * w)
