Files
stock/src/em_reports/fetch.py
2025-08-10 19:26:47 +08:00

399 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import time
import csv
import os
import re
import argparse
import shutil
import logging
from datetime import datetime, timedelta
from functools import partial
import src.crawler.em.reports as em
import src.utils.utils as utils
from src.config.config import global_host_data_dir, global_share_db_dir, db_config
from src.db_utils.reports import StockReportDB, DatabaseConnectionError
from src.db_utils.reports_mysql import StockReportMysql
from src.logger.logger import setup_logging
import PyPDF2
# 初始化日志
setup_logging()
debug = False
force = False
pdf_base_dir = f"{global_host_data_dir}/pdfs" # 下载 PDF 存放目录
# 定义下载页面的链接
map_pdf_page = {
StockReportDB.TBL_STOCK : "https://data.eastmoney.com/report/info/{}.html",
StockReportDB.TBL_NEW_STOCK : "https://data.eastmoney.com/report/info/{}.html",
StockReportDB.TBL_STRATEGY : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
StockReportDB.TBL_MACRESEARCH : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
StockReportDB.TBL_INDUSTRY : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
}
# 定义表名的映射,作为存储路径用
map_tbl_name = {
StockReportDB.TBL_STOCK : '个股研报',
StockReportDB.TBL_NEW_STOCK : '新股研报',
StockReportDB.TBL_STRATEGY : '策略报告',
StockReportDB.TBL_MACRESEARCH : '宏观研究',
StockReportDB.TBL_INDUSTRY : '行业研报'
}
# 初始化数据库连接
db_path = f"{global_share_db_dir}/stock_report.db"
db_tools = None
current_date = datetime.now()
seven_days_ago = current_date - timedelta(days=7)
two_years_ago = current_date - timedelta(days=2*365)
start_date = two_years_ago.strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d")
this_week_date = seven_days_ago.strftime("%Y-%m-%d")
min_down_pages = 10
def get_pdf_page_count(pdf_path):
try:
# 以二进制只读模式打开 PDF 文件
with open(pdf_path, 'rb') as file:
# 创建一个 PdfReader 对象
pdf_reader = PyPDF2.PdfReader(file)
# 获取 PDF 文件的页数
page_count = len(pdf_reader.pages)
return page_count
except Exception as e:
logging.warning(f"处理文件 {pdf_path} 时出错: {e}")
return None
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
# 示例:获取前 3 页的数据
max_pages = 100000
page = 1
while page <= max_pages:
while True:
data = fetch_func(page_no=page, start_date=s_date, end_date=e_date, page_size=100)
if data:
break
if page == 1:
max_pages = data.get('TotalPage', 1000000)
for row in data.get('data', []):
# 统一以 infoCode 为 UNIQE 键,所以这里对它进行赋值
if row.get('infoCode') is None and row.get('encodeUrl'):
row['infoCode'] = row['encodeUrl']
row['count_all'] = row.get('count', 0) # 兼容旧数据
row['curr_column'] = row.get('column', 'None') # 兼容旧数据
if 'stockName' not in row or row['stockName'] is None or row['stockName']=='':
row['stockName'] = ''
if 'stockCode' not in row or row['stockCode'] is None or row['stockCode']=='':
row['stockCode'] = ''
if 'newPeIssueA' in row and row['newPeIssueA'] is None:
try:
row['newPeIssueA'] = float(row.get('newPeIssueA', 0))
except ValueError:
row['newPeIssueA'] = 0.0
else:
row['newPeIssueA'] = 0.0
row_id = db_tools.insert_or_update_common(row, table_name)
if row_id:
logging.debug(f'insert one row. rowid:{row_id}, ')
else:
logging.warning(f'insert data failed. page : {page}')
return False
# 写本地json文件必要性不大
#utils.save_json_to_file(data, f'{utils.json_data_dir}/{data_dir_prefix}', f'{data_dir_prefix}_report_{page}.json')
logging.info(f"{page} 页, 获取 {len(data['data'])} 条数据, 共 {max_pages}")
page += 1
time.sleep(1) # 避免请求过快
# 股票所用的url
def parse_func_general(row, tbl_name):
info_code = row['infoCode']
title = row['title'].replace("/", "_").replace("\\", "_")
org_sname = row['orgSName']
stock_name = row['stockName']
industry_name = row['industryName']
publish_date = row['publishDate'].split(" ")[0]
# 建表的时候默认值有点问题
if stock_name == '' or stock_name=="''":
stock_name = 'None'
if industry_name == '':
industry_name = 'None'
if org_sname == '':
org_sname = 'None'
report_type = map_tbl_name.get(tbl_name, 'None')
file_name = f"{publish_date}_{report_type}_{org_sname}_{industry_name}_{stock_name}_{title}.pdf"
url = map_pdf_page.get(tbl_name, None)
if url is None:
logging.warning(f'wrong table name: {tbl_name}')
return None, None
url = url.format(info_code)
# 拼目录
dir_year = publish_date[:4] if len(publish_date)>4 else ''
#dir_path = f'{pdf_base_dir}/{dir_year}/{map_tbl_name[tbl_name]}'
dir_path = f'{pdf_base_dir}/{dir_year}'
os.makedirs(dir_path, exist_ok=True)
return url, os.path.join(dir_path, file_name)
# 检查pdf的页数如果小于限定的值则移动到其他目录
def check_pdf_pages(file_path, row, tbl_name):
pages = get_pdf_page_count(file_path)
if pages is None or pages < min_down_pages:
# 获取文件所在目录
file_dir = os.path.dirname(file_path)
# 创建 tmp 子目录
tmp_dir = os.path.join(file_dir, 'tmp')
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
# 移动文件到 tmp 子目录
file_name = os.path.basename(file_path)
new_path = os.path.join(tmp_dir, file_name)
shutil.move(file_path, new_path)
logging.debug(f"move {file_name} to {tmp_dir}")
# macro 和 stra 表,需要更新页码回去
if tbl_name == StockReportDB.TBL_MACRESEARCH or tbl_name == StockReportDB.TBL_STRATEGY:
data={}
data['infoCode'] = row['infoCode']
data['id'] = row['id']
data['attachPages'] = pages
row_id = db_tools.update_pages(data, tbl_name)
if row_id:
logging.debug(f"update one row. tbl: {tbl_name}, rowid:{row_id}")
else:
logging.warning(f"update data failed. tbl: {tbl_name}, rowid:{row['id']}")
return False
# 通用下载函数
def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None, min_page=None):
# 下载pdf
if s_date:
querystr += f" AND publishDate >= '{s_date} 00:00:00.000' "
if e_date:
querystr += f" AND publishDate <= '{e_date} 23:59:59.999' "
if min_page:
querystr += f" AND attachPages >= {min_page} "
rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit)
if rows is None:
rows = []
for row in rows:
url, file_path = parse_func(row, tbl_name)
if url is None or file_path is None:
logging.warning(f'wrong url or file_path. tbl_name: {tbl_name}')
continue
# 已经存在的,跳过
if file_path and os.path.isfile(file_path):
logging.info(f'{file_path} already exists. skipping...')
continue
# 获取pdf链接地址
pdf_url = em.fetch_pdf_link(url)
if pdf_url:
# 下载 PDF
down = em.download_pdf(pdf_url, file_path)
if down:
logging.info(f'saved file {file_path}')
check_pdf_pages(file_path, row, tbl_name)
else:
logging.warning(f'download pdf file error. file_path: {pdf_url}, save_path: {file_path}')
else:
logging.warning(f'cannot get pdf link. url: {url}, save_path: {file_path}')
time.sleep(1) # 避免请求过快
# 获取股票报告列表
def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_stock, StockReportDB.TBL_STOCK, s_date, e_date, 'stock')
# 获取股票报告列表
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_newstock, StockReportDB.TBL_NEW_STOCK, s_date, e_date, 'new')
# 获取行业报告列表
def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_industry, StockReportDB.TBL_INDUSTRY, s_date, e_date, 'industry')
# 获取行业报告列表
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_macresearch, StockReportDB.TBL_MACRESEARCH, s_date, e_date, 'macresearch')
# 获取行业报告列表
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_strategy, StockReportDB.TBL_STRATEGY, s_date, e_date, 'strategy')
# 下载股票pdf
def download_pdf_stock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STOCK, ' ', s_date, e_date, limit=2 if debug else None, min_page=min_down_pages)
def download_pdf_newstock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_NEW_STOCK, ' ', s_date, e_date, limit=2 if debug else None, min_page=min_down_pages)
def download_pdf_industry(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_INDUSTRY, ' ', s_date, e_date, limit=2 if debug else None, min_page=min_down_pages)
def download_pdf_macresearch(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_MACRESEARCH, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_strategy(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STRATEGY, ' ', s_date, e_date, limit=2 if debug else None)
# 建立缩写到函数的映射
function_list_map = {
'stock' : fetch_reports_list_stock,
'new' : fetch_reports_list_newstock,
'indust' : fetch_reports_list_industry,
'macro' : fetch_reports_list_macresearch,
'stra' : fetch_reports_list_strategy,
}
function_down_map = {
'stock' : download_pdf_stock,
'new' : download_pdf_newstock,
'indust' : download_pdf_industry,
'macro' : download_pdf_macresearch,
'stra' : download_pdf_strategy,
}
# 获取最新一周的报告列表
def create_last_week_links(s_date=start_date, e_date=end_date):
last_week_dir = os.path.join(pdf_base_dir, 'last_week')
# 如果 last_week 目录存在,先删除它
if os.path.exists(last_week_dir):
for root, dirs, files in os.walk(last_week_dir, topdown=False):
for file in files:
file_path = os.path.join(root, file)
os.remove(file_path)
for dir in dirs:
dir_path = os.path.join(root, dir)
os.rmdir(dir_path)
os.rmdir(last_week_dir)
os.makedirs(last_week_dir)
for root, dirs, files in os.walk(pdf_base_dir):
# 跳过 last_week 目录及其子目录
if 'last_week' in dirs:
dirs.remove('last_week')
for file in files:
if file.endswith('.pdf'):
match = re.match(r'(\d{4}-\d{2}-\d{2})_(.*)\.pdf', file)
if match:
date_str = match.group(1)
if utils.is_within_last_week(date_str):
file_path = os.path.join(root, file)
# 获取子目录名称
sub_dir_name = os.path.basename(os.path.dirname(file_path))
# 生成新的链接名称,添加子目录名前缀
new_file_name = f"[{sub_dir_name}]_{file}"
link_name = os.path.join(last_week_dir, new_file_name)
if not os.path.exists(link_name):
os.symlink(file_path, link_name)
# 执行功能函数
def run_func(function_names, function_map):
global start_date
global end_date
for short_name in function_names:
func = function_map.get(short_name.strip()) # 从映射中获取对应的函数
if callable(func):
#db_tools.update_task_log(task_id, task_status=f'Running {func}')
logging.info(f'exec function: {func}, begin: {start_date}, end: {end_date}')
func(start_date, end_date)
else:
logging.warning(f"Warning: {short_name} is not a valid function shortcut.")
# 主函数
def main(cmd, mode, args_debug, args_force, begin, end):
global debug
debug = args_debug
if debug:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
global force
force = args_force
global start_date
start_date = begin if begin else start_date
global end_date
end_date = end if end else end_date
# 初始化DB
global db_tools
try:
#db_tools = StockReportDB(db_path)
db_tools = StockReportMysql(db_host=db_config['host'], db_user=db_config['user'], db_password=db_config['password'], db_name = db_config['database'], port=3306) # 使用配置文件中的数据库配置
# 进行数据库操作
except DatabaseConnectionError as e:
logging.error(f"数据库连接失败: {e}")
return False
# 开启任务
#task_id = db_tools.insert_task_log()
task_id = 0
if task_id is None:
logging.warning(f'insert task log error.')
return None
logging.info(f'running task. id: {task_id}, debug: {debug}, force: {force}, cmd: {cmd}, mode: {mode}')
# 如果是lastweek我们先执行列表再执行下载
function_list = []
if mode == 'fetch':
function_list.append(function_list_map)
elif mode == 'down':
function_list.append(function_down_map)
elif mode == 'lastweek':
start_date = this_week_date
function_list.append(function_list_map)
function_list.append(function_down_map)
else:
function_list.append(function_list_map)
# 执行指定的函数
if cmd and mode !='lastweek':
function_names = args.cmd.split(",") # 拆分输入
else:
function_names = function_list_map.keys()
# 遍历功能函数,执行
for function_map in function_list:
run_func(function_names, function_map)
logging.info(f'all process completed!')
#db_tools.finalize_task_log(task_id)
if __name__ == "__main__":
# 命令行参数处理
keys_str = ",".join(function_list_map.keys())
parser = argparse.ArgumentParser(description='fetch iafd data.')
parser.add_argument("--cmd", type=str, help=f"Comma-separated list of function shortcuts: {keys_str}")
parser.add_argument("--mode", type=str, help=f"Fetch list or Download pdf: (fetch, down, lastweek)")
parser.add_argument("--begin", type=str, help=f"begin date, YYYY-mm-dd")
parser.add_argument("--end", type=str, help=f"end date, YYYY-mm-dd")
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
parser.add_argument('--force', action='store_true', help='force update (true for rewrite all)')
args = parser.parse_args()
main(args.cmd, args.mode, args.debug, args.force, args.begin, args.end)