modify scripts
This commit is contained in:
@ -32,6 +32,8 @@ from src.sqlalchemy.models.stockdb import DailySanpModel, Base
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from src.sqlalchemy.config import global_db_url
|
||||
from .trading_day import TradingDayChecker
|
||||
from src.utils.send_to_wecom import send_to_wecom
|
||||
|
||||
# 配置日志
|
||||
logger.setup_logging()
|
||||
@ -40,7 +42,15 @@ current_date = datetime.now().strftime("%Y%m%d")
|
||||
current_year = datetime.now().strftime("%Y")
|
||||
|
||||
res_dir = global_stock_data_dir
|
||||
debug = False
|
||||
|
||||
# 拉取数据
|
||||
market_fs = {
|
||||
"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048",
|
||||
"hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2",
|
||||
"us": "m:105,m:106,m:107"
|
||||
}
|
||||
|
||||
# 刷新代码列表,并返回
|
||||
def flush_code_map():
|
||||
code_id_map_em_df = his_em.code_id_map_em()
|
||||
@ -48,10 +58,11 @@ def flush_code_map():
|
||||
return code_id_map_em_df
|
||||
|
||||
# 获取所有市场的当年股价快照,带重试机制。
|
||||
def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame:
|
||||
def fetch_snap_all(market_id, trading_date) -> pd.DataFrame:
|
||||
# 检查文件是否存在
|
||||
file_name = f'{res_dir}/snapshot_em_{current_date}.csv'
|
||||
if os.path.exists(file_name):
|
||||
os.makedirs(res_dir, exist_ok=True)
|
||||
file_name = f'{res_dir}/snapshot_em_{market_id}_{trading_date}.csv'
|
||||
if os.path.exists(file_name) and debug:
|
||||
try:
|
||||
# 读取本地文件
|
||||
snap_data = pd.read_csv(file_name, encoding='utf-8')
|
||||
@ -60,22 +71,22 @@ def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame:
|
||||
except Exception as e:
|
||||
logging.warning(f"读取本地文件失败: {e},将重新拉取数据\n\n")
|
||||
|
||||
# 拉取数据
|
||||
market_fs = {"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048",
|
||||
"hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2",
|
||||
"us": "m:105,m:106,m:107"}
|
||||
|
||||
result = pd.DataFrame()
|
||||
for market_id, fs in market_fs.items():
|
||||
df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id)
|
||||
if df.empty:
|
||||
logging.warning(f'{market_id} empty data. please check.')
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}')
|
||||
# 关键步骤:添加market_id列,值为当前市场标识
|
||||
df['market_id'] = market_id # 新增一列,记录数据所属市场
|
||||
result = pd.concat([result, df], ignore_index=True)
|
||||
fs = market_fs.get(market_id, None)
|
||||
if not fs:
|
||||
logging.error(f"未找到市场 {market_id} 的数据源配置,请检查 market_fs 配置")
|
||||
return result
|
||||
|
||||
df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id)
|
||||
if df.empty:
|
||||
logging.warning(f'{market_id} empty data. please check.')
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}')
|
||||
# 关键步骤:添加market_id列,值为当前市场标识
|
||||
df['market_id'] = market_id # 新增一列,记录数据所属市场
|
||||
df['curr_date'] = trading_date
|
||||
result = pd.concat([result, df], ignore_index=True)
|
||||
|
||||
result.to_csv(file_name, index=False, encoding='utf-8')
|
||||
logging.info(f"get snapshot data and write to file: {file_name}\n\n")
|
||||
@ -149,7 +160,7 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url):
|
||||
# 创建股票数据对象
|
||||
stock = DailySanpModel(
|
||||
code=row['代码'],
|
||||
curr_date=current_date, # TODO: 怎么判断当前的数据是哪一天的? 要看当前时间是否已经开盘,还是在盘前,还是前一个交易日的?
|
||||
curr_date=row['curr_date'],
|
||||
name=row['名称'],
|
||||
market_id=row['market_id'],
|
||||
code_prefix=row['代码前缀'],
|
||||
@ -217,15 +228,47 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url):
|
||||
session.close()
|
||||
|
||||
|
||||
def main():
|
||||
# 获取快照数据
|
||||
snap_data = fetch_snap_all()
|
||||
if snap_data.empty:
|
||||
logging.error(f"fetching snapshot data error!")
|
||||
return
|
||||
em_code_map = {row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()}
|
||||
def main(list, args_debug, notify):
|
||||
global debug
|
||||
debug = args_debug
|
||||
|
||||
insert_stock_data_to_db(dataframe=snap_data)
|
||||
# 获取快照数据
|
||||
market_list = list.split(',')
|
||||
if not market_list:
|
||||
logging.error("未指定市场列表,请使用 --list 参数指定市场(如 cn,hk,us)")
|
||||
return
|
||||
em_code_map = {}
|
||||
for market_id in market_list:
|
||||
# 获取交易日期
|
||||
trading_day_checker = TradingDayChecker()
|
||||
trading_date = trading_day_checker.get_trading_date(market_id.upper())
|
||||
if not trading_date:
|
||||
logging.error(f"无法获取 {market_id} 市场的交易日期")
|
||||
continue
|
||||
|
||||
# 获取快照数据
|
||||
snap_data = fetch_snap_all(market_id, trading_date)
|
||||
if snap_data.empty:
|
||||
logging.error(f"未获取到 {market_id} 市场的快照数据")
|
||||
continue
|
||||
if snap_data.empty:
|
||||
logging.error(f"fetching snapshot data error for {market_id}!")
|
||||
continue
|
||||
insert_stock_data_to_db(dataframe=snap_data)
|
||||
logging.info(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}")
|
||||
|
||||
if notify:
|
||||
send_to_wecom(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}")
|
||||
|
||||
em_code_map.update({row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()})
|
||||
time.sleep(5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# 命令行参数处理
|
||||
parser = argparse.ArgumentParser(description='获取指定市场的快照数据并存储到数据库')
|
||||
parser.add_argument('--list', type=str, default='cn,hk,us', help='Stocklist to process (cn,hk,us)')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
|
||||
parser.add_argument('--notify', action='store_true', help='notify to wecom')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.list, args.debug, args.notify)
|
||||
Reference in New Issue
Block a user