From 5a4cbb5b162572eda767b0afe1e4813fa6a62953 Mon Sep 17 00:00:00 2001 From: sophon Date: Tue, 12 Aug 2025 11:17:50 +0800 Subject: [PATCH] modify scripts --- .gitignore | 2 +- src/static/daily_snap_em.py | 99 +++++++++++----- src/static/trading_day.py | 221 ++++++++++++++++++++---------------- src/utils/send_to_wecom.py | 127 +++++++++++++++++++++ 4 files changed, 324 insertions(+), 125 deletions(-) create mode 100644 src/utils/send_to_wecom.py diff --git a/.gitignore b/.gitignore index 61af784..40fc3ec 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,4 @@ reports_em/pdfs/ reports_em/raw/ # 忽略sqlachemy生成的文件 -alembic/versions/ +**versions/ diff --git a/src/static/daily_snap_em.py b/src/static/daily_snap_em.py index 51dbf5a..edb1ddb 100644 --- a/src/static/daily_snap_em.py +++ b/src/static/daily_snap_em.py @@ -32,6 +32,8 @@ from src.sqlalchemy.models.stockdb import DailySanpModel, Base from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from src.sqlalchemy.config import global_db_url +from .trading_day import TradingDayChecker +from src.utils.send_to_wecom import send_to_wecom # 配置日志 logger.setup_logging() @@ -40,7 +42,15 @@ current_date = datetime.now().strftime("%Y%m%d") current_year = datetime.now().strftime("%Y") res_dir = global_stock_data_dir +debug = False +# 拉取数据 +market_fs = { + "cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048", + "hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2", + "us": "m:105,m:106,m:107" +} + # 刷新代码列表,并返回 def flush_code_map(): code_id_map_em_df = his_em.code_id_map_em() @@ -48,10 +58,11 @@ def flush_code_map(): return code_id_map_em_df # 获取所有市场的当年股价快照,带重试机制。 -def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame: +def fetch_snap_all(market_id, trading_date) -> pd.DataFrame: # 检查文件是否存在 - file_name = f'{res_dir}/snapshot_em_{current_date}.csv' - if os.path.exists(file_name): + os.makedirs(res_dir, exist_ok=True) + file_name = f'{res_dir}/snapshot_em_{market_id}_{trading_date}.csv' + if os.path.exists(file_name) and debug: try: # 读取本地文件 snap_data = pd.read_csv(file_name, encoding='utf-8') @@ -60,22 +71,22 @@ def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame: except Exception as e: logging.warning(f"读取本地文件失败: {e},将重新拉取数据\n\n") - # 拉取数据 - market_fs = {"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048", - "hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2", - "us": "m:105,m:106,m:107"} - result = pd.DataFrame() - for market_id, fs in market_fs.items(): - df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id) - if df.empty: - logging.warning(f'{market_id} empty data. please check.') - return pd.DataFrame() - else: - logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}') - # 关键步骤:添加market_id列,值为当前市场标识 - df['market_id'] = market_id # 新增一列,记录数据所属市场 - result = pd.concat([result, df], ignore_index=True) + fs = market_fs.get(market_id, None) + if not fs: + logging.error(f"未找到市场 {market_id} 的数据源配置,请检查 market_fs 配置") + return result + + df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id) + if df.empty: + logging.warning(f'{market_id} empty data. please check.') + return pd.DataFrame() + else: + logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}') + # 关键步骤:添加market_id列,值为当前市场标识 + df['market_id'] = market_id # 新增一列,记录数据所属市场 + df['curr_date'] = trading_date + result = pd.concat([result, df], ignore_index=True) result.to_csv(file_name, index=False, encoding='utf-8') logging.info(f"get snapshot data and write to file: {file_name}\n\n") @@ -149,7 +160,7 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url): # 创建股票数据对象 stock = DailySanpModel( code=row['代码'], - curr_date=current_date, # TODO: 怎么判断当前的数据是哪一天的? 要看当前时间是否已经开盘,还是在盘前,还是前一个交易日的? + curr_date=row['curr_date'], name=row['名称'], market_id=row['market_id'], code_prefix=row['代码前缀'], @@ -217,15 +228,47 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url): session.close() -def main(): - # 获取快照数据 - snap_data = fetch_snap_all() - if snap_data.empty: - logging.error(f"fetching snapshot data error!") - return - em_code_map = {row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()} +def main(list, args_debug, notify): + global debug + debug = args_debug - insert_stock_data_to_db(dataframe=snap_data) + # 获取快照数据 + market_list = list.split(',') + if not market_list: + logging.error("未指定市场列表,请使用 --list 参数指定市场(如 cn,hk,us)") + return + em_code_map = {} + for market_id in market_list: + # 获取交易日期 + trading_day_checker = TradingDayChecker() + trading_date = trading_day_checker.get_trading_date(market_id.upper()) + if not trading_date: + logging.error(f"无法获取 {market_id} 市场的交易日期") + continue + + # 获取快照数据 + snap_data = fetch_snap_all(market_id, trading_date) + if snap_data.empty: + logging.error(f"未获取到 {market_id} 市场的快照数据") + continue + if snap_data.empty: + logging.error(f"fetching snapshot data error for {market_id}!") + continue + insert_stock_data_to_db(dataframe=snap_data) + logging.info(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}") + + if notify: + send_to_wecom(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}") + + em_code_map.update({row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()}) + time.sleep(5) if __name__ == "__main__": - main() \ No newline at end of file + # 命令行参数处理 + parser = argparse.ArgumentParser(description='获取指定市场的快照数据并存储到数据库') + parser.add_argument('--list', type=str, default='cn,hk,us', help='Stocklist to process (cn,hk,us)') + parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)') + parser.add_argument('--notify', action='store_true', help='notify to wecom') + args = parser.parse_args() + + main(args.list, args.debug, args.notify) \ No newline at end of file diff --git a/src/static/trading_day.py b/src/static/trading_day.py index deccf94..c0feedc 100644 --- a/src/static/trading_day.py +++ b/src/static/trading_day.py @@ -7,116 +7,145 @@ from src.sqlalchemy.models.stockdb import FutuTradingDayModel from src.sqlalchemy.config import global_db_url -# 市场时间配置(东八区 Asia/Shanghai) -# 新增盘前开始时间(pre_market_start)作为判断临界点 -MARKET_HOURS = { - # A股(沪深):盘前集合竞价9:15开始,实际交易9:30-15:00 - "CN": { - "pre_market_start": time(9, 15), # 盘前开始时间(集合竞价) - "morning_start": time(9, 30), - "morning_end": time(11, 30), - "afternoon_start": time(13, 0), - "afternoon_end": time(15, 0) - }, - # 港股:盘前竞价9:00开始,实际交易9:30-16:00 - "HK": { - "pre_market_start": time(9, 0), # 盘前开始时间 - "morning_start": time(9, 30), - "morning_end": time(12, 0), - "afternoon_start": time(13, 0), - "afternoon_end": time(16, 0) - }, - # 美股:盘前交易4:00-9:30(纽约时间),对应东八区夏令时16:00-21:30,冬令时17:00-22:30 - "US": { - # 东八区盘前开始时间(夏令时/冬令时) - "dst_pre_start": time(16, 0), # 夏令时盘前开始 - "std_pre_start": time(17, 0), # 冬令时盘前开始 - # 东八区交易时间(同之前) - "dst_start": time(21, 30), - "dst_end": time(4, 0), - "std_start": time(22, 30), - "std_end": time(5, 0) - } -} - -def get_trading_date(market: str, db_session: Session) -> str: - """ - 根据新逻辑返回交易日期: - 1. 若当前日期是交易日,且当前时间 >= 盘前开始时间 → 取当前交易日 - 2. 否则(非交易日,或交易日但未到盘前时间) → 取前一交易日 +class TradingDayChecker: + """交易日检查器,封装交易日相关的判断逻辑""" - 参数: - market: 市场标识(CN=A股, HK=港股, US=美股) - db_session: SQLAlchemy数据库会话对象 - 返回: - 交易日期字符串(YYYY-MM-DD) - """ - # 1. 获取东八区当前时间和日期 - tz_sh = ZoneInfo("Asia/Shanghai") - now = datetime.now(tz_sh) # 含时区的当前时间 - current_date: date = now.date() # 东八区当前日期(仅日期) - current_time: time = now.time() # 东八区当前时间(仅时间) + # 市场标识静态变量(对外暴露的公共常量) + MARKET_CN = "CN" # A股(沪深市场) + MARKET_HK = "HK" # 港股市场 + MARKET_US = "US" # 美股市场 + # 市场时间配置(东八区 Asia/Shanghai) + MARKET_HOURS = { + # A股(沪深):盘前集合竞价9:15开始,实际交易9:30-15:00 + MARKET_CN: { + "pre_market_start": time(9, 15), # 盘前开始时间(集合竞价) + "morning_start": time(9, 30), + "morning_end": time(11, 30), + "afternoon_start": time(13, 0), + "afternoon_end": time(15, 0) + }, + # 港股:盘前竞价9:00开始,实际交易9:30-16:00 + MARKET_HK: { + "pre_market_start": time(9, 0), # 盘前开始时间 + "morning_start": time(9, 30), + "morning_end": time(12, 0), + "afternoon_start": time(13, 0), + "afternoon_end": time(16, 0) + }, + # 美股:盘前交易4:00-9:30(纽约时间),对应东八区夏令时16:00-21:30,冬令时17:00-22:30 + MARKET_US: { + # 东八区盘前开始时间(夏令时/冬令时) + "dst_pre_start": time(16, 0), # 夏令时盘前开始 + "std_pre_start": time(17, 0), # 冬令时盘前开始 + # 东八区交易时间 + "dst_start": time(21, 30), + "dst_end": time(4, 0), + "std_start": time(22, 30), + "std_end": time(5, 0) + } + } - # 2. 判断当前日期是否为该市场的交易日 - def is_trading_day(market: str, date: date, session: Session) -> bool: - """检查指定日期是否为该市场的交易日""" - return session.query(FutuTradingDayModel).filter( + def __init__(self, db_url=global_db_url): + """ + 初始化交易日检查器 + :param db_session: SQLAlchemy数据库会话对象 + """ + engine = create_engine(db_url) + Session = sessionmaker(bind=engine) + self.db_session = Session() + + def __del__(self): + try: + self.db_session.close() + except Exception: + pass # 避免销毁时抛出异常影响程序退出 + + def is_trading_day(self, market: str, check_date: date) -> bool: + """ + 检查指定日期是否为该市场的交易日 + :param market: 市场标识(CN=A股, HK=港股, US=美股) + :param check_date: 要检查的日期 + :return: 是否为交易日 + """ + return self.db_session.query(FutuTradingDayModel).filter( FutuTradingDayModel.market == market, - FutuTradingDayModel.trade_date == date + FutuTradingDayModel.trade_date == check_date ).first() is not None + def get_trading_date(self, market: str) -> str: + """ + 根据当前时间判断并返回目标交易日期 + 1. 若当前日期是交易日,且当前时间 >= 盘前开始时间 → 取当前交易日 + 2. 否则(非交易日,或交易日但未到盘前时间) → 取前一交易日 + :param market: 市场标识(CN=A股, HK=港股, US=美股) + :return: 交易日期字符串(YYYY-MM-DD) + """ + # 获取东八区当前时间和日期 + tz_sh = ZoneInfo("Asia/Shanghai") + if market == self.MARKET_US: + # 美股需要使用纽约时区来判断夏令时 + tz_sh = ZoneInfo("America/New_York") + now = datetime.now(tz_sh) + current_date: date = now.date() + current_time: time = now.time() + #print(f"当前时间: {now}, market:{market}, 当前日期: {current_date}, 当前时间: {current_time}") - # 3. 获取该市场的盘前开始时间(东八区) - def get_pre_market_start(market: str, now: datetime) -> time: - """根据市场和当前时间(判断夏令时)返回盘前开始时间""" - if market == "US": + # 判断当前是否为交易日 + current_is_trading_day = self.is_trading_day(market, current_date) + + # 获取盘前开始时间 + pre_market_start = self._get_pre_market_start(market, now) + is_after_pre_market = current_time >= pre_market_start + + # 确定目标交易日期 + if current_is_trading_day and is_after_pre_market: + target_date = current_date + else: + # 查询前一交易日 + prev_trading_day = self.db_session.query(FutuTradingDayModel.trade_date).filter( + FutuTradingDayModel.market == market, + FutuTradingDayModel.trade_date < current_date + ).order_by(FutuTradingDayModel.trade_date.desc()).first() + + if not prev_trading_day: + raise ValueError(f"未查询到{market}市场的前一交易日数据") + target_date = prev_trading_day.trade_date + + return target_date.strftime("%Y%m%d") + + def _get_pre_market_start(self, market: str, now: datetime) -> time: + """ + 内部方法:获取指定市场的盘前开始时间(东八区) + :param market: 市场标识 + :param now: 当前时间(含时区) + :return: 盘前开始时间 + """ + if market == self.MARKET_US: # 美股需根据纽约时区判断夏令时 tz_ny = ZoneInfo("America/New_York") now_ny = now.astimezone(tz_ny) - is_dst = now_ny.dst() != timedelta(0) # 夏令时判断 - return MARKET_HOURS["US"]["dst_pre_start"] if is_dst else MARKET_HOURS["US"]["std_pre_start"] + is_dst = now_ny.dst() != timedelta(0) + return self.MARKET_HOURS[self.MARKET_US]["dst_pre_start"] if is_dst else self.MARKET_HOURS[self.MARKET_US]["std_pre_start"] else: - # A股/港股直接返回配置的盘前时间 - return MARKET_HOURS[market]["pre_market_start"] + # A股/港股直接使用配置的盘前时间 + return self.MARKET_HOURS[market]["pre_market_start"] - # 4. 核心逻辑判断 - current_is_trading_day = is_trading_day(market, current_date, db_session) - pre_market_start = get_pre_market_start(market, now) - is_after_pre_market = current_time >= pre_market_start # 当前时间是否过了盘前开始时间 - - - # 5. 确定查询条件 - if current_is_trading_day and is_after_pre_market: - # 情况1:当前是交易日,且已过盘前时间 → 取当前交易日 - target_date = current_date - else: - # 情况2:非交易日,或未到盘前时间 → 取前一交易日 - # 查询小于当前日期的最大交易日 - prev_trading_day = db_session.query(FutuTradingDayModel.trade_date).filter( - FutuTradingDayModel.market == market, - FutuTradingDayModel.trade_date < current_date - ).order_by(FutuTradingDayModel.trade_date.desc()).first() - - if not prev_trading_day: - raise ValueError(f"未查询到{market}市场的前一交易日数据") - target_date = prev_trading_day.trade_date - - - return target_date.strftime("%Y-%m-%d") - -# 示例:获取各市场的目标交易日期 - +# 示例用法 if __name__ == "__main__": - engine = create_engine(global_db_url) - Session = sessionmaker(bind=engine) - db_session = Session() try: - # 分别获取三个市场的交易日期 - print(f"A股目标交易日期:{get_trading_date('CN', db_session)}") - print(f"港股目标交易日期:{get_trading_date('HK', db_session)}") - print(f"美股目标交易日期:{get_trading_date('US', db_session)}") + checker = TradingDayChecker() + + # 示例:检查今天是否为A股交易日 + today = date.today() + print(f"今天是否为A股交易日: {checker.is_trading_day(TradingDayChecker.MARKET_CN, today)}") + + # 示例:获取各市场目标交易日期 + print(f"A股目标交易日期:{checker.get_trading_date(TradingDayChecker.MARKET_CN)}") + print(f"港股目标交易日期:{checker.get_trading_date(TradingDayChecker.MARKET_HK)}") + print(f"美股目标交易日期:{checker.get_trading_date(TradingDayChecker.MARKET_US)}") finally: - db_session.close() \ No newline at end of file + pass + \ No newline at end of file diff --git a/src/utils/send_to_wecom.py b/src/utils/send_to_wecom.py new file mode 100644 index 0000000..9c6ef84 --- /dev/null +++ b/src/utils/send_to_wecom.py @@ -0,0 +1,127 @@ +import requests +import time +import json +import sys + +# 企业微信相关信息 +CORP_ID = 'ww5d7d350d9b8c0be3' +SECRET = 'YhagYQpaNIK9j1ATopgKNQhw3D13mpGZ64YVr23Je-A' +AGENT_ID = '1000003' + +# 获取 access_token +def get_access_token(): + url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={CORP_ID}&corpsecret={SECRET}' + response = requests.get(url) + result = response.json() + if result.get('errcode') == 0: + return result.get('access_token') + else: + print(f"获取 access_token 失败: {result.get('errmsg')}") + return None + +# 发送消息到企业微信 +def send_message(access_token, message, touser=None, toparty=None, totag=None): + url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}' + data = { + "msgtype": "text", + "agentid": AGENT_ID, + "text": { + "content": message + }, + "safe": 0 + } + if touser: + data["touser"] = touser + if toparty: + data["toparty"] = toparty + if totag: + data["totag"] = totag + + response = requests.post(url, json=data) + result = response.json() + if result.get('errcode') == 0: + print("消息发送成功") + else: + print(f"消息发送失败: {result.get('errmsg')}") + +def pretty_print_json(data, n=10, indent=4, sort_keys=False): + """ + 以美化格式打印数组的前n个元素,其他元素用"..."表示 + + 参数: + - data: 要打印的数据(应为数组) + - n: 要显示的元素数量 + - indent: 缩进空格数 + - sort_keys: 是否按键排序 + """ + try: + # 处理非数组数据 + if not isinstance(data, list): + formatted = json.dumps(data, indent=indent, ensure_ascii=False, sort_keys=sort_keys) + return formatted + + # 复制原始数据,避免修改原数组 + data_copy = data.copy() + + # 切片取前n个元素 + first_n_elements = data_copy[:n] + + # 如果数组长度超过n,添加"..."标记 + if len(data) > n: + result = first_n_elements + ["... ({} more elements)".format(len(data) - n)] + else: + result = first_n_elements + + # 格式化输出 + formatted = json.dumps(result, indent=indent, ensure_ascii=False, sort_keys=sort_keys) + return formatted + + except TypeError as e: + print(f"错误:无法格式化数据。详情:{e}") + return str(data) + except Exception as e: + print(f"格式化时发生意外错误:{e}") + return str(data) + +def is_json(s): + """判断字符串是否可以解析为JSON""" + try: + json.loads(s) + return True + except json.JSONDecodeError: + return False + +# 主函数 +def send_to_wecom(report_content=None): + # 模拟数据报表内容 + if report_content is None: + report_content = "这是第一行\n这是第二行\n这是第三行" + else: + # 处理转义字符 + report_content = report_content.encode().decode('unicode_escape') + + # 判断是否为JSON并格式化 + if is_json(report_content): + try: + parsed_data = json.loads(report_content) + report_content = pretty_print_json(parsed_data) + except Exception as e: + print(f"JSON解析或格式化失败: {e}") + # 解析失败时保持原始内容 + + # 获取 access_token + access_token = get_access_token() + if access_token: + # 示例:发送给特定人员 + send_message(access_token, report_content, touser='oscar') + # 示例:发送给特定部门 + # send_message(access_token, report_content, toparty='department1|department2') + # 示例:发送给特定标签 + # send_message(access_token, report_content, totag='tag1|tag2') + +if __name__ == "__main__": + if len(sys.argv) > 1: + message = sys.argv[1] + send_to_wecom(message) + else: + send_to_wecom()