195 lines
7.7 KiB
Python
195 lines
7.7 KiB
Python
import pymysql
|
||
from pymysql.cursors import DictCursor
|
||
import logging
|
||
import sys
|
||
from datetime import datetime
|
||
|
||
class DatabaseConnectionError(Exception):
|
||
pass
|
||
|
||
class StockReportMysql:
|
||
# 定义类属性(静态变量)
|
||
TBL_STOCK = 'reports_stock'
|
||
TBL_NEW_STOCK = 'reports_newstrock' # 注意原拼写可能存在笔误(strock应为stock)
|
||
TBL_STRATEGY = 'reports_strategy'
|
||
TBL_MACRESEARCH = 'reports_macresearch'
|
||
TBL_INDUSTRY = 'reports_industry'
|
||
|
||
def __init__(self, db_host, db_user, db_password, db_name, port=3306):
|
||
"""
|
||
初始化MySQL连接
|
||
:param db_host: 数据库主机地址
|
||
:param db_user: 数据库用户名
|
||
:param db_password: 数据库密码
|
||
:param db_name: 数据库名称
|
||
:param port: 数据库端口,默认3306
|
||
"""
|
||
self.db_host = db_host
|
||
self.db_user = db_user
|
||
self.db_password = db_password
|
||
self.db_name = db_name
|
||
self.port = port
|
||
self.conn = None
|
||
self.cursor = None
|
||
try:
|
||
self.conn = pymysql.connect(
|
||
host=self.db_host,
|
||
user=self.db_user,
|
||
password=self.db_password,
|
||
database=self.db_name,
|
||
port=self.port,
|
||
charset='utf8mb4' # 支持中文
|
||
)
|
||
#self.cursor = self.conn.cursor(dictionary=False) # 使用非字典游标,保持与原代码兼容
|
||
self.cursor = self.conn.cursor() # 使用默认游标
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"数据库连接失败: {e}")
|
||
raise DatabaseConnectionError("数据库连接失败")
|
||
|
||
def __get_table_columns_and_defaults(self, tbl_name):
|
||
"""获取表的列信息及默认值(适配MySQL)"""
|
||
try:
|
||
# 查询information_schema获取列信息
|
||
self.cursor.execute("""
|
||
SELECT COLUMN_NAME, COLUMN_DEFAULT
|
||
FROM information_schema.COLUMNS
|
||
WHERE TABLE_NAME = %s AND TABLE_SCHEMA = %s
|
||
""", (tbl_name, self.db_name))
|
||
columns = self.cursor.fetchall()
|
||
column_info = {}
|
||
for col in columns:
|
||
col_name = col[0]
|
||
default_value = col[1]
|
||
# MySQL默认值可能包含函数(如CURRENT_TIMESTAMP),需要特殊处理
|
||
column_info[col_name] = default_value
|
||
return column_info
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"获取表结构失败: {e}")
|
||
return None
|
||
|
||
def __check_and_process_data(self, data, tbl_name):
|
||
"""数据校验和处理(逻辑与原代码保持一致)"""
|
||
column_info = self.__get_table_columns_and_defaults(tbl_name=tbl_name)
|
||
if column_info is None:
|
||
return None
|
||
processed_data = {}
|
||
for col, default in column_info.items():
|
||
if col == 'id': # 自增主键,不需要用户提供
|
||
continue
|
||
if col == 'created_at' or col == 'updated_at': # 日期字段由数据库或代码控制
|
||
continue
|
||
if col in ['author', 'authorID']:
|
||
# 将列表转换为逗号分隔字符串
|
||
values = data.get(col, [])
|
||
processed_data[col] = ','.join(values) if values else None
|
||
# 确保不超过255字符
|
||
if processed_data[col] and len(processed_data[col]) > 250:
|
||
processed_data[col] = processed_data[col][:250]
|
||
elif col in data:
|
||
processed_data[col] = data[col]
|
||
else:
|
||
# 使用默认值
|
||
pass
|
||
return processed_data
|
||
|
||
def insert_or_update_common(self, data, tbl_name, uniq_key='infoCode'):
|
||
"""插入或更新数据(适配MySQL的ON DUPLICATE KEY UPDATE)"""
|
||
try:
|
||
processed_data = self.__check_and_process_data(data, tbl_name)
|
||
if processed_data is None:
|
||
return None
|
||
|
||
columns = ', '.join(processed_data.keys())
|
||
values = list(processed_data.values())
|
||
placeholders = ', '.join(['%s' for _ in values]) # MySQL使用%s作为占位符
|
||
|
||
# 构造更新子句(排除唯一键字段)
|
||
update_clause = ', '.join(
|
||
[f"{col}=VALUES({col})" for col in processed_data.keys() if col != uniq_key]
|
||
) + ', updated_at=NOW()' # MySQL使用NOW()获取当前时间
|
||
|
||
sql = f'''
|
||
INSERT INTO {tbl_name} ({columns}, created_at, updated_at)
|
||
VALUES ({placeholders}, NOW(), NOW())
|
||
ON DUPLICATE KEY UPDATE {update_clause}
|
||
'''
|
||
self.cursor.execute(sql, values)
|
||
self.conn.commit()
|
||
|
||
# 获取插入或更新后的记录ID
|
||
self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = %s", (data[uniq_key],))
|
||
result = self.cursor.fetchone()
|
||
return result[0] if result else None
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"插入或更新数据失败: {e}, data: {data}")
|
||
self.conn.rollback() # 出错时回滚
|
||
return None
|
||
|
||
def update_pages(self, data, tbl_name, uniq_key='infoCode'):
|
||
"""更新附件页数(使用参数化查询防止SQL注入)"""
|
||
try:
|
||
# 注意:原代码直接拼接SQL有注入风险,此处改为参数化查询
|
||
sql = f'''
|
||
UPDATE {tbl_name}
|
||
SET attachPages = %s
|
||
WHERE id = %s
|
||
'''
|
||
self.cursor.execute(sql, (data['attachPages'], data['id']))
|
||
self.conn.commit()
|
||
return data['id']
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"更新页数失败: {e}")
|
||
self.conn.rollback()
|
||
return None
|
||
|
||
def query_reports_comm(self, tbl_name, querystr='', limit=None):
|
||
"""查询报告列表(适配MySQL语法)"""
|
||
try:
|
||
# 验证表名合法性
|
||
valid_tables = [
|
||
self.TBL_STOCK, self.TBL_NEW_STOCK,
|
||
self.TBL_INDUSTRY, self.TBL_MACRESEARCH,
|
||
self.TBL_STRATEGY
|
||
]
|
||
if tbl_name not in valid_tables:
|
||
logging.warning(f'无效的表名: {tbl_name}')
|
||
return None
|
||
|
||
# 构造查询SQL
|
||
sql = f"""
|
||
SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate
|
||
FROM {tbl_name}
|
||
WHERE 1=1 {querystr}
|
||
"""
|
||
# 添加限制条件
|
||
if limit:
|
||
sql += f' LIMIT {limit}'
|
||
|
||
self.cursor.execute(sql)
|
||
results = self.cursor.fetchall()
|
||
|
||
# 获取列名
|
||
column_names = [description[0] for description in self.cursor.description]
|
||
|
||
# 转换为字典列表
|
||
result_dict_list = []
|
||
for row in results:
|
||
row_dict = {column_names[i]: value for i, value in enumerate(row)}
|
||
result_dict_list.append(row_dict)
|
||
|
||
return result_dict_list
|
||
except pymysql.MySQLError as e:
|
||
logging.error(f"查询失败: {e}")
|
||
return None
|
||
|
||
def __del__(self):
|
||
"""析构函数,关闭数据库连接(适配pymysql)"""
|
||
try:
|
||
# pymysql中通过ping()判断连接是否有效,若连接已关闭会抛出异常
|
||
if self.conn:
|
||
self.conn.ping() # 尝试检测连接是否存活
|
||
self.cursor.close()
|
||
self.conn.close()
|
||
except (pymysql.MySQLError, AttributeError):
|
||
# 捕获连接已关闭、游标不存在等异常,避免析构时报错
|
||
pass |