import json import time import csv import os import re import argparse import shutil import logging from datetime import datetime, timedelta from functools import partial import config import sqlite_utils as db_tools import em_reports as em import utils config.setup_logging() debug = False force = False pdf_base_dir = "/root/hostdir/stock_data/pdfs" # 下载 PDF 存放目录 map_pdf_page = { utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html", utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html", utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}", utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}", utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}" } map_pdf_path = { utils.tbl_stock : f'{pdf_base_dir}/stock', utils.tbl_new_stock : f'{pdf_base_dir}/newstock', utils.tbl_strategy : f'{pdf_base_dir}/strategy', utils.tbl_macresearch : f'{pdf_base_dir}/macresearch', utils.tbl_industry : f'{pdf_base_dir}/industry' } current_date = datetime.now() seven_days_ago = current_date - timedelta(days=7) two_years_ago = current_date - timedelta(days=2*365) start_date = two_years_ago.strftime("%Y-%m-%d") end_date = current_date.strftime("%Y-%m-%d") this_week_date = seven_days_ago.strftime("%Y-%m-%d") def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix): # 示例:获取前 3 页的数据 max_pages = 100000 page = 1 while page <= max_pages: while True: data = fetch_func(page_no=page, start_date=s_date, end_date=e_date, page_size=100) if data: break if page == 1: max_pages = data.get('TotalPage', 1000000) for row in data.get('data', []): # 统一以 infoCode 为 UNIQE 键,所以这里对它进行赋值 if row.get('infoCode') is None and row.get('encodeUrl'): row['infoCode'] = row['encodeUrl'] row_id = db_tools.insert_or_update_common(row, table_name) if row_id: logging.debug(f'insert one row. rowid:{row_id}, ') else: logging.warning(f'insert data failed. page : {page}') return False utils.save_json_to_file(data, f'{utils.json_data_dir}/{data_dir_prefix}', f'{data_dir_prefix}_report_{page}.json') logging.info(f"第 {page} 页, 获取 {len(data['data'])} 条数据, 共 {max_pages} 页") page += 1 time.sleep(1) # 避免请求过快 # 股票所用的url def parse_func_stock(row, tbl_name): info_code = row['infoCode'] title = row['title'].replace("/", "_").replace("\\", "_") org_sname = row['orgSName'] stock_name = row['stockName'] industry_name = row['industryName'] publish_date = row['publishDate'].split(" ")[0] file_name = f"{publish_date}_{org_sname}_{stock_name}_{title}.pdf" url = map_pdf_page.get(tbl_name, None) if url is None: logging.warning(f'wrong table name: {tbl_name}') return None, None, None url = url.format(info_code) os.makedirs(map_pdf_path[tbl_name], exist_ok=True) return url, os.path.join(map_pdf_path[tbl_name], file_name), None # 其它所用的url def parse_func_other(row, tbl_name): info_code = row['infoCode'] title = row['title'].replace("/", "_").replace("\\", "_") org_sname = row['orgSName'] industry_name = row['industryName'] publish_date = row['publishDate'].split(" ")[0] file_name = f"{publish_date}_{org_sname}_{industry_name}_{title}.pdf" old_file_name = f"{publish_date}_{industry_name}_{org_sname}_{title}.pdf" url = map_pdf_page.get(tbl_name, None) if url is None: logging.warning(f'wrong table name: {tbl_name}') return None, None, None url = url.format(info_code) os.makedirs(map_pdf_path[tbl_name], exist_ok=True) return url, os.path.join(map_pdf_path[tbl_name], file_name), os.path.join(map_pdf_path[tbl_name], old_file_name) # 通用下载函数 def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None): # 下载pdf if s_date: querystr += f" AND publishDate >= '{s_date} 00:00:00.000' " if e_date: querystr += f" AND publishDate <= '{e_date} 23:59:59.999' " rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit) if rows is None: rows = [] for row in rows: url, file_path, old_file_path = parse_func(row, tbl_name) if url is None or file_path is None: logging.warning(f'wrong url or file_path. tbl_name: {tbl_name}') continue # 已经存在的,跳过 if file_path and os.path.isfile(file_path): logging.info(f'{file_path} already exists. skipping...') continue # 旧方式命名的,rename if old_file_path and os.path.isfile(old_file_path): shutil.move(old_file_path, file_path) logging.info(f'rename existed file to {file_path}') continue # 获取pdf链接地址 if url: pdf_url = em.fetch_pdf_link(url) if pdf_url: # 下载 PDF down = em.download_pdf(pdf_url, file_path) if down: logging.info(f'saved file {file_path}') time.sleep(1) # 避免请求过快 # 获取股票报告列表 def fetch_reports_list_stock(s_date=start_date, e_date=end_date): return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock') # 获取股票报告列表 def fetch_reports_list_newstock(s_date=start_date, e_date=end_date): return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new') # 获取行业报告列表 def fetch_reports_list_industry(s_date=start_date, e_date=end_date): return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry') # 获取行业报告列表 def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date): return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch') # 获取行业报告列表 def fetch_reports_list_strategy(s_date=start_date, e_date=end_date): return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy') # 下载股票pdf def download_pdf_stock(s_date=start_date, e_date=end_date): download_pdf_stock_general(parse_func_stock, utils.tbl_stock, ' AND attachPages>=30', s_date, e_date, limit=2 if debug else None) def download_pdf_newstock(s_date=start_date, e_date=end_date): download_pdf_stock_general(parse_func_stock, utils.tbl_new_stock, ' AND attachPages>=30', s_date, e_date, limit=2 if debug else None) def download_pdf_industry(s_date=start_date, e_date=end_date): download_pdf_stock_general(parse_func_other, utils.tbl_industry, ' AND attachPages>=30', s_date, e_date, limit=2 if debug else None) def download_pdf_macresearch(s_date=start_date, e_date=end_date): download_pdf_stock_general(parse_func_other, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None) def download_pdf_strategy(s_date=start_date, e_date=end_date): download_pdf_stock_general(parse_func_other, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None) # 建立缩写到函数的映射 function_list_map = { 'stock' : fetch_reports_list_stock, 'new' : fetch_reports_list_newstock, 'indust' : fetch_reports_list_industry, 'macro' : fetch_reports_list_macresearch, 'stra' : fetch_reports_list_strategy, } function_down_map = { 'stock' : download_pdf_stock, 'new' : download_pdf_newstock, 'indust' : download_pdf_industry, 'macro' : download_pdf_macresearch, 'stra' : download_pdf_strategy, } # 获取最新一周的报告列表 def create_last_week_links(s_date=start_date, e_date=end_date): last_week_dir = os.path.join(pdf_base_dir, 'last_week') # 如果 last_week 目录存在,先删除它 if os.path.exists(last_week_dir): for root, dirs, files in os.walk(last_week_dir, topdown=False): for file in files: file_path = os.path.join(root, file) os.remove(file_path) for dir in dirs: dir_path = os.path.join(root, dir) os.rmdir(dir_path) os.rmdir(last_week_dir) os.makedirs(last_week_dir) for root, dirs, files in os.walk(pdf_base_dir): # 跳过 last_week 目录及其子目录 if 'last_week' in dirs: dirs.remove('last_week') for file in files: if file.endswith('.pdf'): match = re.match(r'(\d{4}-\d{2}-\d{2})_(.*)\.pdf', file) if match: date_str = match.group(1) if utils.is_within_last_week(date_str): file_path = os.path.join(root, file) # 获取子目录名称 sub_dir_name = os.path.basename(os.path.dirname(file_path)) # 生成新的链接名称,添加子目录名前缀 new_file_name = f"[{sub_dir_name}]_{file}" link_name = os.path.join(last_week_dir, new_file_name) if not os.path.exists(link_name): os.symlink(file_path, link_name) # 执行功能函数 def run_func(function_names, function_map): global start_date global end_date for short_name in function_names: func = function_map.get(short_name.strip()) # 从映射中获取对应的函数 if callable(func): #db_tools.update_task_log(task_id, task_status=f'Running {func}') logging.info(f'exec function: {func}, begin: {start_date}, end: {end_date}') func(start_date, end_date) else: logging.warning(f"Warning: {short_name} is not a valid function shortcut.") # 主函数 def main(cmd, mode, args_debug, args_force, begin, end): global debug debug = args_debug global force force = args_force global start_date start_date = begin if begin else start_date global end_date end_date = end if end else end_date # 开启任务 #task_id = db_tools.insert_task_log() task_id = 0 if task_id is None: logging.warning(f'insert task log error.') return None logging.info(f'running task. id: {task_id}, debug: {debug}, force: {force}, cmd: {cmd}, mode: {mode}') # 如果是lastweek,我们先执行列表,再执行下载 function_list = [] if mode == 'fetch': function_list.append(function_list_map) elif mode == 'down': function_list.append(function_down_map) elif mode == 'lastweek': start_date = this_week_date function_list.append(function_list_map) function_list.append(function_down_map) else: function_list.append(function_list_map) # 执行指定的函数 if cmd and mode !='lastweek': function_names = args.cmd.split(",") # 拆分输入 else: function_names = function_list_map.keys() # 遍历功能函数,执行 for function_map in function_list: run_func(function_names, function_map) logging.info(f'all process completed!') #db_tools.finalize_task_log(task_id) if __name__ == "__main__": # 命令行参数处理 keys_str = ",".join(function_list_map.keys()) parser = argparse.ArgumentParser(description='fetch iafd data.') parser.add_argument("--cmd", type=str, help=f"Comma-separated list of function shortcuts: {keys_str}") parser.add_argument("--mode", type=str, help=f"Fetch list or Download pdf: (fetch, down, lastweek)") parser.add_argument("--begin", type=str, help=f"begin date, YYYY-mm-dd") parser.add_argument("--end", type=str, help=f"end date, YYYY-mm-dd") parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)') parser.add_argument('--force', action='store_true', help='force update (true for rewrite all)') args = parser.parse_args() main(args.cmd, args.mode, args.debug, args.force, args.begin, args.end)