modify scripts

This commit is contained in:
oscarz
2025-03-24 10:48:35 +08:00
parent 7ded7c5a19
commit 1521ff1fc0
11 changed files with 248 additions and 565 deletions

View File

@ -1,86 +0,0 @@
import logging
import os
import inspect
import time
from datetime import datetime
from logging.handlers import RotatingFileHandler
from collections import defaultdict
home_dir = os.path.expanduser("~")
global_host_data_dir = f'{home_dir}/hostdir/stock_data'
global_share_db_dir = f'{home_dir}/sharedata/sqlite'
# 统计日志频率
log_count = defaultdict(int) # 记录日志的次数
last_log_time = defaultdict(float) # 记录上次写入的时间戳
class RateLimitFilter(logging.Filter):
"""
频率限制过滤器:
1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略
2. 如果日志速率超过 100 条/秒,发出告警
"""
LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次
def filter(self, record):
global log_count, last_log_time
message_key = record.getMessage() # 获取日志内容
# 计算当前时间
now = time.time()
elapsed = now - last_log_time[message_key]
# 限制相同日志的写入频率
if elapsed < 60: # 60 秒内
log_count[message_key] += 1
if log_count[message_key] > self.LOG_LIMIT:
return False # 直接丢弃
else:
log_count[message_key] = 1 # 超过 60 秒,重新计数
last_log_time[message_key] = now
return True # 允许写入日志
def setup_logging(log_filename=None):
if log_filename is None:
caller_frame = inspect.stack()[1]
caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0]
current_date = datetime.now().strftime('%Y%m%d')
log_filename = f'../log/{caller_filename}_{current_date}.log'
max_log_size = 100 * 1024 * 1024 # 10 MB
max_log_files = 10 # 最多保留 10 个日志文件
file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
# 创建 logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers = [] # 避免重复添加 handler
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 添加频率限制
rate_limit_filter = RateLimitFilter()
file_handler.addFilter(rate_limit_filter)
console_handler.addFilter(rate_limit_filter)
# 运行示例
if __name__ == "__main__":
setup_logging()
for i in range(1000):
logging.info("测试日志,检测频率限制")
time.sleep(0.01) # 模拟快速写入日志

View File

@ -1,115 +0,0 @@
#!/bin/bash
# 远程服务器列表,按需修改
SERVERS=("175.178.54.98" "1.12.218.143" "43.139.169.25" "129.204.180.174" "42.194.142.169")
REMOTE_USER="ubuntu"
REMOTE_SCRIPT_DIR="/home/ubuntu/pyscripts/stockapp/reports_em"
REMOTE_LOG_DIR="/home/ubuntu/pyscripts/stockapp/log"
DATA_DIR="/home/ubuntu/hostdir/stock_data/pdfs"
LOCAL_FILES=("config.py" "em_reports.py" "fetch.py")
# 远程任务参数配置(每台服务器的不同参数)
TASK_PARAMS=(
"--mode=down --begin=2024-01-01 --end=2024-06-30"
"--mode=down --begin=2023-01-01 --end=2023-06-30"
"--mode=down --begin=2022-01-01 --end=2022-06-30"
"--mode=down --begin=2021-01-01 --end=2021-06-30"
"--mode=down --begin=2020-01-01 --end=2020-06-30"
)
# 推送代码到所有服务器
function push_code() {
for SERVER in "${SERVERS[@]}"; do
echo "Pushing code to $SERVER..."
scp "${LOCAL_FILES[@]}" "$REMOTE_USER@$SERVER:$REMOTE_SCRIPT_DIR/"
done
}
# 启动任务
function start_tasks() {
for i in "${!SERVERS[@]}"; do
SERVER="${SERVERS[$i]}"
PARAMS="${TASK_PARAMS[$i]}"
echo "Starting task on $SERVER with params: $PARAMS"
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 &"
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
# nohup ... < /dev/null防止 nohup 等待 stdin立即释放控制权。
# & disown确保进程与 SSH 彻底分离,避免 SIGHUP 信号影响SSH 立即返回。
# ssh -n禁用 ssh 的 stdin防止远程进程等待输入。
ssh -n "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
done
}
# 停止任务
function stop_tasks() {
for SERVER in "${SERVERS[@]}"; do
echo "Stopping task on $SERVER..."
ssh "$REMOTE_USER@$SERVER" "pkill -f 'python3 ./fetch.py'"
done
}
# 获取任务进度
function check_progress() {
for SERVER in "${SERVERS[@]}"; do
echo -e "\nChecking progress on $SERVER..."
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
if [ "$PROCESS_COUNT" -gt 0 ]; then
echo "Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress"
else
echo "Process status: Not running"
fi
echo "Total files: $FILE_COUNT"
echo "Total size : $FILE_SIZE"
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
echo "Error lines: $ERROR_LINES"
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
echo "Task count: $TASK_COUNT"
done
}
# 获取任务进度
function check_progress_robot() {
result=""
for SERVER in "${SERVERS[@]}"; do
result+="\nChecking progress on $SERVER...\n"
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
if [ "$PROCESS_COUNT" -gt 0 ]; then
result+="Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress\n"
else
result+="Process status: Not running\n"
fi
result+="Total files: $FILE_COUNT\n"
result+="Total size : $FILE_SIZE\n"
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
result+="Error lines: $ERROR_LINES\n"
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
done
echo -e "$result"
# 调用 Python 脚本发送消息
python3 ./robot.py "$result"
}
# 脚本菜单
case "$1" in
push) push_code ;;
start) start_tasks ;;
stop) stop_tasks ;;
check) check_progress ;;
*)
echo "Usage: $0 {push|start|stop|check}"
exit 1
;;
esac

