modify scripts

This commit is contained in:
2025-08-12 11:17:50 +08:00
parent 7b7799307f
commit 5a4cbb5b16
4 changed files with 324 additions and 125 deletions

View File

@ -32,6 +32,8 @@ from src.sqlalchemy.models.stockdb import DailySanpModel, Base
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from src.sqlalchemy.config import global_db_url
from .trading_day import TradingDayChecker
from src.utils.send_to_wecom import send_to_wecom
# 配置日志
logger.setup_logging()
@ -40,7 +42,15 @@ current_date = datetime.now().strftime("%Y%m%d")
current_year = datetime.now().strftime("%Y")
res_dir = global_stock_data_dir
debug = False
# 拉取数据
market_fs = {
"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048",
"hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2",
"us": "m:105,m:106,m:107"
}
# 刷新代码列表,并返回
def flush_code_map():
code_id_map_em_df = his_em.code_id_map_em()
@ -48,10 +58,11 @@ def flush_code_map():
return code_id_map_em_df
# 获取所有市场的当年股价快照,带重试机制。
def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame:
def fetch_snap_all(market_id, trading_date) -> pd.DataFrame:
# 检查文件是否存在
file_name = f'{res_dir}/snapshot_em_{current_date}.csv'
if os.path.exists(file_name):
os.makedirs(res_dir, exist_ok=True)
file_name = f'{res_dir}/snapshot_em_{market_id}_{trading_date}.csv'
if os.path.exists(file_name) and debug:
try:
# 读取本地文件
snap_data = pd.read_csv(file_name, encoding='utf-8')
@ -60,22 +71,22 @@ def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame:
except Exception as e:
logging.warning(f"读取本地文件失败: {e},将重新拉取数据\n\n")
# 拉取数据
market_fs = {"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048",
"hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2",
"us": "m:105,m:106,m:107"}
result = pd.DataFrame()
for market_id, fs in market_fs.items():
df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id)
if df.empty:
logging.warning(f'{market_id} empty data. please check.')
return pd.DataFrame()
else:
logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}')
# 关键步骤添加market_id列值为当前市场标识
df['market_id'] = market_id # 新增一列,记录数据所属市场
result = pd.concat([result, df], ignore_index=True)
fs = market_fs.get(market_id, None)
if not fs:
logging.error(f"未找到市场 {market_id} 的数据源配置,请检查 market_fs 配置")
return result
df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id)
if df.empty:
logging.warning(f'{market_id} empty data. please check.')
return pd.DataFrame()
else:
logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}')
# 关键步骤添加market_id列值为当前市场标识
df['market_id'] = market_id # 新增一列,记录数据所属市场
df['curr_date'] = trading_date
result = pd.concat([result, df], ignore_index=True)
result.to_csv(file_name, index=False, encoding='utf-8')
logging.info(f"get snapshot data and write to file: {file_name}\n\n")
@ -149,7 +160,7 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url):
# 创建股票数据对象
stock = DailySanpModel(
code=row['代码'],
curr_date=current_date, # TODO: 怎么判断当前的数据是哪一天的? 要看当前时间是否已经开盘,还是在盘前,还是前一个交易日的?
curr_date=row['curr_date'],
name=row['名称'],
market_id=row['market_id'],
code_prefix=row['代码前缀'],
@ -217,15 +228,47 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url):
session.close()
def main():
# 获取快照数据
snap_data = fetch_snap_all()
if snap_data.empty:
logging.error(f"fetching snapshot data error!")
return
em_code_map = {row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()}
def main(list, args_debug, notify):
global debug
debug = args_debug
insert_stock_data_to_db(dataframe=snap_data)
# 获取快照数据
market_list = list.split(',')
if not market_list:
logging.error("未指定市场列表,请使用 --list 参数指定市场(如 cn,hk,us")
return
em_code_map = {}
for market_id in market_list:
# 获取交易日期
trading_day_checker = TradingDayChecker()
trading_date = trading_day_checker.get_trading_date(market_id.upper())
if not trading_date:
logging.error(f"无法获取 {market_id} 市场的交易日期")
continue
# 获取快照数据
snap_data = fetch_snap_all(market_id, trading_date)
if snap_data.empty:
logging.error(f"未获取到 {market_id} 市场的快照数据")
continue
if snap_data.empty:
logging.error(f"fetching snapshot data error for {market_id}!")
continue
insert_stock_data_to_db(dataframe=snap_data)
logging.info(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}")
if notify:
send_to_wecom(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}")
em_code_map.update({row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()})
time.sleep(5)
if __name__ == "__main__":
main()
# 命令行参数处理
parser = argparse.ArgumentParser(description='获取指定市场的快照数据并存储到数据库')
parser.add_argument('--list', type=str, default='cn,hk,us', help='Stocklist to process (cn,hk,us)')
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
parser.add_argument('--notify', action='store_true', help='notify to wecom')
args = parser.parse_args()
main(args.list, args.debug, args.notify)