modify scripts

This commit is contained in:
oscarz
2025-03-24 10:48:35 +08:00
parent 7ded7c5a19
commit 1521ff1fc0
11 changed files with 248 additions and 565 deletions

View File

@ -1,86 +0,0 @@
import logging
import os
import inspect
import time
from datetime import datetime
from logging.handlers import RotatingFileHandler
from collections import defaultdict
home_dir = os.path.expanduser("~")
global_host_data_dir = f'{home_dir}/hostdir/stock_data'
global_share_db_dir = f'{home_dir}/sharedata/sqlite'
# 统计日志频率
log_count = defaultdict(int) # 记录日志的次数
last_log_time = defaultdict(float) # 记录上次写入的时间戳
class RateLimitFilter(logging.Filter):
"""
频率限制过滤器:
1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略
2. 如果日志速率超过 100 条/秒,发出告警
"""
LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次
def filter(self, record):
global log_count, last_log_time
message_key = record.getMessage() # 获取日志内容
# 计算当前时间
now = time.time()
elapsed = now - last_log_time[message_key]
# 限制相同日志的写入频率
if elapsed < 60: # 60 秒内
log_count[message_key] += 1
if log_count[message_key] > self.LOG_LIMIT:
return False # 直接丢弃
else:
log_count[message_key] = 1 # 超过 60 秒,重新计数
last_log_time[message_key] = now
return True # 允许写入日志
def setup_logging(log_filename=None):
if log_filename is None:
caller_frame = inspect.stack()[1]
caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0]
current_date = datetime.now().strftime('%Y%m%d')
log_filename = f'../log/{caller_filename}_{current_date}.log'
max_log_size = 100 * 1024 * 1024 # 10 MB
max_log_files = 10 # 最多保留 10 个日志文件
file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
# 创建 logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers = [] # 避免重复添加 handler
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 添加频率限制
rate_limit_filter = RateLimitFilter()
file_handler.addFilter(rate_limit_filter)
console_handler.addFilter(rate_limit_filter)
# 运行示例
if __name__ == "__main__":
setup_logging()
for i in range(1000):
logging.info("测试日志,检测频率限制")
time.sleep(0.01) # 模拟快速写入日志

View File

@ -1,115 +0,0 @@
#!/bin/bash
# 远程服务器列表,按需修改
SERVERS=("175.178.54.98" "1.12.218.143" "43.139.169.25" "129.204.180.174" "42.194.142.169")
REMOTE_USER="ubuntu"
REMOTE_SCRIPT_DIR="/home/ubuntu/pyscripts/stockapp/reports_em"
REMOTE_LOG_DIR="/home/ubuntu/pyscripts/stockapp/log"
DATA_DIR="/home/ubuntu/hostdir/stock_data/pdfs"
LOCAL_FILES=("config.py" "em_reports.py" "fetch.py")
# 远程任务参数配置(每台服务器的不同参数)
TASK_PARAMS=(
"--mode=down --begin=2024-01-01 --end=2024-06-30"
"--mode=down --begin=2023-01-01 --end=2023-06-30"
"--mode=down --begin=2022-01-01 --end=2022-06-30"
"--mode=down --begin=2021-01-01 --end=2021-06-30"
"--mode=down --begin=2020-01-01 --end=2020-06-30"
)
# 推送代码到所有服务器
function push_code() {
for SERVER in "${SERVERS[@]}"; do
echo "Pushing code to $SERVER..."
scp "${LOCAL_FILES[@]}" "$REMOTE_USER@$SERVER:$REMOTE_SCRIPT_DIR/"
done
}
# 启动任务
function start_tasks() {
for i in "${!SERVERS[@]}"; do
SERVER="${SERVERS[$i]}"
PARAMS="${TASK_PARAMS[$i]}"
echo "Starting task on $SERVER with params: $PARAMS"
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 &"
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
# nohup ... < /dev/null防止 nohup 等待 stdin立即释放控制权。
# & disown确保进程与 SSH 彻底分离,避免 SIGHUP 信号影响SSH 立即返回。
# ssh -n禁用 ssh 的 stdin防止远程进程等待输入。
ssh -n "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
done
}
# 停止任务
function stop_tasks() {
for SERVER in "${SERVERS[@]}"; do
echo "Stopping task on $SERVER..."
ssh "$REMOTE_USER@$SERVER" "pkill -f 'python3 ./fetch.py'"
done
}
# 获取任务进度
function check_progress() {
for SERVER in "${SERVERS[@]}"; do
echo -e "\nChecking progress on $SERVER..."
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
if [ "$PROCESS_COUNT" -gt 0 ]; then
echo "Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress"
else
echo "Process status: Not running"
fi
echo "Total files: $FILE_COUNT"
echo "Total size : $FILE_SIZE"
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
echo "Error lines: $ERROR_LINES"
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
echo "Task count: $TASK_COUNT"
done
}
# 获取任务进度
function check_progress_robot() {
result=""
for SERVER in "${SERVERS[@]}"; do
result+="\nChecking progress on $SERVER...\n"
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
if [ "$PROCESS_COUNT" -gt 0 ]; then
result+="Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress\n"
else
result+="Process status: Not running\n"
fi
result+="Total files: $FILE_COUNT\n"
result+="Total size : $FILE_SIZE\n"
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
result+="Error lines: $ERROR_LINES\n"
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
done
echo -e "$result"
# 调用 Python 脚本发送消息
python3 ./robot.py "$result"
}
# 脚本菜单
case "$1" in
push) push_code ;;
start) start_tasks ;;
stop) stop_tasks ;;
check) check_progress ;;
*)
echo "Usage: $0 {push|start|stop|check}"
exit 1
;;
esac

