modify scripts
This commit is contained in:
@ -1,86 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import time
|
||||
from datetime import datetime
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from collections import defaultdict
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
global_host_data_dir = f'{home_dir}/hostdir/stock_data'
|
||||
global_share_db_dir = f'{home_dir}/sharedata/sqlite'
|
||||
|
||||
# 统计日志频率
|
||||
log_count = defaultdict(int) # 记录日志的次数
|
||||
last_log_time = defaultdict(float) # 记录上次写入的时间戳
|
||||
|
||||
class RateLimitFilter(logging.Filter):
|
||||
"""
|
||||
频率限制过滤器:
|
||||
1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略
|
||||
2. 如果日志速率超过 100 条/秒,发出告警
|
||||
"""
|
||||
LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次
|
||||
|
||||
def filter(self, record):
|
||||
global log_count, last_log_time
|
||||
message_key = record.getMessage() # 获取日志内容
|
||||
|
||||
# 计算当前时间
|
||||
now = time.time()
|
||||
elapsed = now - last_log_time[message_key]
|
||||
|
||||
# 限制相同日志的写入频率
|
||||
if elapsed < 60: # 60 秒内
|
||||
log_count[message_key] += 1
|
||||
if log_count[message_key] > self.LOG_LIMIT:
|
||||
return False # 直接丢弃
|
||||
else:
|
||||
log_count[message_key] = 1 # 超过 60 秒,重新计数
|
||||
|
||||
last_log_time[message_key] = now
|
||||
|
||||
return True # 允许写入日志
|
||||
|
||||
|
||||
|
||||
def setup_logging(log_filename=None):
|
||||
if log_filename is None:
|
||||
caller_frame = inspect.stack()[1]
|
||||
caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0]
|
||||
current_date = datetime.now().strftime('%Y%m%d')
|
||||
log_filename = f'../log/{caller_filename}_{current_date}.log'
|
||||
|
||||
max_log_size = 100 * 1024 * 1024 # 10 MB
|
||||
max_log_files = 10 # 最多保留 10 个日志文件
|
||||
|
||||
file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files)
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
|
||||
))
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
|
||||
))
|
||||
|
||||
# 创建 logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers = [] # 避免重复添加 handler
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 添加频率限制
|
||||
rate_limit_filter = RateLimitFilter()
|
||||
file_handler.addFilter(rate_limit_filter)
|
||||
console_handler.addFilter(rate_limit_filter)
|
||||
|
||||
|
||||
# 运行示例
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
|
||||
for i in range(1000):
|
||||
logging.info("测试日志,检测频率限制")
|
||||
time.sleep(0.01) # 模拟快速写入日志
|
||||
@ -1,115 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 远程服务器列表,按需修改
|
||||
SERVERS=("175.178.54.98" "1.12.218.143" "43.139.169.25" "129.204.180.174" "42.194.142.169")
|
||||
REMOTE_USER="ubuntu"
|
||||
REMOTE_SCRIPT_DIR="/home/ubuntu/pyscripts/stockapp/reports_em"
|
||||
REMOTE_LOG_DIR="/home/ubuntu/pyscripts/stockapp/log"
|
||||
DATA_DIR="/home/ubuntu/hostdir/stock_data/pdfs"
|
||||
LOCAL_FILES=("config.py" "em_reports.py" "fetch.py")
|
||||
|
||||
# 远程任务参数配置(每台服务器的不同参数)
|
||||
TASK_PARAMS=(
|
||||
"--mode=down --begin=2024-01-01 --end=2024-06-30"
|
||||
"--mode=down --begin=2023-01-01 --end=2023-06-30"
|
||||
"--mode=down --begin=2022-01-01 --end=2022-06-30"
|
||||
"--mode=down --begin=2021-01-01 --end=2021-06-30"
|
||||
"--mode=down --begin=2020-01-01 --end=2020-06-30"
|
||||
)
|
||||
|
||||
# 推送代码到所有服务器
|
||||
function push_code() {
|
||||
for SERVER in "${SERVERS[@]}"; do
|
||||
echo "Pushing code to $SERVER..."
|
||||
scp "${LOCAL_FILES[@]}" "$REMOTE_USER@$SERVER:$REMOTE_SCRIPT_DIR/"
|
||||
done
|
||||
}
|
||||
|
||||
# 启动任务
|
||||
function start_tasks() {
|
||||
for i in "${!SERVERS[@]}"; do
|
||||
SERVER="${SERVERS[$i]}"
|
||||
PARAMS="${TASK_PARAMS[$i]}"
|
||||
echo "Starting task on $SERVER with params: $PARAMS"
|
||||
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 &"
|
||||
#ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
|
||||
# nohup ... < /dev/null:防止 nohup 等待 stdin,立即释放控制权。
|
||||
# & disown:确保进程与 SSH 彻底分离,避免 SIGHUP 信号影响,SSH 立即返回。
|
||||
# ssh -n:禁用 ssh 的 stdin,防止远程进程等待输入。
|
||||
ssh -n "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown"
|
||||
done
|
||||
}
|
||||
|
||||
# 停止任务
|
||||
function stop_tasks() {
|
||||
for SERVER in "${SERVERS[@]}"; do
|
||||
echo "Stopping task on $SERVER..."
|
||||
ssh "$REMOTE_USER@$SERVER" "pkill -f 'python3 ./fetch.py'"
|
||||
done
|
||||
}
|
||||
|
||||
# 获取任务进度
|
||||
function check_progress() {
|
||||
for SERVER in "${SERVERS[@]}"; do
|
||||
echo -e "\nChecking progress on $SERVER..."
|
||||
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
|
||||
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
|
||||
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
|
||||
|
||||
if [ "$PROCESS_COUNT" -gt 0 ]; then
|
||||
echo "Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress"
|
||||
else
|
||||
echo "Process status: Not running"
|
||||
fi
|
||||
|
||||
echo "Total files: $FILE_COUNT"
|
||||
echo "Total size : $FILE_SIZE"
|
||||
|
||||
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
|
||||
echo "Error lines: $ERROR_LINES"
|
||||
|
||||
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
|
||||
echo "Task count: $TASK_COUNT"
|
||||
|
||||
done
|
||||
}
|
||||
|
||||
# 获取任务进度
|
||||
function check_progress_robot() {
|
||||
result=""
|
||||
for SERVER in "${SERVERS[@]}"; do
|
||||
result+="\nChecking progress on $SERVER...\n"
|
||||
FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l")
|
||||
FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR")
|
||||
PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l")
|
||||
|
||||
if [ "$PROCESS_COUNT" -gt 0 ]; then
|
||||
result+="Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress\n"
|
||||
else
|
||||
result+="Process status: Not running\n"
|
||||
fi
|
||||
|
||||
result+="Total files: $FILE_COUNT\n"
|
||||
result+="Total size : $FILE_SIZE\n"
|
||||
|
||||
ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l")
|
||||
result+="Error lines: $ERROR_LINES\n"
|
||||
|
||||
TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l")
|
||||
done
|
||||
echo -e "$result"
|
||||
# 调用 Python 脚本发送消息
|
||||
python3 ./robot.py "$result"
|
||||
}
|
||||
|
||||
# 脚本菜单
|
||||
case "$1" in
|
||||
push) push_code ;;
|
||||
start) start_tasks ;;
|
||||
stop) stop_tasks ;;
|
||||
check) check_progress ;;
|
||||
*)
|
||||
echo "Usage: $0 {push|start|stop|check}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
@ -1,248 +0,0 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import config
|
||||
import utils
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# 连接 SQLite 数据库
|
||||
DB_PATH = f"{config.global_share_db_dir}/stock_report.db" # 替换为你的数据库文件
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表的列名和默认值
|
||||
def get_table_columns_and_defaults(tbl_name):
|
||||
try:
|
||||
cursor.execute(f"PRAGMA table_info({tbl_name})")
|
||||
columns = cursor.fetchall()
|
||||
column_info = {}
|
||||
for col in columns:
|
||||
col_name = col[1]
|
||||
default_value = col[4]
|
||||
column_info[col_name] = default_value
|
||||
return column_info
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"Error getting table columns: {e}")
|
||||
return None
|
||||
|
||||
# 检查并处理数据
|
||||
def check_and_process_data(data, tbl_name):
|
||||
column_info = get_table_columns_and_defaults(tbl_name=tbl_name)
|
||||
if column_info is None:
|
||||
return None
|
||||
processed_data = {}
|
||||
for col, default in column_info.items():
|
||||
if col == 'id': # 自增主键,不需要用户提供
|
||||
continue
|
||||
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
|
||||
continue
|
||||
if col in ['author', 'authorID']:
|
||||
values = data.get(col, [])
|
||||
processed_data[col] = ','.join(values)
|
||||
elif col in data:
|
||||
processed_data[col] = data[col]
|
||||
else:
|
||||
if default is not None:
|
||||
processed_data[col] = default
|
||||
else:
|
||||
processed_data[col] = None
|
||||
return processed_data
|
||||
|
||||
|
||||
# 插入或更新数据
|
||||
def insert_or_update_common(data, tbl_name, uniq_key='infoCode'):
|
||||
try:
|
||||
processed_data = check_and_process_data(data, tbl_name)
|
||||
if processed_data is None:
|
||||
return None
|
||||
|
||||
columns = ', '.join(processed_data.keys())
|
||||
values = list(processed_data.values())
|
||||
placeholders = ', '.join(['?' for _ in values])
|
||||
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
|
||||
|
||||
sql = f'''
|
||||
INSERT INTO {tbl_name} ({columns}, updated_at)
|
||||
VALUES ({placeholders}, datetime('now', 'localtime'))
|
||||
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
|
||||
'''
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
|
||||
# 获取插入或更新后的 report_id
|
||||
cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
|
||||
report_id = cursor.fetchone()[0]
|
||||
return report_id
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"Error inserting or updating data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 查询数据
|
||||
def query_reports_comm(tbl_name, querystr='', limit=None ):
|
||||
try:
|
||||
if tbl_name in [utils.tbl_stock, utils.tbl_new_stock, utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] :
|
||||
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
|
||||
else:
|
||||
logging.warning(f'wrong table name: {tbl_name}')
|
||||
return None
|
||||
|
||||
if limit :
|
||||
sql = sql + f' limit {limit}'
|
||||
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
|
||||
# 获取列名
|
||||
column_names = [description[0] for description in cursor.description]
|
||||
|
||||
# 将结果转换为字典列表
|
||||
result_dict_list = []
|
||||
for row in results:
|
||||
row_dict = {column_names[i]: value for i, value in enumerate(row)}
|
||||
result_dict_list.append(row_dict)
|
||||
|
||||
return result_dict_list
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"查询 href 失败: {e}")
|
||||
return None
|
||||
|
||||
'''
|
||||
# 插入或更新 industry_report 数据
|
||||
def insert_or_update_report(report):
|
||||
try:
|
||||
sql = """
|
||||
INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate,
|
||||
industryCode, industryName, emRatingCode, emRatingValue,
|
||||
emRatingName, attachSize, attachPages, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
|
||||
ON CONFLICT(infoCode) DO UPDATE SET
|
||||
title=excluded.title,
|
||||
orgCode=excluded.orgCode,
|
||||
orgName=excluded.orgName,
|
||||
orgSName=excluded.orgSName,
|
||||
publishDate=excluded.publishDate,
|
||||
industryCode=excluded.industryCode,
|
||||
industryName=excluded.industryName,
|
||||
emRatingCode=excluded.emRatingCode,
|
||||
emRatingValue=excluded.emRatingValue,
|
||||
emRatingName=excluded.emRatingName,
|
||||
attachSize=excluded.attachSize,
|
||||
attachPages=excluded.attachPages,
|
||||
updated_at=datetime('now', 'localtime')
|
||||
"""
|
||||
|
||||
values = (
|
||||
report["infoCode"], report["title"], report["orgCode"], report["orgName"],
|
||||
report["orgSName"], report["publishDate"], report["industryCode"],
|
||||
report["industryName"], report["emRatingCode"], report["emRatingValue"],
|
||||
report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0)
|
||||
)
|
||||
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
|
||||
# 获取插入或更新后的 report_id
|
||||
cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],))
|
||||
report_id = cursor.fetchone()[0]
|
||||
return report_id
|
||||
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
logging.error(f"数据库错误: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"未知错误: {e}")
|
||||
return None
|
||||
|
||||
# 查询研报数据
|
||||
def query_industry_reports(querystr='', limit=None):
|
||||
try:
|
||||
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}"
|
||||
if limit :
|
||||
sql = sql + f' limit {limit}'
|
||||
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"查询 href 失败: {e}")
|
||||
return None
|
||||
|
||||
# 插入或更新 industry_report 数据
|
||||
def insert_or_update_stock_report(report):
|
||||
try:
|
||||
sql = """
|
||||
INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName,
|
||||
publishDate, industryCode, industryName, emIndustryCode, emRatingCode,
|
||||
emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
|
||||
ON CONFLICT(infoCode) DO UPDATE SET
|
||||
title=excluded.title,
|
||||
stockName=excluded.stockName,
|
||||
stockCode=excluded.stockCode,
|
||||
orgCode=excluded.orgCode,
|
||||
orgName=excluded.orgName,
|
||||
orgSName=excluded.orgSName,
|
||||
publishDate=excluded.publishDate,
|
||||
industryCode=excluded.industryCode,
|
||||
industryName=excluded.industryName,
|
||||
emIndustryCode=excluded.emIndustryCode,
|
||||
emRatingCode=excluded.emRatingCode,
|
||||
emRatingValue=excluded.emRatingValue,
|
||||
emRatingName=excluded.emRatingName,
|
||||
attachPages=excluded.attachPages,
|
||||
attachSize=excluded.attachSize,
|
||||
author=excluded.author,
|
||||
authorID=excluded.authorID,
|
||||
updated_at=datetime('now', 'localtime')
|
||||
"""
|
||||
|
||||
values = (
|
||||
report["infoCode"], report["title"], report["stockName"], report["stockCode"],
|
||||
report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"],
|
||||
report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""),
|
||||
report["emRatingCode"], report["emRatingValue"], report["emRatingName"],
|
||||
report.get("attachPages", 0), report.get("attachSize", 0),
|
||||
",".join(report.get("author", [])), ",".join(report.get("authorID", []))
|
||||
)
|
||||
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
|
||||
# 获取插入或更新后的 report_id
|
||||
cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],))
|
||||
report_id = cursor.fetchone()[0]
|
||||
return report_id
|
||||
|
||||
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
logging.error(f"数据库错误: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"未知错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 查询研报数据
|
||||
def query_stock_reports(querystr='', limit=None):
|
||||
try:
|
||||
sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}"
|
||||
if limit :
|
||||
sql = sql + f' limit {limit}'
|
||||
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"查询 href 失败: {e}")
|
||||
return None
|
||||
|
||||
'''
|
||||
@ -1,65 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# MySQL 配置
|
||||
db_config = {
|
||||
'host': 'testdb',
|
||||
'user': 'root',
|
||||
'password': 'mysqlpw',
|
||||
'database': 'stockdb'
|
||||
}
|
||||
|
||||
log_dir_prefix = '../log'
|
||||
|
||||
global_share_data_dir = '/root/sharedata'
|
||||
global_stock_data_dir = '/root/hostdir/stock_data'
|
||||
|
||||
# 获取当前文件所在目录
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 获取项目根目录(假设当前文件在 src/strategy 下)
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '..', '..'))
|
||||
|
||||
# 获取log目录
|
||||
def get_log_directory():
|
||||
"""
|
||||
获取项目根目录下的 log 目录路径。如果 log 目录不存在,则自动创建。
|
||||
"""
|
||||
# 获取当前文件所在目录
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
|
||||
# 找到项目根目录,假设项目根目录下有一个 log 文件夹
|
||||
project_root = current_dir
|
||||
while project_root.name != 'src' and project_root != project_root.parent:
|
||||
project_root = project_root.parent
|
||||
project_root = project_root.parent # 回到项目根目录
|
||||
|
||||
# 确保 log 目录存在
|
||||
log_dir = project_root / 'log'
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return log_dir
|
||||
|
||||
def get_caller_filename():
|
||||
# 获取调用 setup_logging 的脚本文件名
|
||||
caller_frame = inspect.stack()[2]
|
||||
caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0]
|
||||
return caller_filename
|
||||
|
||||
# 设置日志配置
|
||||
def setup_logging(log_filename=None):
|
||||
# 如果未传入 log_filename,则使用当前脚本名称作为日志文件名
|
||||
if log_filename is None:
|
||||
caller_filename = get_caller_filename()
|
||||
common_log_dir = get_log_directory()
|
||||
current_date = datetime.now().strftime('%Y%m%d')
|
||||
# 拼接 log 文件名,将日期加在扩展名前
|
||||
log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log'
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_filename),
|
||||
logging.StreamHandler()
|
||||
])
|
||||
@ -9,8 +9,13 @@ db_config = {
|
||||
'database': 'stockdb'
|
||||
}
|
||||
|
||||
global_share_data_dir = '/root/sharedata'
|
||||
global_stock_data_dir = '/root/hostdir/stock_data'
|
||||
home_dir = os.path.expanduser("~")
|
||||
global_host_data_dir = f'{home_dir}/hostdir/stock_data'
|
||||
global_share_db_dir = f'{home_dir}/sharedata/sqlite'
|
||||
|
||||
# 兼容以前的定义
|
||||
global_stock_data_dir = global_host_data_dir
|
||||
global_share_data_dir = f'{home_dir}/sharedata'
|
||||
|
||||
# 获取当前文件所在目录
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
|
||||
@ -4,8 +4,6 @@ import requests
|
||||
import time
|
||||
import logging
|
||||
from bs4 import BeautifulSoup
|
||||
import sqlite_utils as db_tools
|
||||
import config
|
||||
|
||||
# 获取个股研报列表的指定页
|
||||
def fetch_reports_by_stock(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3):
|
||||
123
src/db_utils/reports.py
Normal file
123
src/db_utils/reports.py
Normal file
@ -0,0 +1,123 @@
|
||||
import sqlite3
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
class DatabaseConnectionError(Exception):
|
||||
pass
|
||||
|
||||
class StockReportDB:
|
||||
# 定义类属性(静态变量)
|
||||
TBL_STOCK = 'reports_stock'
|
||||
TBL_NEW_STOCK = 'reports_newstrock'
|
||||
TBL_STRATEGY = 'reports_strategy'
|
||||
TBL_MACRESEARCH = 'reports_macresearch'
|
||||
TBL_INDUSTRY = 'reports_industry'
|
||||
|
||||
def __init__(self, db_path):
|
||||
self.DB_PATH = db_path
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
try:
|
||||
self.conn = sqlite3.connect(self.DB_PATH)
|
||||
self.cursor = self.conn.cursor()
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"数据库连接失败: {e}")
|
||||
raise DatabaseConnectionError("数据库连接失败")
|
||||
|
||||
def __get_table_columns_and_defaults(self, tbl_name):
|
||||
try:
|
||||
self.cursor.execute(f"PRAGMA table_info({tbl_name})")
|
||||
columns = self.cursor.fetchall()
|
||||
column_info = {}
|
||||
for col in columns:
|
||||
col_name = col[1]
|
||||
default_value = col[4]
|
||||
column_info[col_name] = default_value
|
||||
return column_info
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"Error getting table columns: {e}")
|
||||
return None
|
||||
|
||||
def __check_and_process_data(self, data, tbl_name):
|
||||
column_info = self.__get_table_columns_and_defaults(tbl_name=tbl_name)
|
||||
if column_info is None:
|
||||
return None
|
||||
processed_data = {}
|
||||
for col, default in column_info.items():
|
||||
if col == 'id': # 自增主键,不需要用户提供
|
||||
continue
|
||||
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
|
||||
continue
|
||||
if col in ['author', 'authorID']:
|
||||
values = data.get(col, [])
|
||||
processed_data[col] = ','.join(values)
|
||||
elif col in data:
|
||||
processed_data[col] = data[col]
|
||||
else:
|
||||
if default is not None:
|
||||
processed_data[col] = default
|
||||
else:
|
||||
processed_data[col] = None
|
||||
return processed_data
|
||||
|
||||
|
||||
def insert_or_update_common(self, data, tbl_name, uniq_key='infoCode'):
|
||||
try:
|
||||
processed_data = self.__check_and_process_data(data, tbl_name)
|
||||
if processed_data is None:
|
||||
return None
|
||||
|
||||
columns = ', '.join(processed_data.keys())
|
||||
values = list(processed_data.values())
|
||||
placeholders = ', '.join(['?' for _ in values])
|
||||
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
|
||||
|
||||
sql = f'''
|
||||
INSERT INTO {tbl_name} ({columns}, updated_at)
|
||||
VALUES ({placeholders}, datetime('now', 'localtime'))
|
||||
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
|
||||
'''
|
||||
self.cursor.execute(sql, values)
|
||||
self.conn.commit()
|
||||
|
||||
# 获取插入或更新后的 report_id
|
||||
self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
|
||||
report_id = self.cursor.fetchone()[0]
|
||||
return report_id
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"Error inserting or updating data: {e}")
|
||||
return None
|
||||
|
||||
def query_reports_comm(self, tbl_name, querystr='', limit=None):
|
||||
try:
|
||||
if tbl_name in [StockReportDB.TBL_STOCK, StockReportDB.TBL_NEW_STOCK, StockReportDB.TBL_INDUSTRY, StockReportDB.TBL_MACRESEARCH, StockReportDB.TBL_STRATEGY]:
|
||||
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
|
||||
else:
|
||||
logging.warning(f'wrong table name: {tbl_name}')
|
||||
return None
|
||||
|
||||
if limit:
|
||||
sql = sql + f' limit {limit}'
|
||||
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
|
||||
# 获取列名
|
||||
column_names = [description[0] for description in self.cursor.description]
|
||||
|
||||
# 将结果转换为字典列表
|
||||
result_dict_list = []
|
||||
for row in results:
|
||||
row_dict = {column_names[i]: value for i, value in enumerate(row)}
|
||||
result_dict_list.append(row_dict)
|
||||
|
||||
return result_dict_list
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"查询 href 失败: {e}")
|
||||
return None
|
||||
|
||||
def __del__(self):
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
|
||||
@ -9,34 +9,41 @@ import shutil
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import config
|
||||
import sqlite_utils as db_tools
|
||||
import em_reports as em
|
||||
import utils
|
||||
import src.crawler.em.reports as em
|
||||
import src.utils.utils as utils
|
||||
from src.config.config import global_host_data_dir, global_share_db_dir
|
||||
from src.db_utils.reports import StockReportDB, DatabaseConnectionError
|
||||
from src.logger.logger import setup_logging
|
||||
|
||||
config.setup_logging()
|
||||
# 初始化日志
|
||||
setup_logging()
|
||||
|
||||
debug = False
|
||||
force = False
|
||||
pdf_base_dir = f"{config.global_host_data_dir}/pdfs" # 下载 PDF 存放目录
|
||||
|
||||
pdf_base_dir = f"{global_host_data_dir}/pdfs" # 下载 PDF 存放目录
|
||||
|
||||
# 定义下载页面的链接
|
||||
map_pdf_page = {
|
||||
utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html",
|
||||
utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html",
|
||||
utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
|
||||
utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
|
||||
utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
|
||||
StockReportDB.TBL_STOCK : "https://data.eastmoney.com/report/info/{}.html",
|
||||
StockReportDB.TBL_NEW_STOCK : "https://data.eastmoney.com/report/info/{}.html",
|
||||
StockReportDB.TBL_STRATEGY : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}",
|
||||
StockReportDB.TBL_MACRESEARCH : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}",
|
||||
StockReportDB.TBL_INDUSTRY : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}"
|
||||
}
|
||||
|
||||
# 定义表名的映射,作为存储路径用
|
||||
map_tbl_name = {
|
||||
utils.tbl_stock : '个股研报',
|
||||
utils.tbl_new_stock : '新股研报',
|
||||
utils.tbl_strategy : '策略报告',
|
||||
utils.tbl_macresearch : '宏观研究',
|
||||
utils.tbl_industry : '行业研报'
|
||||
StockReportDB.TBL_STOCK : '个股研报',
|
||||
StockReportDB.TBL_NEW_STOCK : '新股研报',
|
||||
StockReportDB.TBL_STRATEGY : '策略报告',
|
||||
StockReportDB.TBL_MACRESEARCH : '宏观研究',
|
||||
StockReportDB.TBL_INDUSTRY : '行业研报'
|
||||
}
|
||||
|
||||
# 初始化数据库连接
|
||||
db_path = f"{global_share_db_dir}/stock_report.db"
|
||||
db_tools = None
|
||||
|
||||
current_date = datetime.now()
|
||||
seven_days_ago = current_date - timedelta(days=7)
|
||||
two_years_ago = current_date - timedelta(days=2*365)
|
||||
@ -45,6 +52,7 @@ start_date = two_years_ago.strftime("%Y-%m-%d")
|
||||
end_date = current_date.strftime("%Y-%m-%d")
|
||||
this_week_date = seven_days_ago.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
|
||||
# 示例:获取前 3 页的数据
|
||||
max_pages = 100000
|
||||
@ -148,39 +156,39 @@ def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_d
|
||||
|
||||
# 获取股票报告列表
|
||||
def fetch_reports_list_stock(s_date=start_date, e_date=end_date):
|
||||
return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock')
|
||||
return fetch_reports_list_general(em.fetch_reports_by_stock, StockReportDB.TBL_STOCK, s_date, e_date, 'stock')
|
||||
|
||||
# 获取股票报告列表
|
||||
def fetch_reports_list_newstock(s_date=start_date, e_date=end_date):
|
||||
return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new')
|
||||
return fetch_reports_list_general(em.fetch_reports_by_newstock, StockReportDB.TBL_NEW_STOCK, s_date, e_date, 'new')
|
||||
|
||||
# 获取行业报告列表
|
||||
def fetch_reports_list_industry(s_date=start_date, e_date=end_date):
|
||||
return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry')
|
||||
return fetch_reports_list_general(em.fetch_reports_by_industry, StockReportDB.TBL_INDUSTRY, s_date, e_date, 'industry')
|
||||
|
||||
# 获取行业报告列表
|
||||
def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date):
|
||||
return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch')
|
||||
return fetch_reports_list_general(em.fetch_reports_by_macresearch, StockReportDB.TBL_MACRESEARCH, s_date, e_date, 'macresearch')
|
||||
|
||||
# 获取行业报告列表
|
||||
def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
|
||||
return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy')
|
||||
return fetch_reports_list_general(em.fetch_reports_by_strategy, StockReportDB.TBL_STRATEGY, s_date, e_date, 'strategy')
|
||||
|
||||
# 下载股票pdf
|
||||
def download_pdf_stock(s_date=start_date, e_date=end_date):
|
||||
download_pdf_stock_general(parse_func_general, utils.tbl_stock, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STOCK, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
|
||||
def download_pdf_newstock(s_date=start_date, e_date=end_date):
|
||||
download_pdf_stock_general(parse_func_general, utils.tbl_new_stock, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_NEW_STOCK, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
|
||||
def download_pdf_industry(s_date=start_date, e_date=end_date):
|
||||
download_pdf_stock_general(parse_func_general, utils.tbl_industry, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_INDUSTRY, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
|
||||
def download_pdf_macresearch(s_date=start_date, e_date=end_date):
|
||||
download_pdf_stock_general(parse_func_general, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_MACRESEARCH, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
|
||||
def download_pdf_strategy(s_date=start_date, e_date=end_date):
|
||||
download_pdf_stock_general(parse_func_general, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STRATEGY, ' ', s_date, e_date, limit=2 if debug else None)
|
||||
|
||||
|
||||
# 建立缩写到函数的映射
|
||||
@ -265,6 +273,15 @@ def main(cmd, mode, args_debug, args_force, begin, end):
|
||||
global end_date
|
||||
end_date = end if end else end_date
|
||||
|
||||
# 初始化DB
|
||||
global db_tools
|
||||
try:
|
||||
db_tools = StockReportDB(db_path)
|
||||
# 进行数据库操作
|
||||
except DatabaseConnectionError as e:
|
||||
logging.error(f"数据库连接失败: {e}")
|
||||
return False
|
||||
|
||||
# 开启任务
|
||||
#task_id = db_tools.insert_task_log()
|
||||
task_id = 0
|
||||
@ -1,10 +1,46 @@
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from collections import defaultdict
|
||||
from src.config.config import get_log_directory, get_src_directory
|
||||
|
||||
# 统计日志频率
|
||||
log_count = defaultdict(int) # 记录日志的次数
|
||||
last_log_time = defaultdict(float) # 记录上次写入的时间戳
|
||||
|
||||
class RateLimitFilter(logging.Filter):
|
||||
"""
|
||||
频率限制过滤器:
|
||||
1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略
|
||||
2. 如果日志速率超过 100 条/秒,发出告警
|
||||
"""
|
||||
LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次
|
||||
|
||||
def filter(self, record):
|
||||
global log_count, last_log_time
|
||||
message_key = record.getMessage() # 获取日志内容
|
||||
|
||||
# 计算当前时间
|
||||
now = time.time()
|
||||
elapsed = now - last_log_time[message_key]
|
||||
|
||||
# 限制相同日志的写入频率
|
||||
if elapsed < 60: # 60 秒内
|
||||
log_count[message_key] += 1
|
||||
if log_count[message_key] > self.LOG_LIMIT:
|
||||
return False # 直接丢弃
|
||||
else:
|
||||
log_count[message_key] = 1 # 超过 60 秒,重新计数
|
||||
|
||||
last_log_time[message_key] = now
|
||||
|
||||
return True # 允许写入日志
|
||||
|
||||
|
||||
def get_caller_filename():
|
||||
# 获取调用栈
|
||||
stack = inspect.stack()
|
||||
@ -26,7 +62,7 @@ def get_caller_filename():
|
||||
return os.path.splitext(os.path.basename(frame_info.filename))[0]
|
||||
return None
|
||||
|
||||
# 设置日志配置
|
||||
|
||||
def setup_logging(log_filename=None):
|
||||
# 如果未传入 log_filename,则使用当前脚本名称作为日志文件名
|
||||
if log_filename is None:
|
||||
@ -35,9 +71,29 @@ def setup_logging(log_filename=None):
|
||||
current_date = datetime.now().strftime('%Y%m%d')
|
||||
# 拼接 log 文件名,将日期加在扩展名前
|
||||
log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log'
|
||||
|
||||
max_log_size = 100 * 1024 * 1024 # 10 MB
|
||||
max_log_files = 10 # 最多保留 10 个日志文件
|
||||
|
||||
file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files)
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
|
||||
))
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s'
|
||||
))
|
||||
|
||||
# 创建 logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers = [] # 避免重复添加 handler
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 添加频率限制
|
||||
rate_limit_filter = RateLimitFilter()
|
||||
file_handler.addFilter(rate_limit_filter)
|
||||
console_handler.addFilter(rate_limit_filter)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_filename),
|
||||
logging.StreamHandler()
|
||||
])
|
||||
@ -1,6 +1,7 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import warnings
|
||||
from src.strategy.prepare import fetch_his_kline
|
||||
import src.config.config as config
|
||||
import src.crawler.em.stock as em_stock
|
||||
@ -15,10 +16,15 @@ def select_stocks(stock_map):
|
||||
df = fetch_his_kline(stock_code)
|
||||
close_prices = df['close'].values
|
||||
|
||||
# 使用 MyTT 库计算不同周期的 RSI
|
||||
rsi_6 = RSI(close_prices, 6)
|
||||
rsi_12 = RSI(close_prices, 12)
|
||||
rsi_24 = RSI(close_prices, 24)
|
||||
# 捕获 RuntimeWarning 警告
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always", RuntimeWarning)
|
||||
# 使用 MyTT 库计算不同周期的 RSI
|
||||
rsi_6 = RSI(close_prices, 6)
|
||||
rsi_12 = RSI(close_prices, 12)
|
||||
rsi_24 = RSI(close_prices, 24)
|
||||
if w:
|
||||
print(f"股票代码 {stock_code} {stock_name} 在计算 RSI 时出现警告: {w[0].message}")
|
||||
|
||||
df['rsi_6'] = rsi_6
|
||||
df['rsi_12'] = rsi_12
|
||||
@ -96,7 +102,8 @@ if __name__ == "__main__":
|
||||
|
||||
codes = ['105.QFIN', '105.FUTU']
|
||||
# 从网络上获取
|
||||
stock_map = em_stock.code_by_fs('hk_famous', em_stock.em_market_fs_types['hk_famous'])
|
||||
plat_id = 'cn_hs300'
|
||||
stock_map = em_stock.code_by_fs(plat_id, em_stock.em_market_fs_types[plat_id])
|
||||
if stock_map:
|
||||
select_stocks(stock_map)
|
||||
|
||||
|
||||
@ -5,15 +5,6 @@ import time
|
||||
import csv
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import config
|
||||
|
||||
tbl_stock = 'reports_stock'
|
||||
tbl_new_stock = 'reports_newstrock'
|
||||
tbl_strategy = 'reports_strategy'
|
||||
tbl_macresearch = 'reports_macresearch'
|
||||
tbl_industry = 'reports_industry'
|
||||
|
||||
json_data_dir = f'{config.global_host_data_dir}/em_reports/json_data'
|
||||
|
||||
# 保存 JSON 数据到本地文件
|
||||
def save_json_to_file(data, file_path, file_name):
|
||||
Reference in New Issue
Block a user