modify some scripts.

This commit is contained in:
2025-03-11 14:22:37 +08:00
parent 4d5f9521ef
commit af92229a3e
9 changed files with 1197 additions and 1 deletions

View File

@ -0,0 +1,323 @@
import json
import time
import csv
import os
import re
import argparse
import shutil
import logging
from datetime import datetime, timedelta
from functools import partial
import config
import sqlite_utils as db_tools
import em_reports as em
import utils
config.setup_logging()
debug = False
force = False
pdf_base_dir = "/root/hostdir/stock_data/pdfs" # 下载 PDF 存放目录
map_pdf_page = {
utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
}
map_pdf_path = {
utils.tbl_stock : f'{pdf_base_dir}/stock',
utils.tbl_new_stock : f'{pdf_base_dir}/newstock',
utils.tbl_strategy : f'{pdf_base_dir}/strategy',
utils.tbl_macresearch : f'{pdf_base_dir}/macresearch',
utils.tbl_industry : f'{pdf_base_dir}/industry'
}
current_date = datetime.now()
seven_days_ago = current_date - timedelta(days=7)
two_years_ago = current_date - timedelta(days=2*365)
start_date = two_years_ago.strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d")
this_week_date = seven_days_ago.strftime("%Y-%m-%d")
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
# 示例:获取前 3 页的数据
max_pages = 100000
page = 1
while page <= max_pages:
while True:
data = fetch_func(page_no=page, start_date=s_date, end_date=e_date, page_size=100)
if data:
break
if page == 1:
max_pages = data.get('TotalPage', 1000000)
for row in data.get('data', []):
# 统一以 infoCode 为 UNIQE 键,所以这里对它进行赋值
if row.get('infoCode') is None and row.get('encodeUrl'):
row['infoCode'] = row['encodeUrl']
row_id = db_tools.insert_or_update_common(row, table_name)
if row_id:
logging.debug(f'insert one row. rowid:{row_id}, ')
else:
logging.warning(f'insert data failed. page : {page}')
return False
utils.save_json_to_file(data, f'{utils.json_data_dir}/{data_dir_prefix}', f'{data_dir_prefix}_report_{page}.json')
logging.info(f"{page} 页, 获取 {len(data['data'])} 条数据, 共 {max_pages}")
page += 1
time.sleep(1) # 避免请求过快
# 股票所用的url
def parse_func_stock(row, tbl_name):
info_code = row['infoCode']
title = row['title'].replace("/", "_").replace("\\", "_")
org_sname = row['orgSName']
stock_name = row['stockName']
industry_name = row['industryName']
publish_date = row['publishDate'].split(" ")[0]
file_name = f"{publish_date}_{org_sname}_{stock_name}_{title}.pdf"
url = map_pdf_page.get(tbl_name, None)
if url is None:
logging.warning(f'wrong table name: {tbl_name}')
return None, None, None
url = url.format(info_code)
os.makedirs(map_pdf_path[tbl_name], exist_ok=True)
return url, os.path.join(map_pdf_path[tbl_name], file_name), None
# 其它所用的url
def parse_func_other(row, tbl_name):
info_code = row['infoCode']
title = row['title'].replace("/", "_").replace("\\", "_")
org_sname = row['orgSName']
industry_name = row['industryName']
publish_date = row['publishDate'].split(" ")[0]
file_name = f"{publish_date}_{org_sname}_{industry_name}_{title}.pdf"
old_file_name = f"{publish_date}_{industry_name}_{org_sname}_{title}.pdf"
url = map_pdf_page.get(tbl_name, None)
if url is None:
logging.warning(f'wrong table name: {tbl_name}')
return None, None, None
url = url.format(info_code)
os.makedirs(map_pdf_path[tbl_name], exist_ok=True)
return url, os.path.join(map_pdf_path[tbl_name], file_name), os.path.join(map_pdf_path[tbl_name], old_file_name)
# 通用下载函数
def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None):
# 下载pdf
if s_date:
querystr += f" AND publishDate >= '{s_date} 00:00:00.000' "
if e_date:
querystr += f" AND publishDate <= '{e_date} 23:59:59.999' "
rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit)
if rows is None:
rows = []
for row in rows:
url, file_path, old_file_path = parse_func(row, tbl_name)
if url is None or file_path is None:
logging.warning(f'wrong url or file_path. tbl_name: {tbl_name}')
continue
# 已经存在的,跳过
if file_path and os.path.isfile(file_path):
logging.info(f'{file_path} already exists. skipping...')
continue
# 旧方式命名的rename
if old_file_path and os.path.isfile(old_file_path):
shutil.move(old_file_path, file_path)
logging.info(f'rename existed file to {file_path}')
continue
# 获取pdf链接地址
if url:
pdf_url = em.fetch_pdf_link(url)
if pdf_url:
# 下载 PDF
down = em.download_pdf(pdf_url, file_path)
if down:
logging.info(f'saved file {file_path}')
time.sleep(1) # 避免请求过快
# 获取股票报告列表
def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock')
# 获取股票报告列表
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new')
# 获取行业报告列表
def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry')
# 获取行业报告列表
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch')
# 获取行业报告列表
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy')
# 下载股票pdf
def download_pdf_stock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_stock, utils.tbl_stock, ' AND attachPages>=30', s_date, e_date, limit=2 if debug else None)
def download_pdf_newstock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_stock, utils.tbl_new_stock, ' AND attachPages>=30', s_date, e_date, limit=2 if debug else None)
def download_pdf_industry(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_other, utils.tbl_industry, ' AND attachPages>=30', s_date, e_date, limit=2 if debug else None)
def download_pdf_macresearch(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_other, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_strategy(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_other, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None)
# 建立缩写到函数的映射
function_list_map = {
'stock' : fetch_reports_list_stock,
'new' : fetch_reports_list_newstock,
'indust' : fetch_reports_list_industry,
'macro' : fetch_reports_list_macresearch,
'stra' : fetch_reports_list_strategy,
}
function_down_map = {
'stock' : download_pdf_stock,
'new' : download_pdf_newstock,
'indust' : download_pdf_industry,
'macro' : download_pdf_macresearch,
'stra' : download_pdf_strategy,
}
# 获取最新一周的报告列表
def create_last_week_links(s_date=start_date, e_date=end_date):
last_week_dir = os.path.join(pdf_base_dir, 'last_week')
# 如果 last_week 目录存在,先删除它
if os.path.exists(last_week_dir):
for root, dirs, files in os.walk(last_week_dir, topdown=False):
for file in files:
file_path = os.path.join(root, file)
os.remove(file_path)
for dir in dirs:
dir_path = os.path.join(root, dir)
os.rmdir(dir_path)
os.rmdir(last_week_dir)
os.makedirs(last_week_dir)
for root, dirs, files in os.walk(pdf_base_dir):
# 跳过 last_week 目录及其子目录
if 'last_week' in dirs:
dirs.remove('last_week')
for file in files:
if file.endswith('.pdf'):
match = re.match(r'(\d{4}-\d{2}-\d{2})_(.*)\.pdf', file)
if match:
date_str = match.group(1)
if utils.is_within_last_week(date_str):
file_path = os.path.join(root, file)
# 获取子目录名称
sub_dir_name = os.path.basename(os.path.dirname(file_path))
# 生成新的链接名称,添加子目录名前缀
new_file_name = f"[{sub_dir_name}]_{file}"
link_name = os.path.join(last_week_dir, new_file_name)
if not os.path.exists(link_name):
os.symlink(file_path, link_name)
# 执行功能函数
def run_func(function_names, function_map):
global start_date
global end_date
for short_name in function_names:
func = function_map.get(short_name.strip()) # 从映射中获取对应的函数
if callable(func):
#db_tools.update_task_log(task_id, task_status=f'Running {func}')
logging.info(f'exec function: {func}, begin: {start_date}, end: {end_date}')
func(start_date, end_date)
else:
logging.warning(f"Warning: {short_name} is not a valid function shortcut.")
# 主函数
def main(cmd, mode, args_debug, args_force, begin, end):
global debug
debug = args_debug
global force
force = args_force
global start_date
start_date = begin if begin else start_date
global end_date
end_date = end if end else end_date
# 开启任务
#task_id = db_tools.insert_task_log()
task_id = 0
if task_id is None:
logging.warning(f'insert task log error.')
return None
logging.info(f'running task. id: {task_id}, debug: {debug}, force: {force}, cmd: {cmd}, mode: {mode}')
# 如果是lastweek我们先执行列表再执行下载
function_list = []
if mode == 'fetch':
function_list.append(function_list_map)
elif mode == 'down':
function_list.append(function_down_map)
elif mode == 'lastweek':
start_date = this_week_date
function_list.append(function_list_map)
function_list.append(function_down_map)
else:
function_list.append(function_list_map)
# 执行指定的函数
if cmd and mode !='lastweek':
function_names = args.cmd.split(",") # 拆分输入
else:
function_names = function_list_map.keys()
# 遍历功能函数,执行
for function_map in function_list:
run_func(function_names, function_map)
logging.info(f'all process completed!')
#db_tools.finalize_task_log(task_id)
if __name__ == "__main__":
# 命令行参数处理
keys_str = ",".join(function_list_map.keys())
parser = argparse.ArgumentParser(description='fetch iafd data.')
parser.add_argument("--cmd", type=str, help=f"Comma-separated list of function shortcuts: {keys_str}")
parser.add_argument("--mode", type=str, help=f"Fetch list or Download pdf: (fetch, down, lastweek)")
parser.add_argument("--begin", type=str, help=f"begin date, YYYY-mm-dd")
parser.add_argument("--end", type=str, help=f"end date, YYYY-mm-dd")
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
parser.add_argument('--force', action='store_true', help='force update (true for rewrite all)')
args = parser.parse_args()
main(args.cmd, args.mode, args.debug, args.force, args.begin, args.end)