View File

@ -1,248 +0,0 @@
import sqlite3
import json
import config
import utils
import logging
import sys
from datetime import datetime
# 连接 SQLite 数据库
DB_PATH = f"{config.global_share_db_dir}/stock_report.db" # 替换为你的数据库文件
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# 获取表的列名和默认值
def get_table_columns_and_defaults(tbl_name):
try:
cursor.execute(f"PRAGMA table_info({tbl_name})")
columns = cursor.fetchall()
column_info = {}
for col in columns:
col_name = col[1]
default_value = col[4]
column_info[col_name] = default_value
return column_info
except sqlite3.Error as e:
logging.error(f"Error getting table columns: {e}")
return None
# 检查并处理数据
def check_and_process_data(data, tbl_name):
column_info = get_table_columns_and_defaults(tbl_name=tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id': # 自增主键,不需要用户提供
continue
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
continue
if col in ['author', 'authorID']:
values = data.get(col, [])
processed_data[col] = ','.join(values)
elif col in data:
processed_data[col] = data[col]
else:
if default is not None:
processed_data[col] = default
else:
processed_data[col] = None
return processed_data
# 插入或更新数据
def insert_or_update_common(data, tbl_name, uniq_key='infoCode'):
try:
processed_data = check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
sql = f'''
INSERT INTO {tbl_name} ({columns}, updated_at)
VALUES ({placeholders}, datetime('now', 'localtime'))
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
'''
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
# 查询数据
def query_reports_comm(tbl_name, querystr='', limit=None ):
try:
if tbl_name in [utils.tbl_stock, utils.tbl_new_stock, utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] :
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
else:
logging.warning(f'wrong table name: {tbl_name}')
return None
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
# 获取列名
column_names = [description[0] for description in cursor.description]
# 将结果转换为字典列表
result_dict_list = []
for row in results:
row_dict = {column_names[i]: value for i, value in enumerate(row)}
result_dict_list.append(row_dict)
return result_dict_list
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''
# 插入或更新 industry_report 数据
def insert_or_update_report(report):
try:
sql = """
INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate,
industryCode, industryName, emRatingCode, emRatingValue,
emRatingName, attachSize, attachPages, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachSize=excluded.attachSize,
attachPages=excluded.attachPages,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["orgCode"], report["orgName"],
report["orgSName"], report["publishDate"], report["industryCode"],
report["industryName"], report["emRatingCode"], report["emRatingValue"],
report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0)
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_industry_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
# 插入或更新 industry_report 数据
def insert_or_update_stock_report(report):
try:
sql = """
INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName,
publishDate, industryCode, industryName, emIndustryCode, emRatingCode,
emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
stockName=excluded.stockName,
stockCode=excluded.stockCode,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emIndustryCode=excluded.emIndustryCode,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachPages=excluded.attachPages,
attachSize=excluded.attachSize,
author=excluded.author,
authorID=excluded.authorID,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["stockName"], report["stockCode"],
report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"],
report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""),
report["emRatingCode"], report["emRatingValue"], report["emRatingName"],
report.get("attachPages", 0), report.get("attachSize", 0),
",".join(report.get("author", [])), ",".join(report.get("authorID", []))
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_stock_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''

