modify scripts
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -14,4 +14,4 @@ reports_em/pdfs/
|
||||
reports_em/raw/
|
||||
|
||||
# 忽略sqlachemy生成的文件
|
||||
alembic/versions/
|
||||
**versions/
|
||||
|
||||
@ -32,6 +32,8 @@ from src.sqlalchemy.models.stockdb import DailySanpModel, Base
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from src.sqlalchemy.config import global_db_url
|
||||
from .trading_day import TradingDayChecker
|
||||
from src.utils.send_to_wecom import send_to_wecom
|
||||
|
||||
# 配置日志
|
||||
logger.setup_logging()
|
||||
@ -40,6 +42,14 @@ current_date = datetime.now().strftime("%Y%m%d")
|
||||
current_year = datetime.now().strftime("%Y")
|
||||
|
||||
res_dir = global_stock_data_dir
|
||||
debug = False
|
||||
|
||||
# 拉取数据
|
||||
market_fs = {
|
||||
"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048",
|
||||
"hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2",
|
||||
"us": "m:105,m:106,m:107"
|
||||
}
|
||||
|
||||
# 刷新代码列表,并返回
|
||||
def flush_code_map():
|
||||
@ -48,10 +58,11 @@ def flush_code_map():
|
||||
return code_id_map_em_df
|
||||
|
||||
# 获取所有市场的当年股价快照,带重试机制。
|
||||
def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame:
|
||||
def fetch_snap_all(market_id, trading_date) -> pd.DataFrame:
|
||||
# 检查文件是否存在
|
||||
file_name = f'{res_dir}/snapshot_em_{current_date}.csv'
|
||||
if os.path.exists(file_name):
|
||||
os.makedirs(res_dir, exist_ok=True)
|
||||
file_name = f'{res_dir}/snapshot_em_{market_id}_{trading_date}.csv'
|
||||
if os.path.exists(file_name) and debug:
|
||||
try:
|
||||
# 读取本地文件
|
||||
snap_data = pd.read_csv(file_name, encoding='utf-8')
|
||||
@ -60,13 +71,12 @@ def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame:
|
||||
except Exception as e:
|
||||
logging.warning(f"读取本地文件失败: {e},将重新拉取数据\n\n")
|
||||
|
||||
# 拉取数据
|
||||
market_fs = {"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048",
|
||||
"hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2",
|
||||
"us": "m:105,m:106,m:107"}
|
||||
|
||||
result = pd.DataFrame()
|
||||
for market_id, fs in market_fs.items():
|
||||
fs = market_fs.get(market_id, None)
|
||||
if not fs:
|
||||
logging.error(f"未找到市场 {market_id} 的数据源配置,请检查 market_fs 配置")
|
||||
return result
|
||||
|
||||
df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id)
|
||||
if df.empty:
|
||||
logging.warning(f'{market_id} empty data. please check.')
|
||||
@ -75,6 +85,7 @@ def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame:
|
||||
logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}')
|
||||
# 关键步骤:添加market_id列,值为当前市场标识
|
||||
df['market_id'] = market_id # 新增一列,记录数据所属市场
|
||||
df['curr_date'] = trading_date
|
||||
result = pd.concat([result, df], ignore_index=True)
|
||||
|
||||
result.to_csv(file_name, index=False, encoding='utf-8')
|
||||
@ -149,7 +160,7 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url):
|
||||
# 创建股票数据对象
|
||||
stock = DailySanpModel(
|
||||
code=row['代码'],
|
||||
curr_date=current_date, # TODO: 怎么判断当前的数据是哪一天的? 要看当前时间是否已经开盘,还是在盘前,还是前一个交易日的?
|
||||
curr_date=row['curr_date'],
|
||||
name=row['名称'],
|
||||
market_id=row['market_id'],
|
||||
code_prefix=row['代码前缀'],
|
||||
@ -217,15 +228,47 @@ def insert_stock_data_to_db(dataframe, db_url=global_db_url):
|
||||
session.close()
|
||||
|
||||
|
||||
def main():
|
||||
# 获取快照数据
|
||||
snap_data = fetch_snap_all()
|
||||
if snap_data.empty:
|
||||
logging.error(f"fetching snapshot data error!")
|
||||
return
|
||||
em_code_map = {row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()}
|
||||
def main(list, args_debug, notify):
|
||||
global debug
|
||||
debug = args_debug
|
||||
|
||||
# 获取快照数据
|
||||
market_list = list.split(',')
|
||||
if not market_list:
|
||||
logging.error("未指定市场列表,请使用 --list 参数指定市场(如 cn,hk,us)")
|
||||
return
|
||||
em_code_map = {}
|
||||
for market_id in market_list:
|
||||
# 获取交易日期
|
||||
trading_day_checker = TradingDayChecker()
|
||||
trading_date = trading_day_checker.get_trading_date(market_id.upper())
|
||||
if not trading_date:
|
||||
logging.error(f"无法获取 {market_id} 市场的交易日期")
|
||||
continue
|
||||
|
||||
# 获取快照数据
|
||||
snap_data = fetch_snap_all(market_id, trading_date)
|
||||
if snap_data.empty:
|
||||
logging.error(f"未获取到 {market_id} 市场的快照数据")
|
||||
continue
|
||||
if snap_data.empty:
|
||||
logging.error(f"fetching snapshot data error for {market_id}!")
|
||||
continue
|
||||
insert_stock_data_to_db(dataframe=snap_data)
|
||||
logging.info(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}")
|
||||
|
||||
if notify:
|
||||
send_to_wecom(f"成功获取 {market_id} 市场的快照数据,记录数: {len(snap_data)}")
|
||||
|
||||
em_code_map.update({row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()})
|
||||
time.sleep(5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# 命令行参数处理
|
||||
parser = argparse.ArgumentParser(description='获取指定市场的快照数据并存储到数据库')
|
||||
parser.add_argument('--list', type=str, default='cn,hk,us', help='Stocklist to process (cn,hk,us)')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
|
||||
parser.add_argument('--notify', action='store_true', help='notify to wecom')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.list, args.debug, args.notify)
|
||||
@ -7,11 +7,18 @@ from src.sqlalchemy.models.stockdb import FutuTradingDayModel
|
||||
from src.sqlalchemy.config import global_db_url
|
||||
|
||||
|
||||
# 市场时间配置(东八区 Asia/Shanghai)
|
||||
# 新增盘前开始时间(pre_market_start)作为判断临界点
|
||||
MARKET_HOURS = {
|
||||
class TradingDayChecker:
|
||||
"""交易日检查器,封装交易日相关的判断逻辑"""
|
||||
|
||||
# 市场标识静态变量(对外暴露的公共常量)
|
||||
MARKET_CN = "CN" # A股(沪深市场)
|
||||
MARKET_HK = "HK" # 港股市场
|
||||
MARKET_US = "US" # 美股市场
|
||||
|
||||
# 市场时间配置(东八区 Asia/Shanghai)
|
||||
MARKET_HOURS = {
|
||||
# A股(沪深):盘前集合竞价9:15开始,实际交易9:30-15:00
|
||||
"CN": {
|
||||
MARKET_CN: {
|
||||
"pre_market_start": time(9, 15), # 盘前开始时间(集合竞价)
|
||||
"morning_start": time(9, 30),
|
||||
"morning_end": time(11, 30),
|
||||
@ -19,7 +26,7 @@ MARKET_HOURS = {
|
||||
"afternoon_end": time(15, 0)
|
||||
},
|
||||
# 港股:盘前竞价9:00开始,实际交易9:30-16:00
|
||||
"HK": {
|
||||
MARKET_HK: {
|
||||
"pre_market_start": time(9, 0), # 盘前开始时间
|
||||
"morning_start": time(9, 30),
|
||||
"morning_end": time(12, 0),
|
||||
@ -27,74 +34,76 @@ MARKET_HOURS = {
|
||||
"afternoon_end": time(16, 0)
|
||||
},
|
||||
# 美股:盘前交易4:00-9:30(纽约时间),对应东八区夏令时16:00-21:30,冬令时17:00-22:30
|
||||
"US": {
|
||||
MARKET_US: {
|
||||
# 东八区盘前开始时间(夏令时/冬令时)
|
||||
"dst_pre_start": time(16, 0), # 夏令时盘前开始
|
||||
"std_pre_start": time(17, 0), # 冬令时盘前开始
|
||||
# 东八区交易时间(同之前)
|
||||
# 东八区交易时间
|
||||
"dst_start": time(21, 30),
|
||||
"dst_end": time(4, 0),
|
||||
"std_start": time(22, 30),
|
||||
"std_end": time(5, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def get_trading_date(market: str, db_session: Session) -> str:
|
||||
def __init__(self, db_url=global_db_url):
|
||||
"""
|
||||
根据新逻辑返回交易日期:
|
||||
1. 若当前日期是交易日,且当前时间 >= 盘前开始时间 → 取当前交易日
|
||||
2. 否则(非交易日,或交易日但未到盘前时间) → 取前一交易日
|
||||
|
||||
参数:
|
||||
market: 市场标识(CN=A股, HK=港股, US=美股)
|
||||
db_session: SQLAlchemy数据库会话对象
|
||||
返回:
|
||||
交易日期字符串(YYYY-MM-DD)
|
||||
初始化交易日检查器
|
||||
:param db_session: SQLAlchemy数据库会话对象
|
||||
"""
|
||||
# 1. 获取东八区当前时间和日期
|
||||
tz_sh = ZoneInfo("Asia/Shanghai")
|
||||
now = datetime.now(tz_sh) # 含时区的当前时间
|
||||
current_date: date = now.date() # 东八区当前日期(仅日期)
|
||||
current_time: time = now.time() # 东八区当前时间(仅时间)
|
||||
engine = create_engine(db_url)
|
||||
Session = sessionmaker(bind=engine)
|
||||
self.db_session = Session()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.db_session.close()
|
||||
except Exception:
|
||||
pass # 避免销毁时抛出异常影响程序退出
|
||||
|
||||
# 2. 判断当前日期是否为该市场的交易日
|
||||
def is_trading_day(market: str, date: date, session: Session) -> bool:
|
||||
"""检查指定日期是否为该市场的交易日"""
|
||||
return session.query(FutuTradingDayModel).filter(
|
||||
def is_trading_day(self, market: str, check_date: date) -> bool:
|
||||
"""
|
||||
检查指定日期是否为该市场的交易日
|
||||
:param market: 市场标识(CN=A股, HK=港股, US=美股)
|
||||
:param check_date: 要检查的日期
|
||||
:return: 是否为交易日
|
||||
"""
|
||||
return self.db_session.query(FutuTradingDayModel).filter(
|
||||
FutuTradingDayModel.market == market,
|
||||
FutuTradingDayModel.trade_date == date
|
||||
FutuTradingDayModel.trade_date == check_date
|
||||
).first() is not None
|
||||
|
||||
def get_trading_date(self, market: str) -> str:
|
||||
"""
|
||||
根据当前时间判断并返回目标交易日期
|
||||
1. 若当前日期是交易日,且当前时间 >= 盘前开始时间 → 取当前交易日
|
||||
2. 否则(非交易日,或交易日但未到盘前时间) → 取前一交易日
|
||||
:param market: 市场标识(CN=A股, HK=港股, US=美股)
|
||||
:return: 交易日期字符串(YYYY-MM-DD)
|
||||
"""
|
||||
# 获取东八区当前时间和日期
|
||||
tz_sh = ZoneInfo("Asia/Shanghai")
|
||||
if market == self.MARKET_US:
|
||||
# 美股需要使用纽约时区来判断夏令时
|
||||
tz_sh = ZoneInfo("America/New_York")
|
||||
now = datetime.now(tz_sh)
|
||||
current_date: date = now.date()
|
||||
current_time: time = now.time()
|
||||
#print(f"当前时间: {now}, market:{market}, 当前日期: {current_date}, 当前时间: {current_time}")
|
||||
|
||||
# 3. 获取该市场的盘前开始时间(东八区)
|
||||
def get_pre_market_start(market: str, now: datetime) -> time:
|
||||
"""根据市场和当前时间(判断夏令时)返回盘前开始时间"""
|
||||
if market == "US":
|
||||
# 美股需根据纽约时区判断夏令时
|
||||
tz_ny = ZoneInfo("America/New_York")
|
||||
now_ny = now.astimezone(tz_ny)
|
||||
is_dst = now_ny.dst() != timedelta(0) # 夏令时判断
|
||||
return MARKET_HOURS["US"]["dst_pre_start"] if is_dst else MARKET_HOURS["US"]["std_pre_start"]
|
||||
else:
|
||||
# A股/港股直接返回配置的盘前时间
|
||||
return MARKET_HOURS[market]["pre_market_start"]
|
||||
# 判断当前是否为交易日
|
||||
current_is_trading_day = self.is_trading_day(market, current_date)
|
||||
|
||||
# 获取盘前开始时间
|
||||
pre_market_start = self._get_pre_market_start(market, now)
|
||||
is_after_pre_market = current_time >= pre_market_start
|
||||
|
||||
# 4. 核心逻辑判断
|
||||
current_is_trading_day = is_trading_day(market, current_date, db_session)
|
||||
pre_market_start = get_pre_market_start(market, now)
|
||||
is_after_pre_market = current_time >= pre_market_start # 当前时间是否过了盘前开始时间
|
||||
|
||||
|
||||
# 5. 确定查询条件
|
||||
# 确定目标交易日期
|
||||
if current_is_trading_day and is_after_pre_market:
|
||||
# 情况1:当前是交易日,且已过盘前时间 → 取当前交易日
|
||||
target_date = current_date
|
||||
else:
|
||||
# 情况2:非交易日,或未到盘前时间 → 取前一交易日
|
||||
# 查询小于当前日期的最大交易日
|
||||
prev_trading_day = db_session.query(FutuTradingDayModel.trade_date).filter(
|
||||
# 查询前一交易日
|
||||
prev_trading_day = self.db_session.query(FutuTradingDayModel.trade_date).filter(
|
||||
FutuTradingDayModel.market == market,
|
||||
FutuTradingDayModel.trade_date < current_date
|
||||
).order_by(FutuTradingDayModel.trade_date.desc()).first()
|
||||
@ -103,20 +112,40 @@ def get_trading_date(market: str, db_session: Session) -> str:
|
||||
raise ValueError(f"未查询到{market}市场的前一交易日数据")
|
||||
target_date = prev_trading_day.trade_date
|
||||
|
||||
return target_date.strftime("%Y%m%d")
|
||||
|
||||
return target_date.strftime("%Y-%m-%d")
|
||||
def _get_pre_market_start(self, market: str, now: datetime) -> time:
|
||||
"""
|
||||
内部方法:获取指定市场的盘前开始时间(东八区)
|
||||
:param market: 市场标识
|
||||
:param now: 当前时间(含时区)
|
||||
:return: 盘前开始时间
|
||||
"""
|
||||
if market == self.MARKET_US:
|
||||
# 美股需根据纽约时区判断夏令时
|
||||
tz_ny = ZoneInfo("America/New_York")
|
||||
now_ny = now.astimezone(tz_ny)
|
||||
is_dst = now_ny.dst() != timedelta(0)
|
||||
return self.MARKET_HOURS[self.MARKET_US]["dst_pre_start"] if is_dst else self.MARKET_HOURS[self.MARKET_US]["std_pre_start"]
|
||||
else:
|
||||
# A股/港股直接使用配置的盘前时间
|
||||
return self.MARKET_HOURS[market]["pre_market_start"]
|
||||
|
||||
# 示例:获取各市场的目标交易日期
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
engine = create_engine(global_db_url)
|
||||
Session = sessionmaker(bind=engine)
|
||||
db_session = Session()
|
||||
|
||||
try:
|
||||
# 分别获取三个市场的交易日期
|
||||
print(f"A股目标交易日期:{get_trading_date('CN', db_session)}")
|
||||
print(f"港股目标交易日期:{get_trading_date('HK', db_session)}")
|
||||
print(f"美股目标交易日期:{get_trading_date('US', db_session)}")
|
||||
checker = TradingDayChecker()
|
||||
|
||||
# 示例:检查今天是否为A股交易日
|
||||
today = date.today()
|
||||
print(f"今天是否为A股交易日: {checker.is_trading_day(TradingDayChecker.MARKET_CN, today)}")
|
||||
|
||||
# 示例:获取各市场目标交易日期
|
||||
print(f"A股目标交易日期:{checker.get_trading_date(TradingDayChecker.MARKET_CN)}")
|
||||
print(f"港股目标交易日期:{checker.get_trading_date(TradingDayChecker.MARKET_HK)}")
|
||||
print(f"美股目标交易日期:{checker.get_trading_date(TradingDayChecker.MARKET_US)}")
|
||||
finally:
|
||||
db_session.close()
|
||||
pass
|
||||
|
||||
127
src/utils/send_to_wecom.py
Normal file
127
src/utils/send_to_wecom.py
Normal file
@ -0,0 +1,127 @@
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
import sys
|
||||
|
||||
# 企业微信相关信息
|
||||
CORP_ID = 'ww5d7d350d9b8c0be3'
|
||||
SECRET = 'YhagYQpaNIK9j1ATopgKNQhw3D13mpGZ64YVr23Je-A'
|
||||
AGENT_ID = '1000003'
|
||||
|
||||
# 获取 access_token
|
||||
def get_access_token():
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={CORP_ID}&corpsecret={SECRET}'
|
||||
response = requests.get(url)
|
||||
result = response.json()
|
||||
if result.get('errcode') == 0:
|
||||
return result.get('access_token')
|
||||
else:
|
||||
print(f"获取 access_token 失败: {result.get('errmsg')}")
|
||||
return None
|
||||
|
||||
# 发送消息到企业微信
|
||||
def send_message(access_token, message, touser=None, toparty=None, totag=None):
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}'
|
||||
data = {
|
||||
"msgtype": "text",
|
||||
"agentid": AGENT_ID,
|
||||
"text": {
|
||||
"content": message
|
||||
},
|
||||
"safe": 0
|
||||
}
|
||||
if touser:
|
||||
data["touser"] = touser
|
||||
if toparty:
|
||||
data["toparty"] = toparty
|
||||
if totag:
|
||||
data["totag"] = totag
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
result = response.json()
|
||||
if result.get('errcode') == 0:
|
||||
print("消息发送成功")
|
||||
else:
|
||||
print(f"消息发送失败: {result.get('errmsg')}")
|
||||
|
||||
def pretty_print_json(data, n=10, indent=4, sort_keys=False):
|
||||
"""
|
||||
以美化格式打印数组的前n个元素,其他元素用"..."表示
|
||||
|
||||
参数:
|
||||
- data: 要打印的数据(应为数组)
|
||||
- n: 要显示的元素数量
|
||||
- indent: 缩进空格数
|
||||
- sort_keys: 是否按键排序
|
||||
"""
|
||||
try:
|
||||
# 处理非数组数据
|
||||
if not isinstance(data, list):
|
||||
formatted = json.dumps(data, indent=indent, ensure_ascii=False, sort_keys=sort_keys)
|
||||
return formatted
|
||||
|
||||
# 复制原始数据,避免修改原数组
|
||||
data_copy = data.copy()
|
||||
|
||||
# 切片取前n个元素
|
||||
first_n_elements = data_copy[:n]
|
||||
|
||||
# 如果数组长度超过n,添加"..."标记
|
||||
if len(data) > n:
|
||||
result = first_n_elements + ["... ({} more elements)".format(len(data) - n)]
|
||||
else:
|
||||
result = first_n_elements
|
||||
|
||||
# 格式化输出
|
||||
formatted = json.dumps(result, indent=indent, ensure_ascii=False, sort_keys=sort_keys)
|
||||
return formatted
|
||||
|
||||
except TypeError as e:
|
||||
print(f"错误:无法格式化数据。详情:{e}")
|
||||
return str(data)
|
||||
except Exception as e:
|
||||
print(f"格式化时发生意外错误:{e}")
|
||||
return str(data)
|
||||
|
||||
def is_json(s):
|
||||
"""判断字符串是否可以解析为JSON"""
|
||||
try:
|
||||
json.loads(s)
|
||||
return True
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
# 主函数
|
||||
def send_to_wecom(report_content=None):
|
||||
# 模拟数据报表内容
|
||||
if report_content is None:
|
||||
report_content = "这是第一行\n这是第二行\n这是第三行"
|
||||
else:
|
||||
# 处理转义字符
|
||||
report_content = report_content.encode().decode('unicode_escape')
|
||||
|
||||
# 判断是否为JSON并格式化
|
||||
if is_json(report_content):
|
||||
try:
|
||||
parsed_data = json.loads(report_content)
|
||||
report_content = pretty_print_json(parsed_data)
|
||||
except Exception as e:
|
||||
print(f"JSON解析或格式化失败: {e}")
|
||||
# 解析失败时保持原始内容
|
||||
|
||||
# 获取 access_token
|
||||
access_token = get_access_token()
|
||||
if access_token:
|
||||
# 示例:发送给特定人员
|
||||
send_message(access_token, report_content, touser='oscar')
|
||||
# 示例:发送给特定部门
|
||||
# send_message(access_token, report_content, toparty='department1|department2')
|
||||
# 示例:发送给特定标签
|
||||
# send_message(access_token, report_content, totag='tag1|tag2')
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
message = sys.argv[1]
|
||||
send_to_wecom(message)
|
||||
else:
|
||||
send_to_wecom()
|
||||
Reference in New Issue
Block a user