modify scripts

This commit is contained in:
oscarz
2025-03-17 11:08:13 +08:00
parent e6327fbe73
commit f43cd53159
177 changed files with 5 additions and 178173 deletions

85
reports_em/config.py Normal file
View File

@ -0,0 +1,85 @@
import logging
import os
import inspect
import time
from datetime import datetime
from logging.handlers import RotatingFileHandler
from collections import defaultdict
home_dir = os.path.expanduser("~")
global_host_data_dir = f'{home_dir}/hostdir/stock_data'
# 统计日志频率
log_count = defaultdict(int) # 记录日志的次数
last_log_time = defaultdict(float) # 记录上次写入的时间戳
class RateLimitFilter(logging.Filter):
"""
频率限制过滤器:
1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略
2. 如果日志速率超过 100 条/秒,发出告警
"""
LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次
def filter(self, record):
global log_count, last_log_time
message_key = record.getMessage() # 获取日志内容
# 计算当前时间
now = time.time()
elapsed = now - last_log_time[message_key]
# 限制相同日志的写入频率
if elapsed < 60: # 60 秒内
log_count[message_key] += 1
if log_count[message_key] > self.LOG_LIMIT:
return False # 直接丢弃
else:
log_count[message_key] = 1 # 超过 60 秒,重新计数
last_log_time[message_key] = now
return True # 允许写入日志
def setup_logging(log_filename=None):
if log_filename is None:
caller_frame = inspect.stack()[1]
caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0]
current_date = datetime.now().strftime('%Y%m%d')
log_filename = f'../log/{caller_filename}_{current_date}.log'
max_log_size = 100 * 1024 * 1024 # 10 MB
max_log_files = 10 # 最多保留 10 个日志文件
file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
# 创建 logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers = [] # 避免重复添加 handler
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 添加频率限制
rate_limit_filter = RateLimitFilter()
file_handler.addFilter(rate_limit_filter)
console_handler.addFilter(rate_limit_filter)
# 运行示例
if __name__ == "__main__":
setup_logging()
for i in range(1000):
logging.info("测试日志,检测频率限制")
time.sleep(0.01) # 模拟快速写入日志

115
reports_em/deploy.sh Executable file
View File

@ -0,0 +1,115 @@
#!/bin/bash
# 远程服务器列表,按需修改
SERVERS=("175.178.54.98" "1.12.218.143" "43.139.169.25" "129.204.180.174" "42.194.142.169")
REMOTE_USER="ubuntu"
REMOTE_SCRIPT_DIR="/home/ubuntu/pyscripts/stockapp/reports_em"
REMOTE_LOG_DIR="/home/ubuntu/pyscripts/stockapp/log"
DATA_DIR="/home/ubuntu/hostdir/stock_data/pdfs"
LOCAL_FILES=("config.py" "em_reports.py" "fetch.py")
# 远程任务参数配置(每台服务器的不同参数)
TASK_PARAMS=(
"--mode=down --begin=2024-01-01 --end=2024-06-30"
"--mode=down --begin=2023-01-01 --end=2023-06-30"
"--mode=down --begin=2022-01-01 --end=2022-06-30"
"--mode=down --begin=2021-01-01 --end=2021-06-30"
"--mode=down --begin=2020-01-01 --end=2020-06-30"
)
# 推送代码到所有服务器
function push_code() {
for SERVER in "${SERVERS[@]}"; do
echo "Pushing code to $SERVER..."
scp "${LOCAL_FILES[@]}" "$REMOTE_USER@$SERVER:$REMOTE_SCRIPT_DIR/"
done
}
# 启动任务
function start_tasks() {
for i in "${!SERVERS[@]}"; do
SERVER="${SERVERS[$i]}"
PARAMS="${TASK_PARAMS[$i]}"
echo "Starting task on $SERVER with params: $PARAMS"
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 &"
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
# nohup ... < /dev/null防止 nohup 等待 stdin立即释放控制权。
# & disown确保进程与 SSH 彻底分离,避免 SIGHUP 信号影响SSH 立即返回。
# ssh -n禁用 ssh 的 stdin防止远程进程等待输入。
ssh -n "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
done
}
# 停止任务
function stop_tasks() {
for SERVER in "${SERVERS[@]}"; do
echo "Stopping task on $SERVER..."
ssh "$REMOTE_USER@$SERVER" "pkill -f 'python3 ./fetch.py'"
done
}
# 获取任务进度
function check_progress() {
for SERVER in "${SERVERS[@]}"; do
echo -e "\nChecking progress on $SERVER..."
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
if [ "$PROCESS_COUNT" -gt 0 ]; then
echo "Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress"
else
echo "Process status: Not running"
fi
echo "Total files: $FILE_COUNT"
echo "Total size : $FILE_SIZE"
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
echo "Error lines: $ERROR_LINES"
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
echo "Task count: $TASK_COUNT"
done
}
# 获取任务进度
function check_progress_robot() {
result=""
for SERVER in "${SERVERS[@]}"; do
result+="\nChecking progress on $SERVER...\n"
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
if [ "$PROCESS_COUNT" -gt 0 ]; then
result+="Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress\n"
else
result+="Process status: Not running\n"
fi
result+="Total files: $FILE_COUNT\n"
result+="Total size : $FILE_SIZE\n"
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
result+="Error lines: $ERROR_LINES\n"
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
done
echo -e "$result"
# 调用 Python 脚本发送消息
python3 ./robot.py "$result"
}
# 脚本菜单
case "$1" in
push) push_code ;;
start) start_tasks ;;
stop) stop_tasks ;;
check) check_progress ;;
*)
echo "Usage: $0 {push|start|stop|check}"
exit 1
;;
esac

