diff --git a/reports_em/config.py b/reports_em/config.py deleted file mode 100644 index 289007a..0000000 --- a/reports_em/config.py +++ /dev/null @@ -1,86 +0,0 @@ -import logging -import os -import inspect -import time -from datetime import datetime -from logging.handlers import RotatingFileHandler -from collections import defaultdict - -home_dir = os.path.expanduser("~") -global_host_data_dir = f'{home_dir}/hostdir/stock_data' -global_share_db_dir = f'{home_dir}/sharedata/sqlite' - -# 统计日志频率 -log_count = defaultdict(int) # 记录日志的次数 -last_log_time = defaultdict(float) # 记录上次写入的时间戳 - -class RateLimitFilter(logging.Filter): - """ - 频率限制过滤器: - 1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略 - 2. 如果日志速率超过 100 条/秒,发出告警 - """ - LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次 - - def filter(self, record): - global log_count, last_log_time - message_key = record.getMessage() # 获取日志内容 - - # 计算当前时间 - now = time.time() - elapsed = now - last_log_time[message_key] - - # 限制相同日志的写入频率 - if elapsed < 60: # 60 秒内 - log_count[message_key] += 1 - if log_count[message_key] > self.LOG_LIMIT: - return False # 直接丢弃 - else: - log_count[message_key] = 1 # 超过 60 秒,重新计数 - - last_log_time[message_key] = now - - return True # 允许写入日志 - - - -def setup_logging(log_filename=None): - if log_filename is None: - caller_frame = inspect.stack()[1] - caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0] - current_date = datetime.now().strftime('%Y%m%d') - log_filename = f'../log/{caller_filename}_{current_date}.log' - - max_log_size = 100 * 1024 * 1024 # 10 MB - max_log_files = 10 # 最多保留 10 个日志文件 - - file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files) - file_handler.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s' - )) - - console_handler = logging.StreamHandler() - console_handler.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s' - )) - - # 创建 logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) - logger.handlers = [] # 避免重复添加 handler - logger.addHandler(file_handler) - logger.addHandler(console_handler) - - # 添加频率限制 - rate_limit_filter = RateLimitFilter() - file_handler.addFilter(rate_limit_filter) - console_handler.addFilter(rate_limit_filter) - - -# 运行示例 -if __name__ == "__main__": - setup_logging() - - for i in range(1000): - logging.info("测试日志,检测频率限制") - time.sleep(0.01) # 模拟快速写入日志 \ No newline at end of file diff --git a/reports_em/deploy.sh b/reports_em/deploy.sh deleted file mode 100755 index d18eaa6..0000000 --- a/reports_em/deploy.sh +++ /dev/null @@ -1,115 +0,0 @@ -#!/bin/bash - -# 远程服务器列表,按需修改 -SERVERS=("175.178.54.98" "1.12.218.143" "43.139.169.25" "129.204.180.174" "42.194.142.169") -REMOTE_USER="ubuntu" -REMOTE_SCRIPT_DIR="/home/ubuntu/pyscripts/stockapp/reports_em" -REMOTE_LOG_DIR="/home/ubuntu/pyscripts/stockapp/log" -DATA_DIR="/home/ubuntu/hostdir/stock_data/pdfs" -LOCAL_FILES=("config.py" "em_reports.py" "fetch.py") - -# 远程任务参数配置(每台服务器的不同参数) -TASK_PARAMS=( - "--mode=down --begin=2024-01-01 --end=2024-06-30" - "--mode=down --begin=2023-01-01 --end=2023-06-30" - "--mode=down --begin=2022-01-01 --end=2022-06-30" - "--mode=down --begin=2021-01-01 --end=2021-06-30" - "--mode=down --begin=2020-01-01 --end=2020-06-30" -) - -# 推送代码到所有服务器 -function push_code() { - for SERVER in "${SERVERS[@]}"; do - echo "Pushing code to $SERVER..." - scp "${LOCAL_FILES[@]}" "$REMOTE_USER@$SERVER:$REMOTE_SCRIPT_DIR/" - done -} - -# 启动任务 -function start_tasks() { - for i in "${!SERVERS[@]}"; do - SERVER="${SERVERS[$i]}" - PARAMS="${TASK_PARAMS[$i]}" - echo "Starting task on $SERVER with params: $PARAMS" - #ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 &" - #ssh "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown" - # nohup ... < /dev/null:防止 nohup 等待 stdin,立即释放控制权。 - # & disown:确保进程与 SSH 彻底分离,避免 SIGHUP 信号影响,SSH 立即返回。 - # ssh -n:禁用 ssh 的 stdin,防止远程进程等待输入。 - ssh -n "$REMOTE_USER@$SERVER" "cd $REMOTE_SCRIPT_DIR && nohup python3 ./fetch.py $PARAMS > ../log/nohup.log 2>&1 < /dev/null & disown" - done -} - -# 停止任务 -function stop_tasks() { - for SERVER in "${SERVERS[@]}"; do - echo "Stopping task on $SERVER..." - ssh "$REMOTE_USER@$SERVER" "pkill -f 'python3 ./fetch.py'" - done -} - -# 获取任务进度 -function check_progress() { - for SERVER in "${SERVERS[@]}"; do - echo -e "\nChecking progress on $SERVER..." - FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l") - FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR") - PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l") - - if [ "$PROCESS_COUNT" -gt 0 ]; then - echo "Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress" - else - echo "Process status: Not running" - fi - - echo "Total files: $FILE_COUNT" - echo "Total size : $FILE_SIZE" - - ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l") - echo "Error lines: $ERROR_LINES" - - TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l") - echo "Task count: $TASK_COUNT" - - done -} - -# 获取任务进度 -function check_progress_robot() { - result="" - for SERVER in "${SERVERS[@]}"; do - result+="\nChecking progress on $SERVER...\n" - FILE_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ls -lRh $DATA_DIR | grep pdf | wc -l") - FILE_SIZE=$(ssh "$REMOTE_USER@$SERVER" "du -sh $DATA_DIR") - PROCESS_COUNT=$(ssh "$REMOTE_USER@$SERVER" "ps aux | grep '[f]etch.py' | wc -l") - - if [ "$PROCESS_COUNT" -gt 0 ]; then - result+="Process status: Running ($PROCESS_COUNT instances), if 2, include parent progress\n" - else - result+="Process status: Not running\n" - fi - - result+="Total files: $FILE_COUNT\n" - result+="Total size : $FILE_SIZE\n" - - ERROR_LINES=$(ssh "$REMOTE_USER@$SERVER" "grep -v INFO $REMOTE_LOG_DIR/fetch_202503* | wc -l") - result+="Error lines: $ERROR_LINES\n" - - TASK_COUNT=$(ssh "$REMOTE_USER@$SERVER" "grep 'running task. id' ~/pyscripts/stockapp/log/fetch_20250316.log | wc -l") - done - echo -e "$result" - # 调用 Python 脚本发送消息 - python3 ./robot.py "$result" -} - -# 脚本菜单 -case "$1" in - push) push_code ;; - start) start_tasks ;; - stop) stop_tasks ;; - check) check_progress ;; - *) - echo "Usage: $0 {push|start|stop|check}" - exit 1 - ;; -esac diff --git a/reports_em/sqlite_utils.py b/reports_em/sqlite_utils.py deleted file mode 100644 index 938de21..0000000 --- a/reports_em/sqlite_utils.py +++ /dev/null @@ -1,248 +0,0 @@ -import sqlite3 -import json -import config -import utils -import logging -import sys -from datetime import datetime - -# 连接 SQLite 数据库 -DB_PATH = f"{config.global_share_db_dir}/stock_report.db" # 替换为你的数据库文件 -conn = sqlite3.connect(DB_PATH) -cursor = conn.cursor() - -# 获取表的列名和默认值 -def get_table_columns_and_defaults(tbl_name): - try: - cursor.execute(f"PRAGMA table_info({tbl_name})") - columns = cursor.fetchall() - column_info = {} - for col in columns: - col_name = col[1] - default_value = col[4] - column_info[col_name] = default_value - return column_info - except sqlite3.Error as e: - logging.error(f"Error getting table columns: {e}") - return None - -# 检查并处理数据 -def check_and_process_data(data, tbl_name): - column_info = get_table_columns_and_defaults(tbl_name=tbl_name) - if column_info is None: - return None - processed_data = {} - for col, default in column_info.items(): - if col == 'id': # 自增主键,不需要用户提供 - continue - if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可 - continue - if col in ['author', 'authorID']: - values = data.get(col, []) - processed_data[col] = ','.join(values) - elif col in data: - processed_data[col] = data[col] - else: - if default is not None: - processed_data[col] = default - else: - processed_data[col] = None - return processed_data - - -# 插入或更新数据 -def insert_or_update_common(data, tbl_name, uniq_key='infoCode'): - try: - processed_data = check_and_process_data(data, tbl_name) - if processed_data is None: - return None - - columns = ', '.join(processed_data.keys()) - values = list(processed_data.values()) - placeholders = ', '.join(['?' for _ in values]) - update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')' - - sql = f''' - INSERT INTO {tbl_name} ({columns}, updated_at) - VALUES ({placeholders}, datetime('now', 'localtime')) - ON CONFLICT (infoCode) DO UPDATE SET {update_clause} - ''' - cursor.execute(sql, values) - conn.commit() - - # 获取插入或更新后的 report_id - cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],)) - report_id = cursor.fetchone()[0] - return report_id - except sqlite3.Error as e: - logging.error(f"Error inserting or updating data: {e}") - return None - - -# 查询数据 -def query_reports_comm(tbl_name, querystr='', limit=None ): - try: - if tbl_name in [utils.tbl_stock, utils.tbl_new_stock, utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] : - sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}" - else: - logging.warning(f'wrong table name: {tbl_name}') - return None - - if limit : - sql = sql + f' limit {limit}' - - cursor.execute(sql) - results = cursor.fetchall() - - # 获取列名 - column_names = [description[0] for description in cursor.description] - - # 将结果转换为字典列表 - result_dict_list = [] - for row in results: - row_dict = {column_names[i]: value for i, value in enumerate(row)} - result_dict_list.append(row_dict) - - return result_dict_list - - except sqlite3.Error as e: - logging.error(f"查询 href 失败: {e}") - return None - -''' -# 插入或更新 industry_report 数据 -def insert_or_update_report(report): - try: - sql = """ - INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate, - industryCode, industryName, emRatingCode, emRatingValue, - emRatingName, attachSize, attachPages, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime')) - ON CONFLICT(infoCode) DO UPDATE SET - title=excluded.title, - orgCode=excluded.orgCode, - orgName=excluded.orgName, - orgSName=excluded.orgSName, - publishDate=excluded.publishDate, - industryCode=excluded.industryCode, - industryName=excluded.industryName, - emRatingCode=excluded.emRatingCode, - emRatingValue=excluded.emRatingValue, - emRatingName=excluded.emRatingName, - attachSize=excluded.attachSize, - attachPages=excluded.attachPages, - updated_at=datetime('now', 'localtime') - """ - - values = ( - report["infoCode"], report["title"], report["orgCode"], report["orgName"], - report["orgSName"], report["publishDate"], report["industryCode"], - report["industryName"], report["emRatingCode"], report["emRatingValue"], - report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0) - ) - - cursor.execute(sql, values) - conn.commit() - - # 获取插入或更新后的 report_id - cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],)) - report_id = cursor.fetchone()[0] - return report_id - - except sqlite3.Error as e: - conn.rollback() - logging.error(f"数据库错误: {e}") - return None - except Exception as e: - conn.rollback() - logging.error(f"未知错误: {e}") - return None - -# 查询研报数据 -def query_industry_reports(querystr='', limit=None): - try: - sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}" - if limit : - sql = sql + f' limit {limit}' - - cursor.execute(sql) - results = cursor.fetchall() - return results - - except sqlite3.Error as e: - logging.error(f"查询 href 失败: {e}") - return None - -# 插入或更新 industry_report 数据 -def insert_or_update_stock_report(report): - try: - sql = """ - INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName, - publishDate, industryCode, industryName, emIndustryCode, emRatingCode, - emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime')) - ON CONFLICT(infoCode) DO UPDATE SET - title=excluded.title, - stockName=excluded.stockName, - stockCode=excluded.stockCode, - orgCode=excluded.orgCode, - orgName=excluded.orgName, - orgSName=excluded.orgSName, - publishDate=excluded.publishDate, - industryCode=excluded.industryCode, - industryName=excluded.industryName, - emIndustryCode=excluded.emIndustryCode, - emRatingCode=excluded.emRatingCode, - emRatingValue=excluded.emRatingValue, - emRatingName=excluded.emRatingName, - attachPages=excluded.attachPages, - attachSize=excluded.attachSize, - author=excluded.author, - authorID=excluded.authorID, - updated_at=datetime('now', 'localtime') - """ - - values = ( - report["infoCode"], report["title"], report["stockName"], report["stockCode"], - report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"], - report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""), - report["emRatingCode"], report["emRatingValue"], report["emRatingName"], - report.get("attachPages", 0), report.get("attachSize", 0), - ",".join(report.get("author", [])), ",".join(report.get("authorID", [])) - ) - - cursor.execute(sql, values) - conn.commit() - - # 获取插入或更新后的 report_id - cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],)) - report_id = cursor.fetchone()[0] - return report_id - - - except sqlite3.Error as e: - conn.rollback() - logging.error(f"数据库错误: {e}") - return None - except Exception as e: - conn.rollback() - logging.error(f"未知错误: {e}") - return None - - -# 查询研报数据 -def query_stock_reports(querystr='', limit=None): - try: - sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}" - if limit : - sql = sql + f' limit {limit}' - - cursor.execute(sql) - results = cursor.fetchall() - return results - - except sqlite3.Error as e: - logging.error(f"查询 href 失败: {e}") - return None - -''' \ No newline at end of file diff --git a/src/config/bak.py b/src/config/bak.py deleted file mode 100644 index 246bf64..0000000 --- a/src/config/bak.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -import os -import inspect -from datetime import datetime -from pathlib import Path - -# MySQL 配置 -db_config = { - 'host': 'testdb', - 'user': 'root', - 'password': 'mysqlpw', - 'database': 'stockdb' -} - -log_dir_prefix = '../log' - -global_share_data_dir = '/root/sharedata' -global_stock_data_dir = '/root/hostdir/stock_data' - -# 获取当前文件所在目录 -current_dir = os.path.dirname(os.path.abspath(__file__)) -# 获取项目根目录(假设当前文件在 src/strategy 下) -project_root = os.path.abspath(os.path.join(current_dir, '..', '..')) - -# 获取log目录 -def get_log_directory(): - """ - 获取项目根目录下的 log 目录路径。如果 log 目录不存在,则自动创建。 - """ - # 获取当前文件所在目录 - current_dir = Path(__file__).resolve().parent - - # 找到项目根目录,假设项目根目录下有一个 log 文件夹 - project_root = current_dir - while project_root.name != 'src' and project_root != project_root.parent: - project_root = project_root.parent - project_root = project_root.parent # 回到项目根目录 - - # 确保 log 目录存在 - log_dir = project_root / 'log' - log_dir.mkdir(parents=True, exist_ok=True) - - return log_dir - -def get_caller_filename(): - # 获取调用 setup_logging 的脚本文件名 - caller_frame = inspect.stack()[2] - caller_filename = os.path.splitext(os.path.basename(caller_frame.filename))[0] - return caller_filename - -# 设置日志配置 -def setup_logging(log_filename=None): - # 如果未传入 log_filename,则使用当前脚本名称作为日志文件名 - if log_filename is None: - caller_filename = get_caller_filename() - common_log_dir = get_log_directory() - current_date = datetime.now().strftime('%Y%m%d') - # 拼接 log 文件名,将日期加在扩展名前 - log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log' - - logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s', - handlers=[ - logging.FileHandler(log_filename), - logging.StreamHandler() - ]) \ No newline at end of file diff --git a/src/config/config.py b/src/config/config.py index 8a8c7f3..b187a7e 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -9,8 +9,13 @@ db_config = { 'database': 'stockdb' } -global_share_data_dir = '/root/sharedata' -global_stock_data_dir = '/root/hostdir/stock_data' +home_dir = os.path.expanduser("~") +global_host_data_dir = f'{home_dir}/hostdir/stock_data' +global_share_db_dir = f'{home_dir}/sharedata/sqlite' + +# 兼容以前的定义 +global_stock_data_dir = global_host_data_dir +global_share_data_dir = f'{home_dir}/sharedata' # 获取当前文件所在目录 current_dir = Path(__file__).resolve().parent diff --git a/reports_em/em_reports.py b/src/crawler/em/reports.py similarity index 99% rename from reports_em/em_reports.py rename to src/crawler/em/reports.py index 390f6a7..1e21b1d 100644 --- a/reports_em/em_reports.py +++ b/src/crawler/em/reports.py @@ -4,8 +4,6 @@ import requests import time import logging from bs4 import BeautifulSoup -import sqlite_utils as db_tools -import config # 获取个股研报列表的指定页 def fetch_reports_by_stock(page_no, start_date="2023-03-10", end_date="2025-03-10", page_size=50, max_retries = 3): diff --git a/src/db_utils/reports.py b/src/db_utils/reports.py new file mode 100644 index 0000000..783b771 --- /dev/null +++ b/src/db_utils/reports.py @@ -0,0 +1,123 @@ +import sqlite3 +import logging +import sys +from datetime import datetime + +class DatabaseConnectionError(Exception): + pass + +class StockReportDB: + # 定义类属性(静态变量) + TBL_STOCK = 'reports_stock' + TBL_NEW_STOCK = 'reports_newstrock' + TBL_STRATEGY = 'reports_strategy' + TBL_MACRESEARCH = 'reports_macresearch' + TBL_INDUSTRY = 'reports_industry' + + def __init__(self, db_path): + self.DB_PATH = db_path + self.conn = None + self.cursor = None + try: + self.conn = sqlite3.connect(self.DB_PATH) + self.cursor = self.conn.cursor() + except sqlite3.Error as e: + logging.error(f"数据库连接失败: {e}") + raise DatabaseConnectionError("数据库连接失败") + + def __get_table_columns_and_defaults(self, tbl_name): + try: + self.cursor.execute(f"PRAGMA table_info({tbl_name})") + columns = self.cursor.fetchall() + column_info = {} + for col in columns: + col_name = col[1] + default_value = col[4] + column_info[col_name] = default_value + return column_info + except sqlite3.Error as e: + logging.error(f"Error getting table columns: {e}") + return None + + def __check_and_process_data(self, data, tbl_name): + column_info = self.__get_table_columns_and_defaults(tbl_name=tbl_name) + if column_info is None: + return None + processed_data = {} + for col, default in column_info.items(): + if col == 'id': # 自增主键,不需要用户提供 + continue + if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可 + continue + if col in ['author', 'authorID']: + values = data.get(col, []) + processed_data[col] = ','.join(values) + elif col in data: + processed_data[col] = data[col] + else: + if default is not None: + processed_data[col] = default + else: + processed_data[col] = None + return processed_data + + + def insert_or_update_common(self, data, tbl_name, uniq_key='infoCode'): + try: + processed_data = self.__check_and_process_data(data, tbl_name) + if processed_data is None: + return None + + columns = ', '.join(processed_data.keys()) + values = list(processed_data.values()) + placeholders = ', '.join(['?' for _ in values]) + update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')' + + sql = f''' + INSERT INTO {tbl_name} ({columns}, updated_at) + VALUES ({placeholders}, datetime('now', 'localtime')) + ON CONFLICT (infoCode) DO UPDATE SET {update_clause} + ''' + self.cursor.execute(sql, values) + self.conn.commit() + + # 获取插入或更新后的 report_id + self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],)) + report_id = self.cursor.fetchone()[0] + return report_id + except sqlite3.Error as e: + logging.error(f"Error inserting or updating data: {e}") + return None + + def query_reports_comm(self, tbl_name, querystr='', limit=None): + try: + if tbl_name in [StockReportDB.TBL_STOCK, StockReportDB.TBL_NEW_STOCK, StockReportDB.TBL_INDUSTRY, StockReportDB.TBL_MACRESEARCH, StockReportDB.TBL_STRATEGY]: + sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}" + else: + logging.warning(f'wrong table name: {tbl_name}') + return None + + if limit: + sql = sql + f' limit {limit}' + + self.cursor.execute(sql) + results = self.cursor.fetchall() + + # 获取列名 + column_names = [description[0] for description in self.cursor.description] + + # 将结果转换为字典列表 + result_dict_list = [] + for row in results: + row_dict = {column_names[i]: value for i, value in enumerate(row)} + result_dict_list.append(row_dict) + + return result_dict_list + except sqlite3.Error as e: + logging.error(f"查询 href 失败: {e}") + return None + + def __del__(self): + if self.conn: + self.conn.close() + \ No newline at end of file diff --git a/reports_em/fetch.py b/src/em_reports/fetch.py similarity index 80% rename from reports_em/fetch.py rename to src/em_reports/fetch.py index bce7cbc..d5a1f31 100644 --- a/reports_em/fetch.py +++ b/src/em_reports/fetch.py @@ -9,34 +9,41 @@ import shutil import logging from datetime import datetime, timedelta from functools import partial -import config -import sqlite_utils as db_tools -import em_reports as em -import utils +import src.crawler.em.reports as em +import src.utils.utils as utils +from src.config.config import global_host_data_dir, global_share_db_dir +from src.db_utils.reports import StockReportDB, DatabaseConnectionError +from src.logger.logger import setup_logging -config.setup_logging() +# 初始化日志 +setup_logging() debug = False force = False -pdf_base_dir = f"{config.global_host_data_dir}/pdfs" # 下载 PDF 存放目录 - +pdf_base_dir = f"{global_host_data_dir}/pdfs" # 下载 PDF 存放目录 +# 定义下载页面的链接 map_pdf_page = { - utils.tbl_stock : "https://data.eastmoney.com/report/info/{}.html", - utils.tbl_new_stock : "https://data.eastmoney.com/report/info/{}.html", - utils.tbl_strategy : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}", - utils.tbl_macresearch : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}", - utils.tbl_industry : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}" + StockReportDB.TBL_STOCK : "https://data.eastmoney.com/report/info/{}.html", + StockReportDB.TBL_NEW_STOCK : "https://data.eastmoney.com/report/info/{}.html", + StockReportDB.TBL_STRATEGY : "https://data.eastmoney.com/report/zw_strategy.jshtml?encodeUrl={}", + StockReportDB.TBL_MACRESEARCH : "https://data.eastmoney.com/report/zw_macresearch.jshtml?encodeUrl={}", + StockReportDB.TBL_INDUSTRY : "https://data.eastmoney.com/report/zw_industry.jshtml?infocode={}" } +# 定义表名的映射,作为存储路径用 map_tbl_name = { - utils.tbl_stock : '个股研报', - utils.tbl_new_stock : '新股研报', - utils.tbl_strategy : '策略报告', - utils.tbl_macresearch : '宏观研究', - utils.tbl_industry : '行业研报' + StockReportDB.TBL_STOCK : '个股研报', + StockReportDB.TBL_NEW_STOCK : '新股研报', + StockReportDB.TBL_STRATEGY : '策略报告', + StockReportDB.TBL_MACRESEARCH : '宏观研究', + StockReportDB.TBL_INDUSTRY : '行业研报' } +# 初始化数据库连接 +db_path = f"{global_share_db_dir}/stock_report.db" +db_tools = None + current_date = datetime.now() seven_days_ago = current_date - timedelta(days=7) two_years_ago = current_date - timedelta(days=2*365) @@ -45,6 +52,7 @@ start_date = two_years_ago.strftime("%Y-%m-%d") end_date = current_date.strftime("%Y-%m-%d") this_week_date = seven_days_ago.strftime("%Y-%m-%d") + def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix): # 示例:获取前 3 页的数据 max_pages = 100000 @@ -148,39 +156,39 @@ def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_d # 获取股票报告列表 def fetch_reports_list_stock(s_date=start_date, e_date=end_date): - return fetch_reports_list_general(em.fetch_reports_by_stock, utils.tbl_stock, s_date, e_date, 'stock') + return fetch_reports_list_general(em.fetch_reports_by_stock, StockReportDB.TBL_STOCK, s_date, e_date, 'stock') # 获取股票报告列表 def fetch_reports_list_newstock(s_date=start_date, e_date=end_date): - return fetch_reports_list_general(em.fetch_reports_by_newstock, utils.tbl_new_stock, s_date, e_date, 'new') + return fetch_reports_list_general(em.fetch_reports_by_newstock, StockReportDB.TBL_NEW_STOCK, s_date, e_date, 'new') # 获取行业报告列表 def fetch_reports_list_industry(s_date=start_date, e_date=end_date): - return fetch_reports_list_general(em.fetch_reports_by_industry, utils.tbl_industry, s_date, e_date, 'industry') + return fetch_reports_list_general(em.fetch_reports_by_industry, StockReportDB.TBL_INDUSTRY, s_date, e_date, 'industry') # 获取行业报告列表 def fetch_reports_list_macresearch(s_date=start_date, e_date=end_date): - return fetch_reports_list_general(em.fetch_reports_by_macresearch, utils.tbl_macresearch, s_date, e_date, 'macresearch') + return fetch_reports_list_general(em.fetch_reports_by_macresearch, StockReportDB.TBL_MACRESEARCH, s_date, e_date, 'macresearch') # 获取行业报告列表 def fetch_reports_list_strategy(s_date=start_date, e_date=end_date): - return fetch_reports_list_general(em.fetch_reports_by_strategy, utils.tbl_strategy, s_date, e_date, 'strategy') + return fetch_reports_list_general(em.fetch_reports_by_strategy, StockReportDB.TBL_STRATEGY, s_date, e_date, 'strategy') # 下载股票pdf def download_pdf_stock(s_date=start_date, e_date=end_date): - download_pdf_stock_general(parse_func_general, utils.tbl_stock, ' ', s_date, e_date, limit=2 if debug else None) + download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STOCK, ' ', s_date, e_date, limit=2 if debug else None) def download_pdf_newstock(s_date=start_date, e_date=end_date): - download_pdf_stock_general(parse_func_general, utils.tbl_new_stock, ' ', s_date, e_date, limit=2 if debug else None) + download_pdf_stock_general(parse_func_general, StockReportDB.TBL_NEW_STOCK, ' ', s_date, e_date, limit=2 if debug else None) def download_pdf_industry(s_date=start_date, e_date=end_date): - download_pdf_stock_general(parse_func_general, utils.tbl_industry, ' ', s_date, e_date, limit=2 if debug else None) + download_pdf_stock_general(parse_func_general, StockReportDB.TBL_INDUSTRY, ' ', s_date, e_date, limit=2 if debug else None) def download_pdf_macresearch(s_date=start_date, e_date=end_date): - download_pdf_stock_general(parse_func_general, utils.tbl_macresearch, ' ', s_date, e_date, limit=2 if debug else None) + download_pdf_stock_general(parse_func_general, StockReportDB.TBL_MACRESEARCH, ' ', s_date, e_date, limit=2 if debug else None) def download_pdf_strategy(s_date=start_date, e_date=end_date): - download_pdf_stock_general(parse_func_general, utils.tbl_strategy, ' ', s_date, e_date, limit=2 if debug else None) + download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STRATEGY, ' ', s_date, e_date, limit=2 if debug else None) # 建立缩写到函数的映射 @@ -265,6 +273,15 @@ def main(cmd, mode, args_debug, args_force, begin, end): global end_date end_date = end if end else end_date + # 初始化DB + global db_tools + try: + db_tools = StockReportDB(db_path) + # 进行数据库操作 + except DatabaseConnectionError as e: + logging.error(f"数据库连接失败: {e}") + return False + # 开启任务 #task_id = db_tools.insert_task_log() task_id = 0 diff --git a/src/logger/logger.py b/src/logger/logger.py index 1aab99e..a546b38 100644 --- a/src/logger/logger.py +++ b/src/logger/logger.py @@ -1,10 +1,46 @@ import logging import os import inspect +import time from datetime import datetime from pathlib import Path +from logging.handlers import RotatingFileHandler +from collections import defaultdict from src.config.config import get_log_directory, get_src_directory +# 统计日志频率 +log_count = defaultdict(int) # 记录日志的次数 +last_log_time = defaultdict(float) # 记录上次写入的时间戳 + +class RateLimitFilter(logging.Filter): + """ + 频率限制过滤器: + 1. 在 60 秒内,同样的日志最多写入 60 次,超过则忽略 + 2. 如果日志速率超过 100 条/秒,发出告警 + """ + LOG_LIMIT = 600 # 每分钟最多记录相同消息 10 次 + + def filter(self, record): + global log_count, last_log_time + message_key = record.getMessage() # 获取日志内容 + + # 计算当前时间 + now = time.time() + elapsed = now - last_log_time[message_key] + + # 限制相同日志的写入频率 + if elapsed < 60: # 60 秒内 + log_count[message_key] += 1 + if log_count[message_key] > self.LOG_LIMIT: + return False # 直接丢弃 + else: + log_count[message_key] = 1 # 超过 60 秒,重新计数 + + last_log_time[message_key] = now + + return True # 允许写入日志 + + def get_caller_filename(): # 获取调用栈 stack = inspect.stack() @@ -26,7 +62,7 @@ def get_caller_filename(): return os.path.splitext(os.path.basename(frame_info.filename))[0] return None -# 设置日志配置 + def setup_logging(log_filename=None): # 如果未传入 log_filename,则使用当前脚本名称作为日志文件名 if log_filename is None: @@ -35,9 +71,29 @@ def setup_logging(log_filename=None): current_date = datetime.now().strftime('%Y%m%d') # 拼接 log 文件名,将日期加在扩展名前 log_filename = f'{common_log_dir}/{caller_filename}_{current_date}.log' + + max_log_size = 100 * 1024 * 1024 # 10 MB + max_log_files = 10 # 最多保留 10 个日志文件 + + file_handler = RotatingFileHandler(log_filename, maxBytes=max_log_size, backupCount=max_log_files) + file_handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s' + )) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s' + )) + + # 创建 logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) + logger.handlers = [] # 避免重复添加 handler + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + # 添加频率限制 + rate_limit_filter = RateLimitFilter() + file_handler.addFilter(rate_limit_filter) + console_handler.addFilter(rate_limit_filter) - logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] (%(funcName)s) - %(message)s', - handlers=[ - logging.FileHandler(log_filename), - logging.StreamHandler() - ]) \ No newline at end of file diff --git a/src/strategy/rsi.py b/src/strategy/rsi.py index 4e261d7..16ed990 100644 --- a/src/strategy/rsi.py +++ b/src/strategy/rsi.py @@ -1,6 +1,7 @@ import pandas as pd import numpy as np import os +import warnings from src.strategy.prepare import fetch_his_kline import src.config.config as config import src.crawler.em.stock as em_stock @@ -15,10 +16,15 @@ def select_stocks(stock_map): df = fetch_his_kline(stock_code) close_prices = df['close'].values - # 使用 MyTT 库计算不同周期的 RSI - rsi_6 = RSI(close_prices, 6) - rsi_12 = RSI(close_prices, 12) - rsi_24 = RSI(close_prices, 24) + # 捕获 RuntimeWarning 警告 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", RuntimeWarning) + # 使用 MyTT 库计算不同周期的 RSI + rsi_6 = RSI(close_prices, 6) + rsi_12 = RSI(close_prices, 12) + rsi_24 = RSI(close_prices, 24) + if w: + print(f"股票代码 {stock_code} {stock_name} 在计算 RSI 时出现警告: {w[0].message}") df['rsi_6'] = rsi_6 df['rsi_12'] = rsi_12 @@ -96,7 +102,8 @@ if __name__ == "__main__": codes = ['105.QFIN', '105.FUTU'] # 从网络上获取 - stock_map = em_stock.code_by_fs('hk_famous', em_stock.em_market_fs_types['hk_famous']) + plat_id = 'cn_hs300' + stock_map = em_stock.code_by_fs(plat_id, em_stock.em_market_fs_types[plat_id]) if stock_map: select_stocks(stock_map) diff --git a/reports_em/utils.py b/src/utils/utils.py similarity index 73% rename from reports_em/utils.py rename to src/utils/utils.py index 79707f8..1a067e2 100644 --- a/reports_em/utils.py +++ b/src/utils/utils.py @@ -5,15 +5,6 @@ import time import csv import logging from datetime import datetime -import config - -tbl_stock = 'reports_stock' -tbl_new_stock = 'reports_newstrock' -tbl_strategy = 'reports_strategy' -tbl_macresearch = 'reports_macresearch' -tbl_industry = 'reports_industry' - -json_data_dir = f'{config.global_host_data_dir}/em_reports/json_data' # 保存 JSON 数据到本地文件 def save_json_to_file(data, file_path, file_name):