View File

@ -1,248 +0,0 @@
import sqlite3
import json
import config
import utils
import logging
import sys
from datetime import datetime
# 连接 SQLite 数据库
DB_PATH = f"{config.global_share_db_dir}/stock_report.db" # 替换为你的数据库文件
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# 获取表的列名和默认值
def get_table_columns_and_defaults(tbl_name):
try:
cursor.execute(f"PRAGMA table_info({tbl_name})")
columns = cursor.fetchall()
column_info = {}
for col in columns:
col_name = col[1]
default_value = col[4]
column_info[col_name] = default_value
return column_info
except sqlite3.Error as e:
logging.error(f"Error getting table columns: {e}")
return None
# 检查并处理数据
def check_and_process_data(data, tbl_name):
column_info = get_table_columns_and_defaults(tbl_name=tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id': # 自增主键,不需要用户提供
continue
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
continue
if col in ['author', 'authorID']:
values = data.get(col, [])
processed_data[col] = ','.join(values)
elif col in data:
processed_data[col] = data[col]
else:
if default is not None:
processed_data[col] = default
else:
processed_data[col] = None
return processed_data
# 插入或更新数据
def insert_or_update_common(data, tbl_name, uniq_key='infoCode'):
try:
processed_data = check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
sql = f'''
INSERT INTO {tbl_name} ({columns}, updated_at)
VALUES ({placeholders}, datetime('now', 'localtime'))
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
'''
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
# 查询数据
def query_reports_comm(tbl_name, querystr='', limit=None ):
try:
if tbl_name in [utils.tbl_stock, utils.tbl_new_stock, utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] :
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
else:
logging.warning(f'wrong table name: {tbl_name}')
return None
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
# 获取列名
column_names = [description[0] for description in cursor.description]
# 将结果转换为字典列表
result_dict_list = []
for row in results:
row_dict = {column_names[i]: value for i, value in enumerate(row)}
result_dict_list.append(row_dict)
return result_dict_list
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''
# 插入或更新 industry_report 数据
def insert_or_update_report(report):
try:
sql = """
INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate,
industryCode, industryName, emRatingCode, emRatingValue,
emRatingName, attachSize, attachPages, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachSize=excluded.attachSize,
attachPages=excluded.attachPages,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["orgCode"], report["orgName"],
report["orgSName"], report["publishDate"], report["industryCode"],
report["industryName"], report["emRatingCode"], report["emRatingValue"],
report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0)
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_industry_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
# 插入或更新 industry_report 数据
def insert_or_update_stock_report(report):
try:
sql = """
INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName,
publishDate, industryCode, industryName, emIndustryCode, emRatingCode,
emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
stockName=excluded.stockName,
stockCode=excluded.stockCode,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emIndustryCode=excluded.emIndustryCode,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachPages=excluded.attachPages,
attachSize=excluded.attachSize,
author=excluded.author,
authorID=excluded.authorID,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["stockName"], report["stockCode"],
report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"],
report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""),
report["emRatingCode"], report["emRatingValue"], report["emRatingName"],
report.get("attachPages", 0), report.get("attachSize", 0),
",".join(report.get("author", [])), ",".join(report.get("authorID", []))
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_stock_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''

View File