316
reports_em/em_reports.py Normal file
View File

@ -0,0 +1,316 @@
import os
import json
import requests
import time
import logging
from bs4 import BeautifulSoup
import sqlite_utils as db_tools
import config
# 获取个股研报列表的指定页
def fetch_reports_by_stock(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):
# 请求头
HEADERS = {
"Accept": "application/json, text/javascript, */*; q=0.01",
"Content-Type": "application/json",
"Origin": "https://data.eastmoney.com",
"Referer": "https://data.eastmoney.com/report/stock.jshtml",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
}
# 请求 URL
URL = "https://reportapi.eastmoney.com/report/list2"
payload = {
"beginTime": start_date,
"endTime": end_date,
"industryCode": "*",
"ratingChange": None,
"rating": None,
"orgCode": None,
"code": "*",
"rcode": "",
"pageSize": page_size,
"p": page_no,
"pageNo": page_no,
"pageNum": page_no,
"pageNumber": page_no
}
logging.debug(f'begin: {start_date}, end: {end_date}')
for attempt in range(max_retries):
try:
response = requests.post(URL, headers=HEADERS, json=payload, timeout=10)
response.raise_for_status()
data = response.json()
return data
except requests.RequestException as e:
logging.warning(f"network error on {URL}: {e}, Retring...")
logging.error(f'Fetching failed after max retries. {URL}')
return None # 达到最大重试次数仍然失败
# 获取行业研报列表的指定页
def fetch_reports_by_industry(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
"Referer": "https://data.eastmoney.com/report/industry.jshtml"
}
url = "https://reportapi.eastmoney.com/report/list"
params = {
"cb": "datatable1413600",
"industryCode": "*",
"pageSize": page_size,
"industry": "*",
"rating": "*",
"ratingChange": "*",
"beginTime": start_date,
"endTime": end_date,
"pageNo": page_no,
"fields": "",
"qType": 1,
"orgCode": "",
"rcode": "",
"p": page_no,
"pageNum": page_no,
"pageNumber": page_no,
"_": int(time.time() * 1000) # 动态时间戳
}
for attempt in range(max_retries):
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
response.raise_for_status()
# 去掉回调函数包装
json_text = response.text.strip("datatable1413600(").rstrip(");")
data = json.loads(json_text)
return data
except requests.RequestException as e:
logging.warning(f"network error on {url}: {e}, Retring...")
return None
except json.JSONDecodeError as e:
logging.warning(f"json decode error on {url}: {e}, Retring...")
return None
logging.error(f'Fetching failed after max retries. {url}')
return None # 达到最大重试次数仍然失败
# 获取宏观研报列表的指定页
def fetch_reports_by_macresearch(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
"Referer": "https://data.eastmoney.com/report/macresearch.jshtml"
}
url = "https://reportapi.eastmoney.com/report/jg"
params = {
"cb": "datatable2612129",
"industryCode": "*",
"pageSize": page_size,
"author": "",
"beginTime": start_date,
"endTime": end_date,
"pageNo": page_no,
"fields": "",
"qType": 3,
"orgCode": "",
"rcode": "",
"p": page_no,
"pageNum": page_no,
"pageNumber": page_no,
"_": int(time.time() * 1000) # 动态时间戳
}
for attempt in range(max_retries):
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
response.raise_for_status()
# 去掉回调函数包装
json_text = response.text.strip("datatable2612129(").rstrip(");")
data = json.loads(json_text)
return data
except requests.RequestException as e:
logging.warning(f"network error on {url}: {e}, Retring...")
return None
except json.JSONDecodeError as e:
logging.warning(f"json decode error on {url}: {e}, Retring...")
return None
logging.error(f'Fetching failed after max retries. {url}')
return None # 达到最大重试次数仍然失败
# 获取策略研报列表的指定页
def fetch_reports_by_strategy(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
"Referer": "https://data.eastmoney.com/report/strategyreport.jshtml"
}
url = "https://reportapi.eastmoney.com/report/jg"
params = {
"cb": "datatable5349866",
"industryCode": "*",
"pageSize": page_size,
"author": "",
"beginTime": start_date,
"endTime": end_date,
"pageNo": page_no,
"fields": "",
"qType": 2,
"orgCode": "",
"rcode": "",
"p": page_no,
"pageNum": page_no,
"pageNumber": page_no,
"_": int(time.time() * 1000) # 动态时间戳
}
for attempt in range(max_retries):
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
response.raise_for_status()
# 去掉回调函数包装
json_text = response.text.strip("datatable5349866(").rstrip(");")
data = json.loads(json_text)
return data
except requests.RequestException as e:
logging.warning(f"network error on {url}: {e}, Retring...")
return None
except json.JSONDecodeError as e:
logging.warning(f"json decode error on {url}: {e}, Retring...")
return None
logging.error(f'Fetching failed after max retries. {url}')
return None # 达到最大重试次数仍然失败
# 获取新股研报列表的指定页
def fetch_reports_by_newstock(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
"Referer": "https://data.eastmoney.com/report/newstock.jshtml"
}
url = "https://reportapi.eastmoney.com/report/newStockList"
params = {
"cb": "datatable5144183",
"pageSize": page_size,
"author": "",
"beginTime": start_date,
"endTime": end_date,
"pageNo": page_no,
"fields": "",
"qType": 4,
"orgCode": "",
"rcode": "",
"p": page_no,
"pageNum": page_no,
"pageNumber": page_no,
"_": int(time.time() * 1000) # 动态时间戳
}
for attempt in range(max_retries):
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
response.raise_for_status()
# 去掉回调函数包装
json_text = response.text.strip("datatable5144183(").rstrip(");")
data = json.loads(json_text)
return data
except requests.RequestException as e:
logging.warning(f"network error on {url}: {e}, Retring...")
return None
except json.JSONDecodeError as e:
logging.warning(f"json decode error on {url}: {e}, Retring...")
return None
logging.error(f'Fetching failed after max retries. {url}')
return None # 达到最大重试次数仍然失败
# 访问指定 infoCode 的页面,提取 PDF 下载链接
def fetch_pdf_link(url, max_retries = 3):
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36"
}
for attempt in range(max_retries):
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
# 解析 HTML
soup = BeautifulSoup(response.text, "html.parser")
pdf_link = soup.find("a", class_="pdf-link")
if pdf_link and "href" in pdf_link.attrs:
return pdf_link["href"]
else:
logging.warning(f"未找到 PDF 链接: {url}")
return None
except requests.RequestException as e:
logging.error(f"请求失败: {url} {e}")
logging.error(f'Fetching failed after max retries. {url}')
return None # 达到最大重试次数仍然失败
def is_valid_pdf(file_path):
try:
with open(file_path, "rb") as f:
header = f.read(4)
return header == b"%PDF"
except Exception as e:
logging.error(f"验证 PDF 失败: {e}")
return False
def download_pdf_wget(pdf_url, save_path):
cmd = f'wget -O "{save_path}" "{pdf_url}" --quiet --user-agent="Mozilla/5.0"'
os.system(cmd)
return os.path.exists(save_path) and is_valid_pdf(save_path)
# 下载 PDF 并保存到本地
def download_pdf(pdf_url, save_path, max_retries=5):
for attempt in range(max_retries):
down = download_pdf_wget(pdf_url, save_path)
if down:
return True
return False
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36"
}
try:
response = requests.get(pdf_url, headers=headers, stream=True, timeout=20)
response.raise_for_status()
with open(save_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
return True
except requests.RequestException as e:
logging.error(f"PDF 下载失败: {e}")
return False

316
reports_em/fetch.py Normal file
View File

@ -0,0 +1,316 @@
import json
import time
import csv
import os
import re
import argparse
import shutil
import logging
from datetime import datetime, timedelta
from functools import partial
import config
import sqlite_utils as db_tools
import em_reports as em
import utils
config.setup_logging()
debug = False
force = False
pdf_base_dir = f"{config.global_host_data_dir}/pdfs" # 下载 PDF 存放目录
map_pdf_page = {
utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
}
map_tbl_name = {
utils.tbl_stock : '个股研报',
utils.tbl_new_stock : '新股研报',
utils.tbl_strategy : '策略报告',
utils.tbl_macresearch : '宏观研究',
utils.tbl_industry : '行业研报'
}
current_date = datetime.now()
seven_days_ago = current_date - timedelta(days=7)
two_years_ago = current_date - timedelta(days=2*365)
start_date = two_years_ago.strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d")
this_week_date = seven_days_ago.strftime("%Y-%m-%d")
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
# 示例:获取前 3 页的数据
max_pages = 100000
page = 1
while page <= max_pages:
while True:
data = fetch_func(page_no=page, start_date=s_date, end_date=e_date, page_size=100)
if data:
break
if page == 1:
max_pages = data.get('TotalPage', 1000000)
for row in data.get('data', []):
# 统一以 infoCode 为 UNIQE 键,所以这里对它进行赋值
if row.get('infoCode') is None and row.get('encodeUrl'):
row['infoCode'] = row['encodeUrl']
row_id = db_tools.insert_or_update_common(row, table_name)
if row_id:
logging.debug(f'insert one row. rowid:{row_id}, ')
else:
logging.warning(f'insert data failed. page : {page}')
return False
# 写本地json文件必要性不大
#utils.save_json_to_file(data, f'{utils.json_data_dir}/{data_dir_prefix}', f'{data_dir_prefix}_report_{page}.json')
logging.info(f"{page} 页, 获取 {len(data['data'])} 条数据, 共 {max_pages}")
page += 1
time.sleep(1) # 避免请求过快
# 股票所用的url
def parse_func_general(row, tbl_name):
info_code = row['infoCode']
title = row['title'].replace("/", "_").replace("\\", "_")
org_sname = row['orgSName']
stock_name = row['stockName']
industry_name = row['industryName']
publish_date = row['publishDate'].split(" ")[0]
# 建表的时候默认值有点问题
if stock_name == '' or stock_name=="''":
stock_name = 'None'
if industry_name == '':
industry_name = 'None'
if org_sname == '':
org_sname = 'None'
report_type = map_tbl_name.get(tbl_name, 'None')
file_name = f"{publish_date}_{report_type}_{org_sname}_{industry_name}_{stock_name}_{title}.pdf"
url = map_pdf_page.get(tbl_name, None)
if url is None:
logging.warning(f'wrong table name: {tbl_name}')
return None, None
url = url.format(info_code)
# 拼目录
dir_year = publish_date[:4] if len(publish_date)>4 else ''
dir_path = f'{pdf_base_dir}/{dir_year}/{map_tbl_name[tbl_name]}'
os.makedirs(dir_path, exist_ok=True)
return url, os.path.join(dir_path, file_name)
# 通用下载函数
def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None):
# 下载pdf
if s_date:
querystr += f" AND publishDate >= '{s_date} 00:00:00.000' "
if e_date:
querystr += f" AND publishDate <= '{e_date} 23:59:59.999' "
rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit)
if rows is None:
rows = []
for row in rows:
url, file_path = parse_func(row, tbl_name)
if url is None or file_path is None:
logging.warning(f'wrong url or file_path. tbl_name: {tbl_name}')
continue
# 已经存在的,跳过
if file_path and os.path.isfile(file_path):
logging.info(f'{file_path} already exists. skipping...')
continue
# 获取pdf链接地址
pdf_url = em.fetch_pdf_link(url)
if pdf_url:
# 下载 PDF
down = em.download_pdf(pdf_url, file_path)
if down:
logging.info(f'saved file {file_path}')
else:
logging.warning(f'download pdf file error. file_path: {pdf_url}, save_path: {file_path}')
else:
logging.warning(f'cannot get pdf link. url: {url}, save_path: {file_path}')
time.sleep(1) # 避免请求过快
# 获取股票报告列表
def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock')
# 获取股票报告列表
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new')
# 获取行业报告列表
def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry')
# 获取行业报告列表
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch')
# 获取行业报告列表
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy')
# 下载股票pdf
def download_pdf_stock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_stock, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_newstock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_new_stock, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_industry(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_industry, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_macresearch(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_strategy(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None)
# 建立缩写到函数的映射
function_list_map = {
'stock' : fetch_reports_list_stock,
'new' : fetch_reports_list_newstock,
'indust' : fetch_reports_list_industry,
'macro' : fetch_reports_list_macresearch,
'stra' : fetch_reports_list_strategy,
}
function_down_map = {
'stock' : download_pdf_stock,
'new' : download_pdf_newstock,
'indust' : download_pdf_industry,
'macro' : download_pdf_macresearch,
'stra' : download_pdf_strategy,
}
# 获取最新一周的报告列表
def create_last_week_links(s_date=start_date, e_date=end_date):
last_week_dir = os.path.join(pdf_base_dir, 'last_week')
# 如果 last_week 目录存在,先删除它
if os.path.exists(last_week_dir):
for root, dirs, files in os.walk(last_week_dir, topdown=False):
for file in files:
file_path = os.path.join(root, file)
os.remove(file_path)
for dir in dirs:
dir_path = os.path.join(root, dir)
os.rmdir(dir_path)
os.rmdir(last_week_dir)
os.makedirs(last_week_dir)
for root, dirs, files in os.walk(pdf_base_dir):
# 跳过 last_week 目录及其子目录
if 'last_week' in dirs:
dirs.remove('last_week')
for file in files:
if file.endswith('.pdf'):
match = re.match(r'(\d{4}-\d{2}-\d{2})_(.*)\.pdf', file)
if match:
date_str = match.group(1)
if utils.is_within_last_week(date_str):
file_path = os.path.join(root, file)
# 获取子目录名称
sub_dir_name = os.path.basename(os.path.dirname(file_path))
# 生成新的链接名称,添加子目录名前缀
new_file_name = f"[{sub_dir_name}]_{file}"
link_name = os.path.join(last_week_dir, new_file_name)
if not os.path.exists(link_name):
os.symlink(file_path, link_name)
# 执行功能函数
def run_func(function_names, function_map):
global start_date
global end_date
for short_name in function_names:
func = function_map.get(short_name.strip()) # 从映射中获取对应的函数
if callable(func):
#db_tools.update_task_log(task_id, task_status=f'Running {func}')
logging.info(f'exec function: {func}, begin: {start_date}, end: {end_date}')
func(start_date, end_date)
else:
logging.warning(f"Warning: {short_name} is not a valid function shortcut.")
# 主函数
def main(cmd, mode, args_debug, args_force, begin, end):
global debug
debug = args_debug
global force
force = args_force
global start_date
start_date = begin if begin else start_date
global end_date
end_date = end if end else end_date
# 开启任务
#task_id = db_tools.insert_task_log()
task_id = 0
if task_id is None:
logging.warning(f'insert task log error.')
return None
logging.info(f'running task. id: {task_id}, debug: {debug}, force: {force}, cmd: {cmd}, mode: {mode}')
# 如果是lastweek我们先执行列表再执行下载
function_list = []
if mode == 'fetch':
function_list.append(function_list_map)
elif mode == 'down':
function_list.append(function_down_map)
elif mode == 'lastweek':
start_date = this_week_date
function_list.append(function_list_map)
function_list.append(function_down_map)
else:
function_list.append(function_list_map)
# 执行指定的函数
if cmd and mode !='lastweek':
function_names = args.cmd.split(",") # 拆分输入
else:
function_names = function_list_map.keys()
# 遍历功能函数,执行
for function_map in function_list:
run_func(function_names, function_map)
logging.info(f'all process completed!')
#db_tools.finalize_task_log(task_id)
if __name__ == "__main__":
# 命令行参数处理
keys_str = ",".join(function_list_map.keys())
parser = argparse.ArgumentParser(description='fetch iafd data.')
parser.add_argument("--cmd", type=str, help=f"Comma-separated list of function shortcuts: {keys_str}")
parser.add_argument("--mode", type=str, help=f"Fetch list or Download pdf: (fetch, down, lastweek)")
parser.add_argument("--begin", type=str, help=f"begin date, YYYY-mm-dd")
parser.add_argument("--end", type=str, help=f"end date, YYYY-mm-dd")
parser.add_argument('--debug', action='store_true', help='Enable debug mode (limit records)')
parser.add_argument('--force', action='store_true', help='force update (true for rewrite all)')
args = parser.parse_args()
main(args.cmd, args.mode, args.debug, args.force, args.begin, args.end)

248
reports_em/sqlite_utils.py Normal file
View File

@ -0,0 +1,248 @@
import sqlite3
import json
import config
import utils
import logging
import sys
from datetime import datetime
# 连接 SQLite 数据库
DB_PATH = f"{config.global_host_data_dir}/stock_report.db" # 替换为你的数据库文件
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# 获取表的列名和默认值
def get_table_columns_and_defaults(tbl_name):
try:
cursor.execute(f"PRAGMA table_info({tbl_name})")
columns = cursor.fetchall()
column_info = {}
for col in columns:
col_name = col[1]
default_value = col[4]
column_info[col_name] = default_value
return column_info
except sqlite3.Error as e:
logging.error(f"Error getting table columns: {e}")
return None
# 检查并处理数据
def check_and_process_data(data, tbl_name):
column_info = get_table_columns_and_defaults(tbl_name=tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id': # 自增主键,不需要用户提供
continue
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
continue
if col in ['author', 'authorID']:
values = data.get(col, [])
processed_data[col] = ','.join(values)
elif col in data:
processed_data[col] = data[col]
else:
if default is not None:
processed_data[col] = default
else:
processed_data[col] = None
return processed_data
# 插入或更新数据
def insert_or_update_common(data, tbl_name, uniq_key='infoCode'):
try:
processed_data = check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
sql = f'''
INSERT INTO {tbl_name} ({columns}, updated_at)
VALUES ({placeholders}, datetime('now', 'localtime'))
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
'''
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
# 查询数据
def query_reports_comm(tbl_name, querystr='', limit=None ):
try:
if tbl_name in [utils.tbl_stock, utils.tbl_new_stock, utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] :
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
else:
logging.warning(f'wrong table name: {tbl_name}')
return None
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
# 获取列名
column_names = [description[0] for description in cursor.description]
# 将结果转换为字典列表
result_dict_list = []
for row in results:
row_dict = {column_names[i]: value for i, value in enumerate(row)}
result_dict_list.append(row_dict)
return result_dict_list
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''
# 插入或更新 industry_report 数据
def insert_or_update_report(report):
try:
sql = """
INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate,
industryCode, industryName, emRatingCode, emRatingValue,
emRatingName, attachSize, attachPages, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachSize=excluded.attachSize,
attachPages=excluded.attachPages,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["orgCode"], report["orgName"],
report["orgSName"], report["publishDate"], report["industryCode"],
report["industryName"], report["emRatingCode"], report["emRatingValue"],
report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0)
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_industry_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
# 插入或更新 industry_report 数据
def insert_or_update_stock_report(report):
try:
sql = """
INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName,
publishDate, industryCode, industryName, emIndustryCode, emRatingCode,
emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
stockName=excluded.stockName,
stockCode=excluded.stockCode,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emIndustryCode=excluded.emIndustryCode,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachPages=excluded.attachPages,
attachSize=excluded.attachSize,
author=excluded.author,
authorID=excluded.authorID,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["stockName"], report["stockCode"],
report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"],
report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""),
report["emRatingCode"], report["emRatingValue"], report["emRatingName"],
report.get("attachPages", 0), report.get("attachSize", 0),
",".join(report.get("author", [])), ",".join(report.get("authorID", []))
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_stock_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''

36
reports_em/utils.py Normal file
View File

@ -0,0 +1,36 @@
import re
import os
import json
import time
import csv
import logging
from datetime import datetime
import config
tbl_stock = 'reports_stock'
tbl_new_stock = 'reports_newstrock'
tbl_strategy = 'reports_strategy'
tbl_macresearch = 'reports_macresearch'
tbl_industry = 'reports_industry'
json_data_dir = f'{config.global_host_data_dir}/em_reports/json_data'
# 保存 JSON 数据到本地文件
def save_json_to_file(data, file_path, file_name):
os.makedirs(file_path, exist_ok=True)
full_name = f"{file_path}/{file_name}"
with open(full_name, "w", encoding="utf-8") as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.debug(f"saved json data to: {full_name}")
# 判断日期字符串是否在最近七天内
def is_within_last_week(date_str):
try:
file_date = datetime.strptime(date_str, '%Y-%m-%d')
current_date = datetime.now()
diff = current_date - file_date
return diff.days <= 7
except ValueError:
return False