View File

@ -1,65 +0,0 @@
import logging
import os
import inspect
from datetime import datetime
from pathlib import Path
# MySQL 配置
db_config = {
'host': 'testdb',
'user': 'root',
'password': 'mysqlpw',
'database': 'stockdb'
}
log_dir_prefix = '../log'
global_share_data_dir = '/root/sharedata'
global_stock_data_dir = '/root/hostdir/stock_data'
# 获取当前文件所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取项目根目录(假设当前文件在 src/strategy 下)
project_root = os.path.abspath(os.path.join(current_dir, '..', '..'))
# 获取log目录
def get_log_directory():
"""
获取项目根目录下的 log 目录路径。如果 log 目录不存在,则自动创建。
"""
# 获取当前文件所在目录
current_dir = Path(__file__).resolve().parent
# 找到项目根目录,假设项目根目录下有一个 log 文件夹
project_root = current_dir
while project_root.name != 'src' and project_root != project_root.parent:
project_root = project_root.parent
project_root = project_root.parent # 回到项目根目录
# 确保 log 目录存在
log_dir = project_root / 'log'
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def get_caller_filename():
# 获取调用 setup_logging 的脚本文件名
caller_frame = inspect.stack()[2]
caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0]
return caller_filename
# 设置日志配置
def setup_logging(log_filename=None):
# 如果未传入 log_filename则使用当前脚本名称作为日志文件名
if log_filename is None:
caller_filename = get_caller_filename()
common_log_dir = get_log_directory()
current_date = datetime.now().strftime('%Y%m%d')
# 拼接 log 文件名,将日期加在扩展名前
log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log'
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler()
])

View File

@ -9,8 +9,13 @@ db_config = {
'database': 'stockdb' 'database': 'stockdb'
} }
global_share_data_dir = '/root/sharedata' home_dir = os.path.expanduser("~")
global_stock_data_dir = '/root/hostdir/stock_data' global_host_data_dir = f'{home_dir}/hostdir/stock_data'
global_share_db_dir = f'{home_dir}/sharedata/sqlite'
# 兼容以前的定义
global_stock_data_dir = global_host_data_dir
global_share_data_dir = f'{home_dir}/sharedata'
# 获取当前文件所在目录 # 获取当前文件所在目录
current_dir = Path(__file__).resolve().parent current_dir = Path(__file__).resolve().parent

View File

@ -4,8 +4,6 @@ import requests
import time import time
import logging import logging
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import sqlite_utils as db_tools
import config
# 获取个股研报列表的指定页 # 获取个股研报列表的指定页
def fetch_reports_by_stock(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3): def fetch_reports_by_stock(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):

123
src/db_utils/reports.py Normal file
View File