@ -1,65 +0,0 @@
import logging
import os
import inspect
from datetime import datetime
from pathlib import Path
# MySQL 配置
db_config = {
'host': 'testdb',
'user': 'root',
'password': 'mysqlpw',
'database': 'stockdb'
}
log_dir_prefix = '../log'
global_share_data_dir = '/root/sharedata'
global_stock_data_dir = '/root/hostdir/stock_data'
# 获取当前文件所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取项目根目录(假设当前文件在 src/strategy 下)
project_root = os.path.abspath(os.path.join(current_dir, '..', '..'))
# 获取log目录
def get_log_directory():
"""
获取项目根目录下的 log 目录路径。如果 log 目录不存在,则自动创建。
"""
# 获取当前文件所在目录
current_dir = Path(__file__).resolve().parent
# 找到项目根目录,假设项目根目录下有一个 log 文件夹
project_root = current_dir
while project_root.name != 'src' and project_root != project_root.parent:
project_root = project_root.parent
project_root = project_root.parent # 回到项目根目录
# 确保 log 目录存在
log_dir = project_root / 'log'
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def get_caller_filename():
# 获取调用 setup_logging 的脚本文件名
caller_frame = inspect.stack()[2]
caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0]
return caller_filename
# 设置日志配置
def setup_logging(log_filename=None):
# 如果未传入 log_filename则使用当前脚本名称作为日志文件名
if log_filename is None:
caller_filename = get_caller_filename()
common_log_dir = get_log_directory()
current_date = datetime.now().strftime('%Y%m%d')
# 拼接 log 文件名,将日期加在扩展名前
log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log'
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler()
])

View File

@ -9,8 +9,13 @@ db_config = {
'database': 'stockdb'
}
global_share_data_dir = '/root/sharedata'
global_stock_data_dir = '/root/hostdir/stock_data'
home_dir = os.path.expanduser("~")
global_host_data_dir = f'{home_dir}/hostdir/stock_data'
global_share_db_dir = f'{home_dir}/sharedata/sqlite'
# 兼容以前的定义
global_stock_data_dir = global_host_data_dir
global_share_data_dir = f'{home_dir}/sharedata'
# 获取当前文件所在目录
current_dir = Path(__file__).resolve().parent

View File

@ -4,8 +4,6 @@ import requests
import time
import logging
from bs4 import BeautifulSoup
import sqlite_utils as db_tools
import config
# 获取个股研报列表的指定页
def fetch_reports_by_stock(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):

123
src/db_utils/reports.py Normal file
View File

