Files
stock/src/db_utils/reports_mysql.py
2025-08-10 19:26:47 +08:00

195 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pymysql
from pymysql.cursors import DictCursor
import logging
import sys
from datetime import datetime
class DatabaseConnectionError(Exception):
pass
class StockReportMysql:
# 定义类属性(静态变量)
TBL_STOCK = 'reports_stock'
TBL_NEW_STOCK = 'reports_newstrock' # 注意原拼写可能存在笔误strock应为stock
TBL_STRATEGY = 'reports_strategy'
TBL_MACRESEARCH = 'reports_macresearch'
TBL_INDUSTRY = 'reports_industry'
def __init__(self, db_host, db_user, db_password, db_name, port=3306):
"""
初始化MySQL连接
:param db_host: 数据库主机地址
:param db_user: 数据库用户名
:param db_password: 数据库密码
:param db_name: 数据库名称
:param port: 数据库端口默认3306
"""
self.db_host = db_host
self.db_user = db_user
self.db_password = db_password
self.db_name = db_name
self.port = port
self.conn = None
self.cursor = None
try:
self.conn = pymysql.connect(
host=self.db_host,
user=self.db_user,
password=self.db_password,
database=self.db_name,
port=self.port,
charset='utf8mb4' # 支持中文
)
#self.cursor = self.conn.cursor(dictionary=False) # 使用非字典游标,保持与原代码兼容
self.cursor = self.conn.cursor() # 使用默认游标
except pymysql.MySQLError as e:
logging.error(f"数据库连接失败: {e}")
raise DatabaseConnectionError("数据库连接失败")
def __get_table_columns_and_defaults(self, tbl_name):
"""获取表的列信息及默认值适配MySQL"""
try:
# 查询information_schema获取列信息
self.cursor.execute("""
SELECT COLUMN_NAME, COLUMN_DEFAULT
FROM information_schema.COLUMNS
WHERE TABLE_NAME = %s AND TABLE_SCHEMA = %s
""", (tbl_name, self.db_name))
columns = self.cursor.fetchall()
column_info = {}
for col in columns:
col_name = col[0]
default_value = col[1]
# MySQL默认值可能包含函数如CURRENT_TIMESTAMP需要特殊处理
column_info[col_name] = default_value
return column_info
except pymysql.MySQLError as e:
logging.error(f"获取表结构失败: {e}")
return None
def __check_and_process_data(self, data, tbl_name):
"""数据校验和处理(逻辑与原代码保持一致)"""
column_info = self.__get_table_columns_and_defaults(tbl_name=tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id': # 自增主键,不需要用户提供
continue
if col == 'created_at' or col == 'updated_at': # 日期字段由数据库或代码控制
continue
if col in ['author', 'authorID']:
# 将列表转换为逗号分隔字符串
values = data.get(col, [])
processed_data[col] = ','.join(values) if values else None
# 确保不超过255字符
if processed_data[col] and len(processed_data[col]) > 250:
processed_data[col] = processed_data[col][:250]
elif col in data:
processed_data[col] = data[col]
else:
# 使用默认值
pass
return processed_data
def insert_or_update_common(self, data, tbl_name, uniq_key='infoCode'):
"""插入或更新数据适配MySQL的ON DUPLICATE KEY UPDATE"""
try:
processed_data = self.__check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['%s' for _ in values]) # MySQL使用%s作为占位符
# 构造更新子句(排除唯一键字段)
update_clause = ', '.join(
[f"{col}=VALUES({col})" for col in processed_data.keys() if col != uniq_key]
) + ', updated_at=NOW()' # MySQL使用NOW()获取当前时间
sql = f'''
INSERT INTO {tbl_name} ({columns}, created_at, updated_at)
VALUES ({placeholders}, NOW(), NOW())
ON DUPLICATE KEY UPDATE {update_clause}
'''
self.cursor.execute(sql, values)
self.conn.commit()
# 获取插入或更新后的记录ID
self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = %s", (data[uniq_key],))
result = self.cursor.fetchone()
return result[0] if result else None
except pymysql.MySQLError as e:
logging.error(f"插入或更新数据失败: {e}, data: {data}")
self.conn.rollback() # 出错时回滚
return None
def update_pages(self, data, tbl_name, uniq_key='infoCode'):
"""更新附件页数使用参数化查询防止SQL注入"""
try:
# 注意原代码直接拼接SQL有注入风险此处改为参数化查询
sql = f'''
UPDATE {tbl_name}
SET attachPages = %s
WHERE id = %s
'''
self.cursor.execute(sql, (data['attachPages'], data['id']))
self.conn.commit()
return data['id']
except pymysql.MySQLError as e:
logging.error(f"更新页数失败: {e}")
self.conn.rollback()
return None
def query_reports_comm(self, tbl_name, querystr='', limit=None):
"""查询报告列表适配MySQL语法"""
try:
# 验证表名合法性
valid_tables = [
self.TBL_STOCK, self.TBL_NEW_STOCK,
self.TBL_INDUSTRY, self.TBL_MACRESEARCH,
self.TBL_STRATEGY
]
if tbl_name not in valid_tables:
logging.warning(f'无效的表名: {tbl_name}')
return None
# 构造查询SQL
sql = f"""
SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate
FROM {tbl_name}
WHERE 1=1 {querystr}
"""
# 添加限制条件
if limit:
sql += f' LIMIT {limit}'
self.cursor.execute(sql)
results = self.cursor.fetchall()
# 获取列名
column_names = [description[0] for description in self.cursor.description]
# 转换为字典列表
result_dict_list = []
for row in results:
row_dict = {column_names[i]: value for i, value in enumerate(row)}
result_dict_list.append(row_dict)
return result_dict_list
except pymysql.MySQLError as e:
logging.error(f"查询失败: {e}")
return None
def __del__(self):
"""析构函数关闭数据库连接适配pymysql"""
try:
# pymysql中通过ping()判断连接是否有效,若连接已关闭会抛出异常
if self.conn:
self.conn.ping() # 尝试检测连接是否存活
self.cursor.close()
self.conn.close()
except (pymysql.MySQLError, AttributeError):
# 捕获连接已关闭、游标不存在等异常,避免析构时报错
pass