282 lines
12 KiB
Python
282 lines
12 KiB
Python
import pymysql
|
||
import logging
|
||
import sys
|
||
import time
|
||
import numpy as np
|
||
from datetime import datetime
|
||
import argparse
|
||
|
||
# 设置默认值
|
||
default_min_stat_years = 5
|
||
allowed_min_stat_years = [3, 5] # 允许的年份统计范围
|
||
default_debug = False
|
||
default_market_key = "hs300"
|
||
allowed_market_keys = ['hs300', 'sp500']
|
||
|
||
# 配置命令行参数
|
||
def parse_arguments():
|
||
parser = argparse.ArgumentParser(description="Run stock yield statistics.")
|
||
|
||
# 添加 min_stat_years 参数
|
||
parser.add_argument('--market', type=str, choices=allowed_market_keys,
|
||
help=f'Set market key for statistics (allowed: {allowed_market_keys}). Default is {default_market_key}.')
|
||
|
||
# 添加 min_stat_years 参数
|
||
parser.add_argument('--min_stat_years', type=int, choices=allowed_min_stat_years,
|
||
help=f'Set minimum years for statistics (allowed: {allowed_min_stat_years}). Default is {default_min_stat_years}.')
|
||
|
||
# 添加 debug 参数
|
||
parser.add_argument('--debug', action='store_true', help='Enable debug mode (default: False).')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 如果没有提供 --min_stat_years,使用默认值
|
||
min_stat_years = args.min_stat_years if args.min_stat_years else default_min_stat_years
|
||
debug = args.debug if args.debug else default_debug
|
||
market_key = args.market if args.market else default_market_key
|
||
|
||
return min_stat_years, debug, market_key
|
||
|
||
# 获取用户输入的参数
|
||
min_stat_years, debug, market_key = parse_arguments()
|
||
|
||
|
||
# 配置日志格式
|
||
formatter = logging.Formatter('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] - %(message)s')
|
||
|
||
# 动态生成日志文件名,基于 min_stat_years 的值
|
||
log_filename = f'./log/stat_yield_{market_key}_{min_stat_years}years_rate.log'
|
||
file_handler = logging.FileHandler(log_filename)
|
||
file_handler.setFormatter(formatter)
|
||
|
||
console_handler = logging.StreamHandler(sys.stdout)
|
||
console_handler.setFormatter(formatter)
|
||
|
||
logging.basicConfig(level=logging.INFO, handlers=[file_handler, console_handler])
|
||
|
||
|
||
# MySQL 配置
|
||
db_config = {
|
||
'host': '172.18.0.2',
|
||
'user': 'root',
|
||
'password': 'mysqlpw',
|
||
'database': 'stockdb'
|
||
}
|
||
|
||
# 传入表名的映射
|
||
table_mapping = {
|
||
"hs300": {
|
||
"codes": "hs300",
|
||
"his_data": "hs300_qfq_his",
|
||
"stat_res": f"hs300_{min_stat_years}years_yield_stats_2410"
|
||
},
|
||
"sp500": {
|
||
"codes": "sp500",
|
||
"his_data": "sp500_qfq_his_202410",
|
||
"stat_res": f"sp500_{min_stat_years}years_yield_stats_2410"
|
||
}
|
||
}
|
||
|
||
# 连接 MySQL
|
||
connection = pymysql.connect(**db_config)
|
||
|
||
# 获取股票代码
|
||
def get_codes(table_mapping, index_name):
|
||
try:
|
||
with connection.cursor() as cursor:
|
||
if debug:
|
||
# 如果 debug 模式开启,查询一条数据
|
||
sql = f"SELECT code, code_name FROM {table_mapping[index_name]['codes']} LIMIT 1"
|
||
else:
|
||
# 否则查询所有数据
|
||
sql = f"SELECT code, code_name FROM {table_mapping[index_name]['codes']}"
|
||
cursor.execute(sql)
|
||
return cursor.fetchall()
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"Error occurred while reading {table_mapping[index_name]['codes']} : {e}", exc_info=True)
|
||
|
||
# 获取历史行情数据
|
||
def get_historical_data(table_mapping, index_name, code):
|
||
try:
|
||
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
|
||
sql = f"SELECT * FROM {table_mapping[index_name]['his_data']} WHERE code = %s ORDER BY time_key"
|
||
cursor.execute(sql, (code,))
|
||
return cursor.fetchall()
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"Error occurred while reading {table_mapping[index_name]['his_data']}: {e}", exc_info=True)
|
||
|
||
|
||
# 插入统计结果
|
||
def insert_yield_stats(connection, table_mapping, index_name, code, name, diff_year, max_entry, min_entry, avg_yield, median_yield, win_rate, annual_max_entry, annual_min_entry, annual_avg_yield, annual_median_yield, max_deficit_entry, annual_yield_variance):
|
||
try:
|
||
with connection.cursor() as cursor:
|
||
sql = f"""
|
||
INSERT INTO {table_mapping[index_name]['stat_res']}
|
||
(code, name, year_diff, max_yield_rate, max_yield_rate_start, max_yield_rate_end,
|
||
min_yield_rate, min_yield_rate_start, min_yield_rate_end, avg_yield_rate,
|
||
median_yield_rate, win_rate, annual_max_yield_rate, annual_max_yield_rate_start,
|
||
annual_max_yield_rate_end, annual_min_yield_rate, annual_min_yield_rate_start,
|
||
annual_min_yield_rate_end, annual_avg_yield_rate, annual_median_yield_rate,
|
||
max_deficit_days, max_deficit_start, max_deficit_end, annual_yield_variance)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE
|
||
max_yield_rate = VALUES(max_yield_rate),
|
||
max_yield_rate_start = VALUES(max_yield_rate_start),
|
||
max_yield_rate_end = VALUES(max_yield_rate_end),
|
||
min_yield_rate = VALUES(min_yield_rate),
|
||
min_yield_rate_start = VALUES(min_yield_rate_start),
|
||
min_yield_rate_end = VALUES(min_yield_rate_end),
|
||
avg_yield_rate = VALUES(avg_yield_rate),
|
||
median_yield_rate = VALUES(median_yield_rate),
|
||
win_rate = VALUES(win_rate),
|
||
annual_max_yield_rate = VALUES(annual_max_yield_rate),
|
||
annual_max_yield_rate_start = VALUES(annual_max_yield_rate_start),
|
||
annual_max_yield_rate_end = VALUES(annual_max_yield_rate_end),
|
||
annual_min_yield_rate = VALUES(annual_min_yield_rate),
|
||
annual_min_yield_rate_start = VALUES(annual_min_yield_rate_start),
|
||
annual_min_yield_rate_end = VALUES(annual_min_yield_rate_end),
|
||
annual_avg_yield_rate = VALUES(annual_avg_yield_rate),
|
||
annual_median_yield_rate = VALUES(annual_median_yield_rate),
|
||
max_deficit_days = VALUES(max_deficit_days),
|
||
max_deficit_start = VALUES(max_deficit_start),
|
||
max_deficit_end = VALUES(max_deficit_end),
|
||
annual_yield_variance = VALUES(annual_yield_variance)
|
||
"""
|
||
cursor.execute(sql, (
|
||
code, name, int(diff_year),
|
||
float(max_entry['yield_rate']), max_entry['start_time_key'], max_entry['end_time_key'],
|
||
float(min_entry['yield_rate']), min_entry['start_time_key'], min_entry['end_time_key'],
|
||
float(avg_yield), float(median_yield), win_rate,
|
||
float(annual_max_entry['annual_yield_rate']), annual_max_entry['start_time_key'], annual_max_entry['end_time_key'],
|
||
float(annual_min_entry['annual_yield_rate']), annual_min_entry['start_time_key'], annual_min_entry['end_time_key'],
|
||
float(annual_avg_yield), float(annual_median_yield),
|
||
max_deficit_entry['max_deficit_days'], max_deficit_entry['max_deficit_start'], max_deficit_entry['max_deficit_end'],
|
||
annual_yield_variance
|
||
))
|
||
connection.commit()
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"Error occurred while inserting yield stats for code {code}: {e}", exc_info=True)
|
||
|
||
# 计算收益率并计算最长连续亏损
|
||
def calculate_yield_rate(data):
|
||
results = {}
|
||
all_entries = []
|
||
num_rows = len(data)
|
||
|
||
for i in range(num_rows):
|
||
for j in range(i + 1, num_rows):
|
||
try:
|
||
start_time_key = data[i]['time_key']
|
||
end_time_key = data[j]['time_key']
|
||
time_diff = int((end_time_key - start_time_key).days / 365.0)
|
||
if time_diff < min_stat_years:
|
||
continue
|
||
|
||
close_start = data[i]['close']
|
||
close_end = data[j]['close']
|
||
yield_rate = (close_end / close_start) - 1
|
||
annual_yield_rate = yield_rate * 365 / (end_time_key - start_time_key).days
|
||
|
||
# 找到从 data[i]['close'] 到 data[j]['close'] 之间的最大连续亏损
|
||
max_deficit_days = 0
|
||
max_deficit_start = start_time_key
|
||
max_deficit_end = end_time_key
|
||
for k in range(i + 1, j):
|
||
if data[k]['close'] > close_start:
|
||
deficit_days = (data[k]['time_key'] - start_time_key).days
|
||
max_deficit_days = deficit_days
|
||
max_deficit_end = data[k]['time_key']
|
||
break
|
||
|
||
# 如果没有找到符合条件的亏损结束点,则认为 j 是亏损结束点
|
||
if max_deficit_days == 0:
|
||
max_deficit_days = (end_time_key - start_time_key).days
|
||
max_deficit_end = end_time_key
|
||
|
||
entry = {
|
||
'diff_year': time_diff,
|
||
'start_time_key': start_time_key,
|
||
'end_time_key': end_time_key,
|
||
'yield_rate': yield_rate,
|
||
'annual_yield_rate': annual_yield_rate,
|
||
'max_deficit_days': max_deficit_days,
|
||
'max_deficit_start': max_deficit_start,
|
||
'max_deficit_end': max_deficit_end
|
||
}
|
||
all_entries.append(entry)
|
||
|
||
if time_diff not in results:
|
||
results[time_diff] = []
|
||
results[time_diff].append(entry)
|
||
except ZeroDivisionError:
|
||
logging.warning(f"Zero division error for code {data[i]['code']}")
|
||
except Exception as e:
|
||
logging.error(f"Error occurred while calculating yield rate: {e}", exc_info=True)
|
||
|
||
# 将全局最大亏损信息加入到汇总部分
|
||
results[10000] = all_entries # 汇总
|
||
return results
|
||
|
||
# 统计结果并输出
|
||
def compute_statistics(connection, table_mapping, index_name, code, name, results):
|
||
for diff_year, entries in results.items():
|
||
yield_rates = [entry['yield_rate'] for entry in entries]
|
||
annual_yield_rates = [entry['annual_yield_rate'] for entry in entries]
|
||
|
||
if yield_rates:
|
||
max_yield = max(yield_rates)
|
||
min_yield = min(yield_rates)
|
||
avg_yield = np.mean(yield_rates)
|
||
median_yield = np.median(yield_rates)
|
||
|
||
max_entry = next(entry for entry in entries if entry['yield_rate'] == max_yield)
|
||
min_entry = next(entry for entry in entries if entry['yield_rate'] == min_yield)
|
||
|
||
# 年化收益率统计
|
||
annual_max_yield = max(annual_yield_rates)
|
||
annual_min_yield = min(annual_yield_rates)
|
||
annual_avg_yield = np.mean(annual_yield_rates)
|
||
annual_median_yield = np.median(annual_yield_rates)
|
||
|
||
annual_max_entry = next(entry for entry in entries if entry['annual_yield_rate'] == annual_max_yield)
|
||
annual_min_entry = next(entry for entry in entries if entry['annual_yield_rate'] == annual_min_yield)
|
||
|
||
# 计算 win_rate
|
||
win_rate = len([r for r in yield_rates if r > 0]) / len(yield_rates)
|
||
|
||
# 计算年化收益率方差
|
||
annual_yield_variance = np.var(annual_yield_rates)
|
||
|
||
# 处理最大连续亏损
|
||
max_deficit_entry = max(entries, key=lambda x: x['max_deficit_days'])
|
||
|
||
# 插入数据库
|
||
insert_yield_stats(connection, table_mapping, index_name, code, name, diff_year,
|
||
max_entry, min_entry, avg_yield, median_yield, win_rate,
|
||
annual_max_entry, annual_min_entry, annual_avg_yield, annual_median_yield,
|
||
max_deficit_entry, annual_yield_variance)
|
||
|
||
# 主函数
|
||
def main(index_name):
|
||
try:
|
||
codes = get_codes(table_mapping, index_name)
|
||
for code_row in codes:
|
||
code, name = code_row[0], code_row[1]
|
||
logging.info(f"开始处理 {code} ({name}) 的数据")
|
||
|
||
data = get_historical_data(table_mapping, index_name, code)
|
||
if not data:
|
||
logging.warning(f"未找到 {code} 的历史数据")
|
||
continue
|
||
|
||
results = calculate_yield_rate(data)
|
||
compute_statistics(connection, table_mapping, index_name, code, name, results)
|
||
|
||
logging.info(f"完成 {code} 的处理")
|
||
except Exception as e:
|
||
logging.error(f"处理过程中出现错误: {e}", exc_info=True)
|
||
finally:
|
||
connection.close()
|
||
|
||
if __name__ == "__main__":
|
||
main(market_key) |