@ -0,0 +1,123 @@
import sqlite3
import logging
import sys
from datetime import datetime
class DatabaseConnectionError(Exception):
pass
class StockReportDB:
# 定义类属性(静态变量)
TBL_STOCK = 'reports_stock'
TBL_NEW_STOCK = 'reports_newstrock'
TBL_STRATEGY = 'reports_strategy'
TBL_MACRESEARCH = 'reports_macresearch'
TBL_INDUSTRY = 'reports_industry'
def __init__(self, db_path):
self.DB_PATH = db_path
self.conn = None
self.cursor = None
try:
self.conn = sqlite3.connect(self.DB_PATH)
self.cursor = self.conn.cursor()
except sqlite3.Error as e:
logging.error(f"数据库连接失败: {e}")
raise DatabaseConnectionError("数据库连接失败")
def __get_table_columns_and_defaults(self, tbl_name):
try:
self.cursor.execute(f"PRAGMA table_info({tbl_name})")
columns = self.cursor.fetchall()
column_info = {}
for col in columns:
col_name = col[1]
default_value = col[4]
column_info[col_name] = default_value
return column_info
except sqlite3.Error as e:
logging.error(f"Error getting table columns: {e}")
return None
def __check_and_process_data(self, data, tbl_name):
column_info = self.__get_table_columns_and_defaults(tbl_name=tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id': # 自增主键,不需要用户提供
continue
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
continue
if col in ['author', 'authorID']:
values = data.get(col, [])
processed_data[col] = ','.join(values)
elif col in data:
processed_data[col] = data[col]
else:
if default is not None:
processed_data[col] = default
else:
processed_data[col] = None
return processed_data
def insert_or_update_common(self, data, tbl_name, uniq_key='infoCode'):
try:
processed_data = self.__check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
sql = f'''
INSERT INTO {tbl_name} ({columns}, updated_at)
VALUES ({placeholders}, datetime('now', 'localtime'))
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
'''
self.cursor.execute(sql, values)
self.conn.commit()
# 获取插入或更新后的 report_id
self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
report_id = self.cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
def query_reports_comm(self, tbl_name, querystr='', limit=None):
try:
if tbl_name in [StockReportDB.TBL_STOCK, StockReportDB.TBL_NEW_STOCK, StockReportDB.TBL_INDUSTRY, StockReportDB.TBL_MACRESEARCH, StockReportDB.TBL_STRATEGY]:
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
else:
logging.warning(f'wrong table name: {tbl_name}')
return None
if limit:
sql = sql + f' limit {limit}'
self.cursor.execute(sql)
results = self.cursor.fetchall()
# 获取列名
column_names = [description[0] for description in self.cursor.description]
# 将结果转换为字典列表
result_dict_list = []
for row in results:
row_dict = {column_names[i]: value for i, value in enumerate(row)}
result_dict_list.append(row_dict)
return result_dict_list
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
def __del__(self):
if self.conn:
self.conn.close()

View File

@ -9,34 +9,41 @@ import shutil
import logging
from datetime import datetime, timedelta
from functools import partial
import config
import sqlite_utils as db_tools
import em_reports as em
import utils
import src.crawler.em.reports as em
import src.utils.utils as utils
from src.config.config import global_host_data_dir, global_share_db_dir
from src.db_utils.reports import StockReportDB, DatabaseConnectionError
from src.logger.logger import setup_logging
config.setup_logging()
# 初始化日志
setup_logging()
debug = False
force = False
pdf_base_dir = f"{config.global_host_data_dir}/pdfs" # 下载 PDF 存放目录
pdf_base_dir = f"{global_host_data_dir}/pdfs" # 下载 PDF 存放目录
# 定义下载页面的链接
map_pdf_page = {
utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html",
utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
StockReportDB.TBL_STOCK : "https://data.eastmoney.com/report/info/{}.html",
StockReportDB.TBL_NEW_STOCK : "https://data.eastmoney.com/report/info/{}.html",
StockReportDB.TBL_STRATEGY : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
StockReportDB.TBL_MACRESEARCH : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
StockReportDB.TBL_INDUSTRY : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
}
# 定义表名的映射,作为存储路径用
map_tbl_name = {
utils.tbl_stock : '个股研报',
utils.tbl_new_stock : '新股研报',
utils.tbl_strategy : '策略报告',
utils.tbl_macresearch : '宏观研究',
utils.tbl_industry : '行业研报'
StockReportDB.TBL_STOCK : '个股研报',
StockReportDB.TBL_NEW_STOCK : '新股研报',
StockReportDB.TBL_STRATEGY : '策略报告',
StockReportDB.TBL_MACRESEARCH : '宏观研究',
StockReportDB.TBL_INDUSTRY : '行业研报'
}
# 初始化数据库连接
db_path = f"{global_share_db_dir}/stock_report.db"
db_tools = None
current_date = datetime.now()
seven_days_ago = current_date - timedelta(days=7)
two_years_ago = current_date - timedelta(days=2*365)
@ -45,6 +52,7 @@ start_date = two_years_ago.strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d")
this_week_date = seven_days_ago.strftime("%Y-%m-%d")
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
# 示例:获取前 3 页的数据
max_pages = 100000
@ -148,39 +156,39 @@ def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_d
# 获取股票报告列表
def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock')
return fetch_reports_list_general(em.fetch_reports_by_stock, StockReportDB.TBL_STOCK, s_date, e_date, 'stock')
# 获取股票报告列表
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new')
return fetch_reports_list_general(em.fetch_reports_by_newstock, StockReportDB.TBL_NEW_STOCK, s_date, e_date, 'new')
# 获取行业报告列表
def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry')
return fetch_reports_list_general(em.fetch_reports_by_industry, StockReportDB.TBL_INDUSTRY, s_date, e_date, 'industry')
# 获取行业报告列表
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch')
return fetch_reports_list_general(em.fetch_reports_by_macresearch, StockReportDB.TBL_MACRESEARCH, s_date, e_date, 'macresearch')
# 获取行业报告列表
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy')
return fetch_reports_list_general(em.fetch_reports_by_strategy, StockReportDB.TBL_STRATEGY, s_date, e_date, 'strategy')
# 下载股票pdf
def download_pdf_stock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_stock, ' ', s_date, e_date, limit=2 if debug else None)
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STOCK, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_newstock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_new_stock, ' ', s_date, e_date, limit=2 if debug else None)
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_NEW_STOCK, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_industry(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_industry, ' ', s_date, e_date, limit=2 if debug else None)
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_INDUSTRY, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_macresearch(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None)
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_MACRESEARCH, ' ', s_date, e_date, limit=2 if debug else None)
def download_pdf_strategy(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None)
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STRATEGY, ' ', s_date, e_date, limit=2 if debug else None)
# 建立缩写到函数的映射
@ -265,6 +273,15 @@ def main(cmd, mode, args_debug, args_force, begin, end):
global end_date
end_date = end if end else end_date
# 初始化DB
global db_tools
try:
db_tools = StockReportDB(db_path)
# 进行数据库操作
except DatabaseConnectionError as e:
logging.error(f"数据库连接失败: {e}")
return False
# 开启任务
#task_id = db_tools.insert_task_log()
task_id = 0

