Files
stock/stockapp/reports_em/sqlite_utils.py
2025-03-11 14:22:37 +08:00

250 lines
9.3 KiB
Python

import sqlite3
import json
import config
import utils
import logging
import sys
from datetime import datetime
# 连接 SQLite 数据库
DB_PATH = f"{config.global_host_data_dir}/stock_report.db" # 替换为你的数据库文件
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# 获取表的列名和默认值
def get_table_columns_and_defaults(tbl_name):
try:
cursor.execute(f"PRAGMA table_info({tbl_name})")
columns = cursor.fetchall()
column_info = {}
for col in columns:
col_name = col[1]
default_value = col[4]
column_info[col_name] = default_value
return column_info
except sqlite3.Error as e:
logging.error(f"Error getting table columns: {e}")
return None
# 检查并处理数据
def check_and_process_data(data, tbl_name):
column_info = get_table_columns_and_defaults(tbl_name=tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id': # 自增主键,不需要用户提供
continue
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
continue
if col in ['author', 'authorID']:
values = data.get(col, [])
processed_data[col] = ','.join(values)
elif col in data:
processed_data[col] = data[col]
else:
if default is not None:
processed_data[col] = default
else:
processed_data[col] = None
return processed_data
# 插入或更新数据
def insert_or_update_common(data, tbl_name, uniq_key='infoCode'):
try:
processed_data = check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
sql = f'''
INSERT INTO {tbl_name} ({columns}, updated_at)
VALUES ({placeholders}, datetime('now', 'localtime'))
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
'''
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
# 查询数据
def query_reports_comm(tbl_name, querystr='', limit=None ):
try:
if tbl_name in [utils.tbl_stock, utils.tbl_new_stock] :
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
elif tbl_name in [utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] :
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
else:
logging.warning(f'wrong table name: {tbl_name}')
return None
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
# 获取列名
column_names = [description[0] for description in cursor.description]
# 将结果转换为字典列表
result_dict_list = []
for row in results:
row_dict = {column_names[i]: value for i, value in enumerate(row)}
result_dict_list.append(row_dict)
return result_dict_list
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''
# 插入或更新 industry_report 数据
def insert_or_update_report(report):
try:
sql = """
INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate,
industryCode, industryName, emRatingCode, emRatingValue,
emRatingName, attachSize, attachPages, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachSize=excluded.attachSize,
attachPages=excluded.attachPages,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["orgCode"], report["orgName"],
report["orgSName"], report["publishDate"], report["industryCode"],
report["industryName"], report["emRatingCode"], report["emRatingValue"],
report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0)
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_industry_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
# 插入或更新 industry_report 数据
def insert_or_update_stock_report(report):
try:
sql = """
INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName,
publishDate, industryCode, industryName, emIndustryCode, emRatingCode,
emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
ON CONFLICT(infoCode) DO UPDATE SET
title=excluded.title,
stockName=excluded.stockName,
stockCode=excluded.stockCode,
orgCode=excluded.orgCode,
orgName=excluded.orgName,
orgSName=excluded.orgSName,
publishDate=excluded.publishDate,
industryCode=excluded.industryCode,
industryName=excluded.industryName,
emIndustryCode=excluded.emIndustryCode,
emRatingCode=excluded.emRatingCode,
emRatingValue=excluded.emRatingValue,
emRatingName=excluded.emRatingName,
attachPages=excluded.attachPages,
attachSize=excluded.attachSize,
author=excluded.author,
authorID=excluded.authorID,
updated_at=datetime('now', 'localtime')
"""
values = (
report["infoCode"], report["title"], report["stockName"], report["stockCode"],
report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"],
report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""),
report["emRatingCode"], report["emRatingValue"], report["emRatingName"],
report.get("attachPages", 0), report.get("attachSize", 0),
",".join(report.get("author", [])), ",".join(report.get("authorID", []))
)
cursor.execute(sql, values)
conn.commit()
# 获取插入或更新后的 report_id
cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],))
report_id = cursor.fetchone()[0]
return report_id
except sqlite3.Error as e:
conn.rollback()
logging.error(f"数据库错误: {e}")
return None
except Exception as e:
conn.rollback()
logging.error(f"未知错误: {e}")
return None
# 查询研报数据
def query_stock_reports(querystr='', limit=None):
try:
sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}"
if limit :
sql = sql + f' limit {limit}'
cursor.execute(sql)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
logging.error(f"查询 href 失败: {e}")
return None
'''