"""
Day 6：波动率因子单因子选股策略
100天经典策略学习计划

测试因子：20日历史波动率、60日历史波动率、特质波动率
方法：选波动率最低的N只股票（低波异象），等权持有
股票池：沪深300
调仓频率：每月

使用方法：修改 FACTOR_NAME 切换测试不同因子
回测建议：2020-01-01 至 2026-02-01，初始资金100万
"""


# ========== 切换因子 ==========
# 可选: 'vol_20d', 'vol_60d', 'idiosyncratic_vol'
FACTOR_NAME = 'vol_20d'
# ==============================


def initialize(context):
    set_params()
    set_backtest()
    run_monthly(rebalance, monthday=1, time='09:31')


def set_params():
    g.stock_pool = '000300.XSHG'
    g.stock_num = 20
    g.factor = FACTOR_NAME
    g.market_index = '000300.XSHG'  # 用于计算特质波动率


def set_backtest():
    set_benchmark('000300.XSHG')
    set_option('use_real_price', True)
    set_slippage(FixedSlippage(0.02))
    set_commission(PerTrade(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
    log.set_level('order', 'error')


def rebalance(context):
    stocks = get_index_stocks(g.stock_pool)
    stocks = filter_stocks(stocks)

    df = get_factor_data(stocks, context)

    if df is None or df.empty:
        return

    # 波动率因子：越低越好（低波异象）
    df = df.sort_values(g.factor, ascending=True)
    target = df['code'].head(g.stock_num).tolist()

    log.info(f'[{g.factor}] 选出{len(target)}只，前3: {target[:3]}')

    adjust_portfolio(context, target)


def get_factor_data(stocks, context):
    """获取波动率因子数据"""
    import pandas as pd
    import numpy as np

    if g.factor in ['vol_20d', 'vol_60d']:
        # 历史波动率
        lookback = 20 if g.factor == 'vol_20d' else 60

        prices = get_price(
            stocks,
            end_date=context.current_dt,
            count=lookback + 1,
            frequency='daily',
            fields=['close'],
            panel=False
        )

        if prices is None or prices.empty:
            return None

        pivot = prices.pivot(index='time', columns='code', values='close')

        # 计算日收益率
        returns = pivot.pct_change().iloc[1:]

        # 计算波动率（标准差）
        volatility = returns.std()

        df = pd.DataFrame({
            'code': volatility.index,
            g.factor: volatility.values
        })
        df = df.dropna()
        df = df[df[g.factor] > 0]

        return df

    elif g.factor == 'idiosyncratic_vol':
        # 特质波动率：需要先做市场模型回归
        lookback = 60

        # 获取股票收益率
        stock_prices = get_price(
            stocks,
            end_date=context.current_dt,
            count=lookback + 1,
            frequency='daily',
            fields=['close'],
            panel=False
        )

        # 获取市场收益率
        market_prices = get_price(
            g.market_index,
            end_date=context.current_dt,
            count=lookback + 1,
            frequency='daily',
            fields=['close'],
            panel=False
        )

        if stock_prices is None or stock_prices.empty or market_prices is None or market_prices.empty:
            return None

        # 计算市场收益率
        market_returns = market_prices['close'].pct_change().iloc[1:].values

        # 计算每只股票的特质波动率
        stock_pivot = stock_prices.pivot(index='time', columns='code', values='close')
        stock_returns = stock_pivot.pct_change().iloc[1:]

        idio_vols = {}
        for stock in stock_returns.columns:
            stock_ret = stock_returns[stock].dropna().values

            # 确保长度匹配
            min_len = min(len(stock_ret), len(market_returns))
            if min_len < 20:  # 至少需要20个数据点
                continue

            stock_ret = stock_ret[:min_len]
            market_ret = market_returns[:min_len]

            # 简单线性回归：stock_ret = alpha + beta * market_ret + epsilon
            # 使用numpy的polyfit
            try:
                beta, alpha = np.polyfit(market_ret, stock_ret, 1)
                predicted = alpha + beta * market_ret
                residuals = stock_ret - predicted

                # 特质波动率 = 残差的标准差
                idio_vol = np.std(residuals, ddof=1)
                idio_vols[stock] = idio_vol
            except:
                continue

        if not idio_vols:
            return None

        df = pd.DataFrame({
            'code': list(idio_vols.keys()),
            'idiosyncratic_vol': list(idio_vols.values())
        })
        df = df[df['idiosyncratic_vol'] > 0]

        return df


def filter_stocks(stocks):
    """过滤ST和停牌"""
    current_data = get_current_data()
    return [s for s in stocks
            if not current_data[s].is_st
            and not current_data[s].paused]


def adjust_portfolio(context, target):
    """调仓：先卖后买"""
    for stock in list(context.portfolio.positions):
        if stock not in target:
            order_target(stock, 0)

    if target:
        per_value = context.portfolio.total_value * 0.95 / len(target)
        for stock in target:
            order_target_value(stock, per_value)
