This repository has been archived on 2026-01-07. You can view files and clone it, but cannot push or open issues or pull requests.
Files
resources/stockapp/stat_yield_rate.py

282 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pymysql
import logging
import sys
import time
import numpy as np
from datetime import datetime
import argparse
# 设置默认值
default_min_stat_years = 5
allowed_min_stat_years = [3, 5] # 允许的年份统计范围
default_debug = False
default_market_key = "hs300"
allowed_market_keys = ['hs300', 'sp500']
# 配置命令行参数
def parse_arguments():
parser = argparse.ArgumentParser(description="Run stock yield statistics.")
# 添加 min_stat_years 参数
parser.add_argument('--market', type=str, choices=allowed_market_keys,
help=f'Set market key for statistics (allowed: {allowed_market_keys}). Default is {default_market_key}.')
# 添加 min_stat_years 参数
parser.add_argument('--min_stat_years', type=int, choices=allowed_min_stat_years,
help=f'Set minimum years for statistics (allowed: {allowed_min_stat_years}). Default is {default_min_stat_years}.')
# 添加 debug 参数
parser.add_argument('--debug', action='store_true', help='Enable debug mode (default: False).')
args = parser.parse_args()
# 如果没有提供 --min_stat_years使用默认值
min_stat_years = args.min_stat_years if args.min_stat_years else default_min_stat_years
debug = args.debug if args.debug else default_debug
market_key = args.market if args.market else default_market_key
return min_stat_years, debug, market_key
# 获取用户输入的参数
min_stat_years, debug, market_key = parse_arguments()
# 配置日志格式
formatter = logging.Formatter('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] - %(message)s')
# 动态生成日志文件名,基于 min_stat_years 的值
log_filename = f'./log/stat_yield_{market_key}_{min_stat_years}years_rate.log'
file_handler = logging.FileHandler(log_filename)
file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logging.basicConfig(level=logging.INFO, handlers=[file_handler, console_handler])
# MySQL 配置
db_config = {
'host': '172.18.0.2',
'user': 'root',
'password': 'mysqlpw',
'database': 'stockdb'
}
# 传入表名的映射
table_mapping = {
"hs300": {
"codes": "hs300",
"his_data": "hs300_qfq_his",
"stat_res": f"hs300_{min_stat_years}years_yield_stats_2410"
},
"sp500": {
"codes": "sp500",
"his_data": "sp500_qfq_his_202410",
"stat_res": f"sp500_{min_stat_years}years_yield_stats_2410"
}
}
# 连接 MySQL
connection = pymysql.connect(**db_config)
# 获取股票代码
def get_codes(table_mapping, index_name):
try:
with connection.cursor() as cursor:
if debug:
# 如果 debug 模式开启,查询一条数据
sql = f"SELECT code, code_name FROM {table_mapping[index_name]['codes']} LIMIT 1"
else:
# 否则查询所有数据
sql = f"SELECT code, code_name FROM {table_mapping[index_name]['codes']}"
cursor.execute(sql)
return cursor.fetchall()
except pymysql.MySQLError as e:
logging.error(f"Error occurred while reading {table_mapping[index_name]['codes']} : {e}", exc_info=True)
# 获取历史行情数据
def get_historical_data(table_mapping, index_name, code):
try:
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
sql = f"SELECT * FROM {table_mapping[index_name]['his_data']} WHERE code = %s ORDER BY time_key"
cursor.execute(sql, (code,))
return cursor.fetchall()
except pymysql.MySQLError as e:
logging.error(f"Error occurred while reading {table_mapping[index_name]['his_data']}: {e}", exc_info=True)
# 插入统计结果
def insert_yield_stats(connection, table_mapping, index_name, code, name, diff_year, max_entry, min_entry, avg_yield, median_yield, win_rate, annual_max_entry, annual_min_entry, annual_avg_yield, annual_median_yield, max_deficit_entry, annual_yield_variance):
try:
with connection.cursor() as cursor:
sql = f"""
INSERT INTO {table_mapping[index_name]['stat_res']}
(code, name, year_diff, max_yield_rate, max_yield_rate_start, max_yield_rate_end,
min_yield_rate, min_yield_rate_start, min_yield_rate_end, avg_yield_rate,
median_yield_rate, win_rate, annual_max_yield_rate, annual_max_yield_rate_start,
annual_max_yield_rate_end, annual_min_yield_rate, annual_min_yield_rate_start,
annual_min_yield_rate_end, annual_avg_yield_rate, annual_median_yield_rate,
max_deficit_days, max_deficit_start, max_deficit_end, annual_yield_variance)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
max_yield_rate = VALUES(max_yield_rate),
max_yield_rate_start = VALUES(max_yield_rate_start),
max_yield_rate_end = VALUES(max_yield_rate_end),
min_yield_rate = VALUES(min_yield_rate),
min_yield_rate_start = VALUES(min_yield_rate_start),
min_yield_rate_end = VALUES(min_yield_rate_end),
avg_yield_rate = VALUES(avg_yield_rate),
median_yield_rate = VALUES(median_yield_rate),
win_rate = VALUES(win_rate),
annual_max_yield_rate = VALUES(annual_max_yield_rate),
annual_max_yield_rate_start = VALUES(annual_max_yield_rate_start),
annual_max_yield_rate_end = VALUES(annual_max_yield_rate_end),
annual_min_yield_rate = VALUES(annual_min_yield_rate),
annual_min_yield_rate_start = VALUES(annual_min_yield_rate_start),
annual_min_yield_rate_end = VALUES(annual_min_yield_rate_end),
annual_avg_yield_rate = VALUES(annual_avg_yield_rate),
annual_median_yield_rate = VALUES(annual_median_yield_rate),
max_deficit_days = VALUES(max_deficit_days),
max_deficit_start = VALUES(max_deficit_start),
max_deficit_end = VALUES(max_deficit_end),
annual_yield_variance = VALUES(annual_yield_variance)
"""
cursor.execute(sql, (
code, name, int(diff_year),
float(max_entry['yield_rate']), max_entry['start_time_key'], max_entry['end_time_key'],
float(min_entry['yield_rate']), min_entry['start_time_key'], min_entry['end_time_key'],
float(avg_yield), float(median_yield), win_rate,
float(annual_max_entry['annual_yield_rate']), annual_max_entry['start_time_key'], annual_max_entry['end_time_key'],
float(annual_min_entry['annual_yield_rate']), annual_min_entry['start_time_key'], annual_min_entry['end_time_key'],
float(annual_avg_yield), float(annual_median_yield),
max_deficit_entry['max_deficit_days'], max_deficit_entry['max_deficit_start'], max_deficit_entry['max_deficit_end'],
annual_yield_variance
))
connection.commit()
except pymysql.MySQLError as e:
logging.error(f"Error occurred while inserting yield stats for code {code}: {e}", exc_info=True)
# 计算收益率并计算最长连续亏损
def calculate_yield_rate(data):
results = {}
all_entries = []
num_rows = len(data)
for i in range(num_rows):
for j in range(i + 1, num_rows):
try:
start_time_key = data[i]['time_key']
end_time_key = data[j]['time_key']
time_diff = int((end_time_key - start_time_key).days / 365.0)
if time_diff < min_stat_years:
continue
close_start = data[i]['close']
close_end = data[j]['close']
yield_rate = (close_end / close_start) - 1
annual_yield_rate = yield_rate * 365 / (end_time_key - start_time_key).days
# 找到从 data[i]['close'] 到 data[j]['close'] 之间的最大连续亏损
max_deficit_days = 0
max_deficit_start = start_time_key
max_deficit_end = end_time_key
for k in range(i + 1, j):
if data[k]['close'] > close_start:
deficit_days = (data[k]['time_key'] - start_time_key).days
max_deficit_days = deficit_days
max_deficit_end = data[k]['time_key']
break
# 如果没有找到符合条件的亏损结束点,则认为 j 是亏损结束点
if max_deficit_days == 0:
max_deficit_days = (end_time_key - start_time_key).days
max_deficit_end = end_time_key
entry = {
'diff_year': time_diff,
'start_time_key': start_time_key,
'end_time_key': end_time_key,
'yield_rate': yield_rate,
'annual_yield_rate': annual_yield_rate,
'max_deficit_days': max_deficit_days,
'max_deficit_start': max_deficit_start,
'max_deficit_end': max_deficit_end
}
all_entries.append(entry)
if time_diff not in results:
results[time_diff] = []
results[time_diff].append(entry)
except ZeroDivisionError:
logging.warning(f"Zero division error for code {data[i]['code']}")
except Exception as e:
logging.error(f"Error occurred while calculating yield rate: {e}", exc_info=True)
# 将全局最大亏损信息加入到汇总部分
results[10000] = all_entries # 汇总
return results
# 统计结果并输出
def compute_statistics(connection, table_mapping, index_name, code, name, results):
for diff_year, entries in results.items():
yield_rates = [entry['yield_rate'] for entry in entries]
annual_yield_rates = [entry['annual_yield_rate'] for entry in entries]
if yield_rates:
max_yield = max(yield_rates)
min_yield = min(yield_rates)
avg_yield = np.mean(yield_rates)
median_yield = np.median(yield_rates)
max_entry = next(entry for entry in entries if entry['yield_rate'] == max_yield)
min_entry = next(entry for entry in entries if entry['yield_rate'] == min_yield)
# 年化收益率统计
annual_max_yield = max(annual_yield_rates)
annual_min_yield = min(annual_yield_rates)
annual_avg_yield = np.mean(annual_yield_rates)
annual_median_yield = np.median(annual_yield_rates)
annual_max_entry = next(entry for entry in entries if entry['annual_yield_rate'] == annual_max_yield)
annual_min_entry = next(entry for entry in entries if entry['annual_yield_rate'] == annual_min_yield)
# 计算 win_rate
win_rate = len([r for r in yield_rates if r > 0]) / len(yield_rates)
# 计算年化收益率方差
annual_yield_variance = np.var(annual_yield_rates)
# 处理最大连续亏损
max_deficit_entry = max(entries, key=lambda x: x['max_deficit_days'])
# 插入数据库
insert_yield_stats(connection, table_mapping, index_name, code, name, diff_year,
max_entry, min_entry, avg_yield, median_yield, win_rate,
annual_max_entry, annual_min_entry, annual_avg_yield, annual_median_yield,
max_deficit_entry, annual_yield_variance)
# 主函数
def main(index_name):
try:
codes = get_codes(table_mapping, index_name)
for code_row in codes:
code, name = code_row[0], code_row[1]
logging.info(f"开始处理 {code} ({name}) 的数据")
data = get_historical_data(table_mapping, index_name, code)
if not data:
logging.warning(f"未找到 {code} 的历史数据")
continue
results = calculate_yield_rate(data)
compute_statistics(connection, table_mapping, index_name, code, name, results)
logging.info(f"完成 {code} 的处理")
except Exception as e:
logging.error(f"处理过程中出现错误: {e}", exc_info=True)
finally:
connection.close()
if __name__ == "__main__":
main(market_key)