View File

@ -1,10 +1,46 @@
import logging
import os
import inspect
import time
from datetime import datetime
from pathlib import Path
from logging.handlers import RotatingFileHandler
from collections import defaultdict
from src.config.config import get_log_directory, get_src_directory
# 统计日志频率
log_count = defaultdict(int) # 记录日志的次数
last_log_time = defaultdict(float) # 记录上次写入的时间戳
class RateLimitFilter(logging.Filter):
"""
频率限制过滤器:
1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略
2. 如果日志速率超过 100 条/秒,发出告警
"""
LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次
def filter(self, record):
global log_count, last_log_time
message_key = record.getMessage() # 获取日志内容
# 计算当前时间
now = time.time()
elapsed = now - last_log_time[message_key]
# 限制相同日志的写入频率
if elapsed < 60: # 60 秒内
log_count[message_key] += 1
if log_count[message_key] > self.LOG_LIMIT:
return False # 直接丢弃
else:
log_count[message_key] = 1 # 超过 60 秒,重新计数
last_log_time[message_key] = now
return True # 允许写入日志
def get_caller_filename():
# 获取调用栈
stack = inspect.stack()
@ -26,7 +62,7 @@ def get_caller_filename():
return os.path.splitext(os.path.basename(frame_info.filename))[0]
return None
# 设置日志配置
def setup_logging(log_filename=None):
# 如果未传入 log_filename则使用当前脚本名称作为日志文件名
if log_filename is None:
@ -35,9 +71,29 @@ def setup_logging(log_filename=None):
current_date = datetime.now().strftime('%Y%m%d')
# 拼接 log 文件名,将日期加在扩展名前
log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log'
max_log_size = 100 * 1024 * 1024 # 10 MB
max_log_files = 10 # 最多保留 10 个日志文件
file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
))
# 创建 logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers = [] # 避免重复添加 handler
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 添加频率限制
rate_limit_filter = RateLimitFilter()
file_handler.addFilter(rate_limit_filter)
console_handler.addFilter(rate_limit_filter)
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler()
])

View File

@ -1,6 +1,7 @@
import pandas as pd
import numpy as np
import os
import warnings
from src.strategy.prepare import fetch_his_kline
import src.config.config as config
import src.crawler.em.stock as em_stock
@ -15,10 +16,15 @@ def select_stocks(stock_map):
df = fetch_his_kline(stock_code)
close_prices = df['close'].values
# 使用 MyTT 库计算不同周期的 RSI
rsi_6 = RSI(close_prices, 6)
rsi_12 = RSI(close_prices, 12)
rsi_24 = RSI(close_prices, 24)
# 捕获 RuntimeWarning 警告
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", RuntimeWarning)
# 使用 MyTT 库计算不同周期的 RSI
rsi_6 = RSI(close_prices, 6)
rsi_12 = RSI(close_prices, 12)
rsi_24 = RSI(close_prices, 24)
if w:
print(f"股票代码 {stock_code} {stock_name} 在计算 RSI 时出现警告: {w[0].message}")
df['rsi_6'] = rsi_6
df['rsi_12'] = rsi_12
@ -96,7 +102,8 @@ if __name__ == "__main__":
codes = ['105.QFIN', '105.FUTU']
# 从网络上获取
stock_map = em_stock.code_by_fs('hk_famous', em_stock.em_market_fs_types['hk_famous'])
plat_id = 'cn_hs300'
stock_map = em_stock.code_by_fs(plat_id, em_stock.em_market_fs_types[plat_id])
if stock_map:
select_stocks(stock_map)

View File

@ -5,15 +5,6 @@ import time
import csv
import logging
from datetime import datetime
import config
tbl_stock = 'reports_stock'
tbl_new_stock = 'reports_newstrock'
tbl_strategy = 'reports_strategy'
tbl_macresearch = 'reports_macresearch'
tbl_industry = 'reports_industry'
json_data_dir = f'{config.global_host_data_dir}/em_reports/json_data'
# 保存 JSON 数据到本地文件
def save_json_to_file(data, file_path, file_name):