"""
Day 9：情绪因子单因子选股策略
100天经典策略学习计划

测试因子：ARBR、PSY、OBV、资金流向
方法：根据情绪指标信号选股，等权持有
股票池：沪深300
调仓频率：每月

使用方法：修改 FACTOR_NAME 切换测试不同因子
回测建议：2020-01-01 至 2026-02-01，初始资金100万

注意：情绪指标通常适合短期交易，月度调仓可能不匹配
"""


# ========== 切换因子 ==========
# 可选: 'arbr', 'psy', 'obv', 'money_flow'
FACTOR_NAME = 'arbr'
# ==============================


def initialize(context):
    set_params()
    set_backtest()
    run_monthly(rebalance, monthday=1, time='09:31')


def set_params():
    g.stock_pool = '000300.XSHG'
    g.stock_num = 20
    g.factor = FACTOR_NAME


def set_backtest():
    set_benchmark('000300.XSHG')
    set_option('use_real_price', True)
    set_slippage(FixedSlippage(0.02))
    set_commission(PerTrade(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
    log.set_level('order', 'error')


def rebalance(context):
    stocks = get_index_stocks(g.stock_pool)
    stocks = filter_stocks(stocks)

    df = get_factor_data(stocks, context)

    if df is None or df.empty:
        return

    # 情绪因子：根据不同指标选股
    df = df.sort_values(g.factor, ascending=False)
    target = df['code'].head(g.stock_num).tolist()

    log.info(f'[{g.factor}] 选出{len(target)}只，前3: {target[:3]}')

    adjust_portfolio(context, target)


def get_factor_data(stocks, context):
    """获取情绪因子数据"""
    import pandas as pd
    import numpy as np

    # 获取足够长的历史数据
    prices = get_price(
        stocks,
        end_date=context.current_dt,
        count=40,  # 足够计算各种指标
        frequency='daily',
        fields=['open', 'close', 'high', 'low', 'volume'],
        panel=False
    )

    if prices is None or prices.empty:
        return None

    results = []

    for stock in stocks:
        stock_data = prices[prices['code'] == stock].sort_values('time')
        if len(stock_data) < 20:
            continue

        open_price = stock_data['open'].values
        close = stock_data['close'].values
        high = stock_data['high'].values
        low = stock_data['low'].values
        volume = stock_data['volume'].values

        try:
            if g.factor == 'arbr':
                score = calculate_arbr(open_price, close, high, low)
            elif g.factor == 'psy':
                score = calculate_psy(close)
            elif g.factor == 'obv':
                score = calculate_obv(close, volume)
            elif g.factor == 'money_flow':
                score = calculate_money_flow(close, volume)
            else:
                continue

            if score is not None and not np.isnan(score):
                results.append({'code': stock, g.factor: score})
        except:
            continue

    if not results:
        return None

    df = pd.DataFrame(results)
    return df


def calculate_arbr(open_price, close, high, low, period=26):
    """
    ARBR人气意愿指标
    选股逻辑：AR和BR都在80-120区间（情绪适中）
    返回：100 - |AR+BR-200|（越接近200越好）
    """
    import pandas as pd

    df = pd.DataFrame({
        'open': open_price,
        'close': close,
        'high': high,
        'low': low
    })

    # 计算AR
    ho = df['high'] - df['open']
    ol = df['open'] - df['low']
    ar = ho.rolling(window=period).sum() / ol.rolling(window=period).sum() * 100

    # 计算BR（需要昨收）
    df['prev_close'] = df['close'].shift(1)
    hc = df['high'] - df['prev_close']
    cl = df['prev_close'] - df['low']
    br = hc.rolling(window=period).sum() / cl.rolling(window=period).sum() * 100

    ar_value = ar.iloc[-1]
    br_value = br.iloc[-1]

    # AR和BR都在80-120区间最好
    if 80 <= ar_value <= 120 and 80 <= br_value <= 120:
        # 越接近100越好
        return 200 - abs(ar_value - 100) - abs(br_value - 100)
    else:
        return -999


def calculate_psy(close, period=12):
    """
    PSY心理线指标
    选股逻辑：PSY在40-60区间（情绪中性）
    返回：100 - |PSY-50|（越接近50越好）
    """
    import pandas as pd

    prices = pd.Series(close)

    # 计算涨跌
    change = prices.diff()
    up_days = (change > 0).astype(int)

    # 计算PSY
    psy = up_days.rolling(window=period).sum() / period * 100

    psy_value = psy.iloc[-1]

    # PSY在40-60区间最好
    if 40 <= psy_value <= 60:
        return 100 - abs(psy_value - 50)
    else:
        return -999


def calculate_obv(close, volume):
    """
    OBV能量潮指标
    选股逻辑：OBV呈上升趋势
    返回：OBV的20日线性回归斜率（越大越好）
    """
    import pandas as pd
    import numpy as np

    df = pd.DataFrame({'close': close, 'volume': volume})

    # 计算OBV
    df['change'] = df['close'].diff()
    df['direction'] = np.where(df['change'] > 0, 1,
                               np.where(df['change'] < 0, -1, 0))
    df['obv'] = (df['direction'] * df['volume']).cumsum()

    # 计算OBV的趋势（线性回归斜率）
    obv_values = df['obv'].iloc[-20:].values
    if len(obv_values) < 20:
        return -999

    x = np.arange(len(obv_values))
    slope, _ = np.polyfit(x, obv_values, 1)

    # 标准化斜率（除以OBV均值，避免大盘股占优）
    obv_mean = obv_values.mean()
    if obv_mean != 0:
        normalized_slope = slope / abs(obv_mean) * 1000
        return float(normalized_slope)
    else:
        return -999


def calculate_money_flow(close, volume, period=20):
    """
    资金流向（简化版）
    选股逻辑：过去N日资金净流入为正
    返回：资金净流入占比
    """
    import pandas as pd
    import numpy as np

    df = pd.DataFrame({'close': close, 'volume': volume})

    # 计算价格变化
    df['change'] = df['close'].diff()

    # 简化逻辑：价格上涨时的成交量视为买入，下跌时视为卖出
    df['money_in'] = np.where(df['change'] > 0, df['volume'], 0)
    df['money_out'] = np.where(df['change'] < 0, df['volume'], 0)

    # 计算过去N日的资金流向
    money_in_sum = df['money_in'].iloc[-period:].sum()
    money_out_sum = df['money_out'].iloc[-period:].sum()

    total = money_in_sum + money_out_sum
    if total > 0:
        net_flow_ratio = (money_in_sum - money_out_sum) / total * 100
        return float(net_flow_ratio)
    else:
        return -999


def filter_stocks(stocks):
    """过滤ST和停牌"""
    current_data = get_current_data()
    return [s for s in stocks
            if not current_data[s].is_st
            and not current_data[s].paused]


def adjust_portfolio(context, target):
    """调仓：先卖后买"""
    for stock in list(context.portfolio.positions):
        if stock not in target:
            order_target(stock, 0)

    if target:
        per_value = context.portfolio.total_value * 0.95 / len(target)
        for stock in target:
            order_target_value(stock, per_value)