@ -0,0 +1,123 @@
import sqlite3
import logging
import sys
from datetime import datetime
class DatabaseConnectionError(Exception):
pass
class StockReportDB:
# 定义类属性(静态变量)
TBL_STOCK = 'reports_stock'
TBL_NEW_STOCK = 'reports_newstrock'
TBL_STRATEGY = 'reports_strategy'
TBL_MACRESEARCH = 'reports_macresearch'
TBL_INDUSTRY = 'reports_industry'
def __init__(self, db_path):
self.DB_PATH = db_path
self.conn = None
self.cursor = None
try:
self.conn = sqlite3.connect(self.DB_PATH)
self.cursor = self.conn.cursor()
except sqlite3.Error as e:
logging.error(f"数据库连接失败: {e}")
raise DatabaseConnectionError("数据库连接失败")
def __get_table_columns_and_defaults(self, tbl_name):
try:
self.cursor.execute(f"PRAGMA table_info({tbl_name})")
columns = self.cursor.fetchall()
column_info = {}
for col in columns:
col_name = col[1]
default_value = col[4]
column_info[col_name] = default_value
return column_info
except sqlite3.Error as e:
logging.error(f"Error getting table columns: {e}")
return None
def __check_and_process_data(self, data, tbl_name):
column_info = self.__get_table_columns_and_defaults(tbl_name=tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id': # 自增主键,不需要用户提供
continue
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
continue
if col in ['author', 'authorID']:
values = data.get(col, [])
processed_data[col] = ','.join(values)
elif col in data:
processed_data[col] = data[col]
else:
if default is not None:
processed_data[col] = default
else:
processed_data[col] = None
return processed_data
def insert_or_update_common(self, data, tbl_name, uniq_key='infoCode'):
try:
processed_data = self.__check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
sql = f'''
INSERT INTO {tbl_name} ({columns}, updated_at)
VALUES ({placeholders}, datetime('now', 'localtime'))
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
'''
self.cursor.execute(sql, values)
self.conn.commit()
# 获取插入或更新后的 report_id
self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
report_id = self.cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
def query_reports_comm(self, tbl_name, querystr='', limit=None):
try:
if tbl_name in [StockReportDB.TBL_STOCK, StockReportDB.TBL_NEW_STOCK, StockReportDB.TBL_INDUSTRY, StockReportDB.TBL_MACRESEARCH, StockReportDB.TBL_STRATEGY]:
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
else:
logging.warning(f'wrong table name: {tbl_name}')
return None
if limit:
sql = sql + f' limit {limit}'
self.cursor.execute(sql)
results = self.cursor.fetchall()
# 获取列名
column_names = [description[0] for description in self.cursor.description]
# 将结果转换为字典列表
result_dict_list = []
for row in results:
row_dict = {column_names[i]: value for i, value in enumerate(row)}
result_dict_list.append(row_dict)
return result_dict_list
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
def __del__(self):
if self.conn:
self.conn.close()

View File

@ -9,34 +9,41 @@ import shutil
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import config import src.crawler.em.reports as em
import sqlite_utils as db_tools import src.utils.utils as utils
import em_reports as em from src.config.config import global_host_data_dir, global_share_db_dir
import utils from src.db_utils.reports import StockReportDB, DatabaseConnectionError
from src.logger.logger import setup_logging
config.setup_logging() # 初始化日志
setup_logging()
debug = False debug = False
force = False force = False
pdf_base_dir = f"{config.global_host_data_dir}/pdfs" # 下载 PDF 存放目录 pdf_base_dir = f"{global_host_data_dir}/pdfs" # 下载 PDF 存放目录
# 定义下载页面的链接
map_pdf_page = { map_pdf_page = {
utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html", StockReportDB.TBL_STOCK : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html", StockReportDB.TBL_NEW_STOCK : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}", StockReportDB.TBL_STRATEGY : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}", StockReportDB.TBL_MACRESEARCH : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}" StockReportDB.TBL_INDUSTRY : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
} }
# 定义表名的映射,作为存储路径用
map_tbl_name = { map_tbl_name = {
utils.tbl_stock : '个股研报', StockReportDB.TBL_STOCK : '个股研报',
utils.tbl_new_stock : '新股研报', StockReportDB.TBL_NEW_STOCK : '新股研报',
utils.tbl_strategy : '策略报告', StockReportDB.TBL_STRATEGY : '策略报告',
utils.tbl_macresearch : '宏观研究', StockReportDB.TBL_MACRESEARCH : '宏观研究',
utils.tbl_industry : '行业研报' StockReportDB.TBL_INDUSTRY : '行业研报'
} }
# 初始化数据库连接
db_path = f"{global_share_db_dir}/stock_report.db"
db_tools = None
current_date = datetime.now() current_date = datetime.now()
seven_days_ago = current_date - timedelta(days=7) seven_days_ago = current_date - timedelta(days=7)
two_years_ago = current_date - timedelta(days=2*365) two_years_ago = current_date - timedelta(days=2*365)
@ -45,6 +52,7 @@ start_date = two_years_ago.strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d") end_date = current_date.strftime("%Y-%m-%d")
this_week_date = seven_days_ago.strftime("%Y-%m-%d") this_week_date = seven_days_ago.strftime("%Y-%m-%d")
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix): def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
# 示例:获取前 3 页的数据 # 示例:获取前 3 页的数据
max_pages = 100000 max_pages = 100000
@ -148,39 +156,39 @@ def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_d
# 获取股票报告列表 # 获取股票报告列表
def fetch_reports_list_stock(s_date=start_date, e_date=end_date): def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock') return fetch_reports_list_general(em.fetch_reports_by_stock, StockReportDB.TBL_STOCK, s_date, e_date, 'stock')
# 获取股票报告列表 # 获取股票报告列表
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date): def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new') return fetch_reports_list_general(em.fetch_reports_by_newstock, StockReportDB.TBL_NEW_STOCK, s_date, e_date, 'new')
# 获取行业报告列表 # 获取行业报告列表
def fetch_reports_list_industry(s_date=start_date, e_date=end_date): def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry') return fetch_reports_list_general(em.fetch_reports_by_industry, StockReportDB.TBL_INDUSTRY, s_date, e_date, 'industry')
# 获取行业报告列表 # 获取行业报告列表
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date): def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch') return fetch_reports_list_general(em.fetch_reports_by_macresearch, StockReportDB.TBL_MACRESEARCH, s_date, e_date, 'macresearch')
# 获取行业报告列表 # 获取行业报告列表
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date): def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy') return fetch_reports_list_general(em.fetch_reports_by_strategy, StockReportDB.TBL_STRATEGY, s_date, e_date, 'strategy')
# 下载股票pdf # 下载股票pdf
def download_pdf_stock(s_date=start_date, e_date=end_date): def download_pdf_stock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_stock, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STOCK, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_newstock(s_date=start_date, e_date=end_date): def download_pdf_newstock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_new_stock, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_NEW_STOCK, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_industry(s_date=start_date, e_date=end_date): def download_pdf_industry(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_industry, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_INDUSTRY, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_macresearch(s_date=start_date, e_date=end_date): def download_pdf_macresearch(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_MACRESEARCH, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_strategy(s_date=start_date, e_date=end_date): def download_pdf_strategy(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STRATEGY, ' ', s_date, e_date, limit=2 if debug else None)
# 建立缩写到函数的映射 # 建立缩写到函数的映射
@ -265,6 +273,15 @@ def main(cmd, mode, args_debug, args_force, begin, end):
global end_date global end_date
end_date = end if end else end_date end_date = end if end else end_date
# 初始化DB
global db_tools
try:
db_tools = StockReportDB(db_path)
# 进行数据库操作
except DatabaseConnectionError as e:
logging.error(f"数据库连接失败: {e}")
return False
# 开启任务 # 开启任务
#task_id = db_tools.insert_task_log() #task_id = db_tools.insert_task_log()
task_id = 0 task_id = 0

View File

@ -1,10 +1,46 @@
import logging import logging
import os import os
import inspect import inspect
import time
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from logging.handlers import RotatingFileHandler
from collections import defaultdict
from src.config.config import get_log_directory, get_src_directory from src.config.config import get_log_directory, get_src_directory
# 统计日志频率
log_count = defaultdict(int) # 记录日志的次数
last_log_time = defaultdict(float) # 记录上次写入的时间戳
class RateLimitFilter(logging.Filter):
"""
频率限制过滤器:
1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略
2. 如果日志速率超过 100 条/秒,发出告警
"""
LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次
def filter(self, record):
global log_count, last_log_time
message_key = record.getMessage() # 获取日志内容
# 计算当前时间
now = time.time()
elapsed = now - last_log_time[message_key]
# 限制相同日志的写入频率
if elapsed < 60: # 60 秒内
log_count[message_key] += 1
if log_count[message_key] > self.LOG_LIMIT:
return False # 直接丢弃
else:
log_count[message_key] = 1 # 超过 60 秒,重新计数
last_log_time[message_key] = now
return True # 允许写入日志
def get_caller_filename(): def get_caller_filename():
# 获取调用栈 # 获取调用栈
stack = inspect.stack() stack = inspect.stack()
@ -26,7 +62,7 @@ def get_caller_filename():
return os.path.splitext(os.path.basename(frame_info.filename))[0] return os.path.splitext(os.path.basename(frame_info.filename))[0]
return None return None
# 设置日志配置
def setup_logging(log_filename=None): def setup_logging(log_filename=None):
# 如果未传入 log_filename则使用当前脚本名称作为日志文件名 # 如果未传入 log_filename则使用当前脚本名称作为日志文件名
if log_filename is None: if log_filename is None:
@ -36,8 +72,28 @@ def setup_logging(log_filename=None):
# 拼接 log 文件名,将日期加在扩展名前 # 拼接 log 文件名,将日期加在扩展名前
log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log' log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log'
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s', max_log_size = 100 * 1024 * 1024 # 10 MB
handlers=[ max_log_files = 10 # 最多保留 10 个日志文件
logging.FileHandler(log_filename),
logging.StreamHandler() file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files)
]) file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
# 创建 logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers = [] # 避免重复添加 handler
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 添加频率限制
rate_limit_filter = RateLimitFilter()
file_handler.addFilter(rate_limit_filter)
console_handler.addFilter(rate_limit_filter)

