Files
stock/stockapp/reports_em/fetch.py
2025-03-15 08:02:25 +08:00

313 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import time
import csv
import os
import re
import argparse
import shutil
import logging
from datetime import datetime, timedelta
from functools import partial
import config
import sqlite_utils as db_tools
import em_reports as em
import utils
config.setup_logging()
debug = False
force = False
pdf_base_dir = "/root/hostdir/stock_data/pdfs" # 下载 PDF 存放目录
map_pdf_page = {
utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
}
map_tbl_name = {
utils.tbl_stock : '个股研报',
utils.tbl_new_stock : '新股研报',
utils.tbl_strategy : '策略报告',
utils.tbl_macresearch : '宏观研究',
utils.tbl_industry : '行业研报'
}
current_date = datetime.now()
seven_days_ago = current_date - timedelta(days=7)
two_years_ago = current_date - timedelta(days=2*365)
start_date = two_years_ago.strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d")
this_week_date = seven_days_ago.strftime("%Y-%m-%d")
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
# 示例:获取前 3 页的数据
max_pages = 100000
page = 1
while page <= max_pages:
while True:
data = fetch_func(page_no=page, start_date=s_date, end_date=e_date, page_size=100)
if data:
break
if page == 1:
max_pages = data.get('TotalPage', 1000000)
for row in data.get('data', []):
# 统一以 infoCode 为 UNIQE 键,所以这里对它进行赋值
if row.get('infoCode') is None and row.get('encodeUrl'):
row['infoCode'] = row['encodeUrl']
row_id = db_tools.insert_or_update_common(row, table_name)
if row_id:
logging.debug(f'insert one row. rowid:{row_id}, ')
else:
logging.warning(f'insert data failed. page : {page}')
return False
utils.save_json_to_file(data, f'{utils.json_data_dir}/{data_dir_prefix}', f'{data_dir_prefix}_report_{page}.json')
logging.info(f"{page} 页, 获取 {len(data['data'])} 条数据, 共 {max_pages}")
page += 1
time.sleep(1) # 避免请求过快
# 股票所用的url
def parse_func_general(row, tbl_name):
info_code = row['infoCode']
title = row['title'].replace("/", "_").replace("\\", "_")
org_sname = row['orgSName']
stock_name = row['stockName']
industry_name = row['industryName']
publish_date = row['publishDate'].split(" ")[0]
# 建表的时候默认值有点问题
if stock_name == '' or stock_name=="''":
stock_name = 'None'
if industry_name == '':
industry_name = 'None'
if org_sname == '':
org_sname = 'None'
report_type = map_tbl_name.get(tbl_name, 'None')
file_name = f"{publish_date}_{report_type}_{org_sname}_{industry_name}_{stock_name}_{title}.pdf"
url = map_pdf_page.get(tbl_name, None)
if url is None:
logging.warning(f'wrong table name: {tbl_name}')
return None, None
url = url.format(info_code)
# 拼目录
dir_year = publish_date[:4] if len(publish_date)>4 else ''
dir_path = f'{pdf_base_dir}/{dir_year}/{map_tbl_name[tbl_name]}'
os.makedirs(dir_path, exist_ok=True)
return url, os.path.join(dir_path, file_name)
# 通用下载函数
def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None):
# 下载pdf
if s_date:
querystr += f" AND publishDate >= '{s_date} 00:00:00.000' "
if e_date:
querystr += f" AND publishDate <= '{e_date} 23:59:59.999' "
rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit)
if rows is None:
rows = []
for row in rows:
url, file_path = parse_func(row, tbl_name)
if url is None or file_path is None:
logging.warning(f'wrong url or file_path. tbl_name: {tbl_name}')
continue
# 已经存在的,跳过
if file_path and os.path.isfile(file_path):
logging.info(f'{file_path} already exists. skipping...')
continue
# 获取pdf链接地址
if url:
pdf_url = em.fetch_pdf_link(url)
if pdf_url:
# 下载 PDF
down = em.download_pdf(pdf_url, file_path)
if down:
logging.info(f'saved file {file_path}')
time.sleep(1) # 避免请求过快
# 获取股票报告列表
def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock')
# 获取股票报告列表
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new')
# 获取行业报告列表
def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry')
# 获取行业报告列表
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch')
# 获取行业报告列表
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy')
# 下载股票pdf
def download_pdf_stock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_stock, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_newstock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_new_stock, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_industry(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_industry, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_macresearch(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_strategy(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None)
# 建立缩写到函数的映射
function_list_map = {
'stock' : fetch_reports_list_stock,
'new' : fetch_reports_list_newstock,
'indust' : fetch_reports_list_industry,
'macro' : fetch_reports_list_macresearch,
'stra' : fetch_reports_list_strategy,
}
function_down_map = {
'stock' : download_pdf_stock,
'new' : download_pdf_newstock,
'indust' : download_pdf_industry,
'macro' : download_pdf_macresearch,
'stra' : download_pdf_strategy,
}
# 获取最新一周的报告列表
def create_last_week_links(s_date=start_date, e_date=end_date):
last_week_dir = os.path.join(pdf_base_dir, 'last_week')
# 如果 last_week 目录存在,先删除它
if os.path.exists(last_week_dir):
for root, dirs, files in os.walk(last_week_dir, topdown=False):
for file in files:
file_path = os.path.join(root, file)
os.remove(file_path)
for dir in dirs:
dir_path = os.path.join(root, dir)
os.rmdir(dir_path)
os.rmdir(last_week_dir)
os.makedirs(last_week_dir)
for root, dirs, files in os.walk(pdf_base_dir):
# 跳过 last_week 目录及其子目录
if 'last_week' in dirs:
dirs.remove('last_week')
for file in files:
if file.endswith('.pdf'):
match = re.match(r'(\d{4}-\d{2}-\d{2})_(.*)\.pdf', file)
if match:
date_str = match.group(1)
if utils.is_within_last_week(date_str):
file_path = os.path.join(root, file)
# 获取子目录名称
sub_dir_name = os.path.basename(os.path.dirname(file_path))
# 生成新的链接名称,添加子目录名前缀
new_file_name = f"[{sub_dir_name}]_{file}"
link_name = os.path.join(last_week_dir, new_file_name)
if not os.path.exists(link_name):
os.symlink(file_path, link_name)
# 执行功能函数
def run_func(function_names, function_map):
global start_date
global end_date
for short_name in function_names:
func = function_map.get(short_name.strip()) # 从映射中获取对应的函数
if callable(func):
#db_tools.update_task_log(task_id, task_status=f'Running {func}')
logging.info(f'exec function: {func}, begin: {start_date}, end: {end_date}')
func(start_date, end_date)
else:
logging.warning(f"Warning: {short_name} is not a valid function shortcut.")
# 主函数
def main(cmd, mode, args_debug, args_force, begin, end):
global debug
debug = args_debug
global force
force = args_force
global start_date
start_date = begin if begin else start_date
global end_date
end_date = end if end else end_date
# 开启任务
#task_id = db_tools.insert_task_log()
task_id = 0
if task_id is None:
logging.warning(f'insert task log error.')
return None
logging.info(f'running task. id: {task_id}, debug: {debug}, force: {force}, cmd: {cmd}, mode: {mode}')
# 如果是lastweek我们先执行列表再执行下载
function_list = []
if mode == 'fetch':
function_list.append(function_list_map)
elif mode == 'down':
function_list.append(function_down_map)
elif mode == 'lastweek':
start_date = this_week_date
function_list.append(function_list_map)
function_list.append(function_down_map)
else:
function_list.append(function_list_map)
# 执行指定的函数
if cmd and mode !='lastweek':
function_names = args.cmd.split(",") # 拆分输入
else:
function_names = function_list_map.keys()
# 遍历功能函数,执行
for function_map in function_list:
run_func(function_names, function_map)
logging.info(f'all process completed!')
#db_tools.finalize_task_log(task_id)
if __name__ == "__main__":
# 命令行参数处理
keys_str = ",".join(function_list_map.keys())
parser = argparse.ArgumentParser(description='fetch iafd data.')
parser.add_argument("--cmd", type=str, help=f"Comma-separated list of function shortcuts: {keys_str}")
parser.add_argument("--mode", type=str, help=f"Fetch list or Download pdf: (fetch, down, lastweek)")
parser.add_argument("--begin", type=str, help=f"begin date, YYYY-mm-dd")
parser.add_argument("--end", type=str, help=f"end date, YYYY-mm-dd")
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
parser.add_argument('--force', action='store_true', help='force update (true for rewrite all)')
args = parser.parse_args()
main(args.cmd, args.mode, args.debug, args.force, args.begin, args.end)