313 lines
12 KiB
Python
313 lines
12 KiB
Python
|
||
import json
|
||
import time
|
||
import csv
|
||
import os
|
||
import re
|
||
import argparse
|
||
import shutil
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from functools import partial
|
||
import config
|
||
import sqlite_utils as db_tools
|
||
import em_reports as em
|
||
import utils
|
||
|
||
config.setup_logging()
|
||
|
||
debug = False
|
||
force = False
|
||
pdf_base_dir = "/root/hostdir/stock_data/pdfs" # 下载 PDF 存放目录
|
||
|
||
|
||
map_pdf_page = {
|
||
utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html",
|
||
utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html",
|
||
utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
|
||
utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
|
||
utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
|
||
}
|
||
|
||
map_tbl_name = {
|
||
utils.tbl_stock : '个股研报',
|
||
utils.tbl_new_stock : '新股研报',
|
||
utils.tbl_strategy : '策略报告',
|
||
utils.tbl_macresearch : '宏观研究',
|
||
utils.tbl_industry : '行业研报'
|
||
}
|
||
|
||
current_date = datetime.now()
|
||
seven_days_ago = current_date - timedelta(days=7)
|
||
two_years_ago = current_date - timedelta(days=2*365)
|
||
|
||
start_date = two_years_ago.strftime("%Y-%m-%d")
|
||
end_date = current_date.strftime("%Y-%m-%d")
|
||
this_week_date = seven_days_ago.strftime("%Y-%m-%d")
|
||
|
||
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
|
||
# 示例:获取前 3 页的数据
|
||
max_pages = 100000
|
||
page = 1
|
||
while page <= max_pages:
|
||
while True:
|
||
data = fetch_func(page_no=page, start_date=s_date, end_date=e_date, page_size=100)
|
||
if data:
|
||
break
|
||
if page == 1:
|
||
max_pages = data.get('TotalPage', 1000000)
|
||
|
||
for row in data.get('data', []):
|
||
# 统一以 infoCode 为 UNIQE 键,所以这里对它进行赋值
|
||
if row.get('infoCode') is None and row.get('encodeUrl'):
|
||
row['infoCode'] = row['encodeUrl']
|
||
row_id = db_tools.insert_or_update_common(row, table_name)
|
||
if row_id:
|
||
logging.debug(f'insert one row. rowid:{row_id}, ')
|
||
else:
|
||
logging.warning(f'insert data failed. page : {page}')
|
||
return False
|
||
|
||
utils.save_json_to_file(data, f'{utils.json_data_dir}/{data_dir_prefix}', f'{data_dir_prefix}_report_{page}.json')
|
||
|
||
logging.info(f"第 {page} 页, 获取 {len(data['data'])} 条数据, 共 {max_pages} 页")
|
||
page += 1
|
||
time.sleep(1) # 避免请求过快
|
||
|
||
|
||
# 股票所用的url
|
||
def parse_func_general(row, tbl_name):
|
||
info_code = row['infoCode']
|
||
title = row['title'].replace("/", "_").replace("\\", "_")
|
||
org_sname = row['orgSName']
|
||
stock_name = row['stockName']
|
||
industry_name = row['industryName']
|
||
publish_date = row['publishDate'].split(" ")[0]
|
||
|
||
# 建表的时候默认值有点问题
|
||
if stock_name == '' or stock_name=="''":
|
||
stock_name = 'None'
|
||
if industry_name == '':
|
||
industry_name = 'None'
|
||
if org_sname == '':
|
||
org_sname = 'None'
|
||
report_type = map_tbl_name.get(tbl_name, 'None')
|
||
|
||
file_name = f"{publish_date}_{report_type}_{org_sname}_{industry_name}_{stock_name}_{title}.pdf"
|
||
url = map_pdf_page.get(tbl_name, None)
|
||
if url is None:
|
||
logging.warning(f'wrong table name: {tbl_name}')
|
||
return None, None
|
||
|
||
url = url.format(info_code)
|
||
# 拼目录
|
||
dir_year = publish_date[:4] if len(publish_date)>4 else ''
|
||
dir_path = f'{pdf_base_dir}/{dir_year}/{map_tbl_name[tbl_name]}'
|
||
os.makedirs(dir_path, exist_ok=True)
|
||
return url, os.path.join(dir_path, file_name)
|
||
|
||
|
||
# 通用下载函数
|
||
def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None):
|
||
# 下载pdf
|
||
if s_date:
|
||
querystr += f" AND publishDate >= '{s_date} 00:00:00.000' "
|
||
if e_date:
|
||
querystr += f" AND publishDate <= '{e_date} 23:59:59.999' "
|
||
|
||
rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit)
|
||
if rows is None:
|
||
rows = []
|
||
|
||
for row in rows:
|
||
url, file_path = parse_func(row, tbl_name)
|
||
if url is None or file_path is None:
|
||
logging.warning(f'wrong url or file_path. tbl_name: {tbl_name}')
|
||
continue
|
||
# 已经存在的,跳过
|
||
if file_path and os.path.isfile(file_path):
|
||
logging.info(f'{file_path} already exists. skipping...')
|
||
continue
|
||
# 获取pdf链接地址
|
||
if url:
|
||
pdf_url = em.fetch_pdf_link(url)
|
||
|
||
if pdf_url:
|
||
# 下载 PDF
|
||
down = em.download_pdf(pdf_url, file_path)
|
||
if down:
|
||
logging.info(f'saved file {file_path}')
|
||
|
||
time.sleep(1) # 避免请求过快
|
||
|
||
|
||
|
||
# 获取股票报告列表
|
||
def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
|
||
return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock')
|
||
|
||
# 获取股票报告列表
|
||
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
|
||
return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new')
|
||
|
||
# 获取行业报告列表
|
||
def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
|
||
return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry')
|
||
|
||
# 获取行业报告列表
|
||
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
|
||
return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch')
|
||
|
||
# 获取行业报告列表
|
||
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
|
||
return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy')
|
||
|
||
# 下载股票pdf
|
||
def download_pdf_stock(s_date=start_date, e_date=end_date):
|
||
download_pdf_stock_general(parse_func_general, utils.tbl_stock, ' ', s_date, e_date, limit=2 if debug else None)
|
||
|
||
def download_pdf_newstock(s_date=start_date, e_date=end_date):
|
||
download_pdf_stock_general(parse_func_general, utils.tbl_new_stock, ' ', s_date, e_date, limit=2 if debug else None)
|
||
|
||
def download_pdf_industry(s_date=start_date, e_date=end_date):
|
||
download_pdf_stock_general(parse_func_general, utils.tbl_industry, ' ', s_date, e_date, limit=2 if debug else None)
|
||
|
||
def download_pdf_macresearch(s_date=start_date, e_date=end_date):
|
||
download_pdf_stock_general(parse_func_general, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None)
|
||
|
||
def download_pdf_strategy(s_date=start_date, e_date=end_date):
|
||
download_pdf_stock_general(parse_func_general, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None)
|
||
|
||
|
||
# 建立缩写到函数的映射
|
||
function_list_map = {
|
||
'stock' : fetch_reports_list_stock,
|
||
'new' : fetch_reports_list_newstock,
|
||
'indust' : fetch_reports_list_industry,
|
||
'macro' : fetch_reports_list_macresearch,
|
||
'stra' : fetch_reports_list_strategy,
|
||
}
|
||
function_down_map = {
|
||
'stock' : download_pdf_stock,
|
||
'new' : download_pdf_newstock,
|
||
'indust' : download_pdf_industry,
|
||
'macro' : download_pdf_macresearch,
|
||
'stra' : download_pdf_strategy,
|
||
}
|
||
|
||
# 获取最新一周的报告列表
|
||
def create_last_week_links(s_date=start_date, e_date=end_date):
|
||
last_week_dir = os.path.join(pdf_base_dir, 'last_week')
|
||
|
||
# 如果 last_week 目录存在,先删除它
|
||
if os.path.exists(last_week_dir):
|
||
for root, dirs, files in os.walk(last_week_dir, topdown=False):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
os.remove(file_path)
|
||
for dir in dirs:
|
||
dir_path = os.path.join(root, dir)
|
||
os.rmdir(dir_path)
|
||
os.rmdir(last_week_dir)
|
||
|
||
os.makedirs(last_week_dir)
|
||
|
||
for root, dirs, files in os.walk(pdf_base_dir):
|
||
# 跳过 last_week 目录及其子目录
|
||
if 'last_week' in dirs:
|
||
dirs.remove('last_week')
|
||
|
||
for file in files:
|
||
if file.endswith('.pdf'):
|
||
match = re.match(r'(\d{4}-\d{2}-\d{2})_(.*)\.pdf', file)
|
||
if match:
|
||
date_str = match.group(1)
|
||
if utils.is_within_last_week(date_str):
|
||
file_path = os.path.join(root, file)
|
||
# 获取子目录名称
|
||
sub_dir_name = os.path.basename(os.path.dirname(file_path))
|
||
# 生成新的链接名称,添加子目录名前缀
|
||
new_file_name = f"[{sub_dir_name}]_{file}"
|
||
link_name = os.path.join(last_week_dir, new_file_name)
|
||
|
||
if not os.path.exists(link_name):
|
||
os.symlink(file_path, link_name)
|
||
|
||
# 执行功能函数
|
||
def run_func(function_names, function_map):
|
||
global start_date
|
||
global end_date
|
||
|
||
for short_name in function_names:
|
||
func = function_map.get(short_name.strip()) # 从映射中获取对应的函数
|
||
if callable(func):
|
||
#db_tools.update_task_log(task_id, task_status=f'Running {func}')
|
||
logging.info(f'exec function: {func}, begin: {start_date}, end: {end_date}')
|
||
func(start_date, end_date)
|
||
else:
|
||
logging.warning(f"Warning: {short_name} is not a valid function shortcut.")
|
||
|
||
# 主函数
|
||
def main(cmd, mode, args_debug, args_force, begin, end):
|
||
global debug
|
||
debug = args_debug
|
||
|
||
global force
|
||
force = args_force
|
||
|
||
global start_date
|
||
start_date = begin if begin else start_date
|
||
|
||
global end_date
|
||
end_date = end if end else end_date
|
||
|
||
# 开启任务
|
||
#task_id = db_tools.insert_task_log()
|
||
task_id = 0
|
||
if task_id is None:
|
||
logging.warning(f'insert task log error.')
|
||
return None
|
||
|
||
logging.info(f'running task. id: {task_id}, debug: {debug}, force: {force}, cmd: {cmd}, mode: {mode}')
|
||
|
||
# 如果是lastweek,我们先执行列表,再执行下载
|
||
function_list = []
|
||
if mode == 'fetch':
|
||
function_list.append(function_list_map)
|
||
elif mode == 'down':
|
||
function_list.append(function_down_map)
|
||
elif mode == 'lastweek':
|
||
start_date = this_week_date
|
||
function_list.append(function_list_map)
|
||
function_list.append(function_down_map)
|
||
else:
|
||
function_list.append(function_list_map)
|
||
|
||
# 执行指定的函数
|
||
if cmd and mode !='lastweek':
|
||
function_names = args.cmd.split(",") # 拆分输入
|
||
else:
|
||
function_names = function_list_map.keys()
|
||
|
||
# 遍历功能函数,执行
|
||
for function_map in function_list:
|
||
run_func(function_names, function_map)
|
||
|
||
logging.info(f'all process completed!')
|
||
#db_tools.finalize_task_log(task_id)
|
||
|
||
if __name__ == "__main__":
|
||
# 命令行参数处理
|
||
keys_str = ",".join(function_list_map.keys())
|
||
|
||
parser = argparse.ArgumentParser(description='fetch iafd data.')
|
||
parser.add_argument("--cmd", type=str, help=f"Comma-separated list of function shortcuts: {keys_str}")
|
||
parser.add_argument("--mode", type=str, help=f"Fetch list or Download pdf: (fetch, down, lastweek)")
|
||
parser.add_argument("--begin", type=str, help=f"begin date, YYYY-mm-dd")
|
||
parser.add_argument("--end", type=str, help=f"end date, YYYY-mm-dd")
|
||
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
|
||
parser.add_argument('--force', action='store_true', help='force update (true for rewrite all)')
|
||
args = parser.parse_args()
|
||
|
||
main(args.cmd, args.mode, args.debug, args.force, args.begin, args.end)
|