View File

@ -1,6 +1,7 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import os import os
import warnings
from src.strategy.prepare import fetch_his_kline from src.strategy.prepare import fetch_his_kline
import src.config.config as config import src.config.config as config
import src.crawler.em.stock as em_stock import src.crawler.em.stock as em_stock
@ -15,10 +16,15 @@ def select_stocks(stock_map):
df = fetch_his_kline(stock_code) df = fetch_his_kline(stock_code)
close_prices = df['close'].values close_prices = df['close'].values
# 捕获 RuntimeWarning 警告
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", RuntimeWarning)
# 使用 MyTT 库计算不同周期的 RSI # 使用 MyTT 库计算不同周期的 RSI
rsi_6 = RSI(close_prices, 6) rsi_6 = RSI(close_prices, 6)
rsi_12 = RSI(close_prices, 12) rsi_12 = RSI(close_prices, 12)
rsi_24 = RSI(close_prices, 24) rsi_24 = RSI(close_prices, 24)
if w:
print(f"股票代码 {stock_code} {stock_name} 在计算 RSI 时出现警告: {w[0].message}")
df['rsi_6'] = rsi_6 df['rsi_6'] = rsi_6
df['rsi_12'] = rsi_12 df['rsi_12'] = rsi_12
@ -96,7 +102,8 @@ if __name__ == "__main__":
codes = ['105.QFIN', '105.FUTU'] codes = ['105.QFIN', '105.FUTU']
# 从网络上获取 # 从网络上获取
stock_map = em_stock.code_by_fs('hk_famous', em_stock.em_market_fs_types['hk_famous']) plat_id = 'cn_hs300'
stock_map = em_stock.code_by_fs(plat_id, em_stock.em_market_fs_types[plat_id])
if stock_map: if stock_map:
select_stocks(stock_map) select_stocks(stock_map)

View File

@ -5,15 +5,6 @@ import time
import csv import csv
import logging import logging
from datetime import datetime from datetime import datetime
import config
tbl_stock = 'reports_stock'
tbl_new_stock = 'reports_newstrock'
tbl_strategy = 'reports_strategy'
tbl_macresearch = 'reports_macresearch'
tbl_industry = 'reports_industry'
json_data_dir = f'{config.global_host_data_dir}/em_reports/json_data'
# 保存 JSON 数据到本地文件 # 保存 JSON 数据到本地文件
def save_json_to_file(data, file_path, file_name): def save_json_to_file(data, file_path, file_name):