250 lines
9.3 KiB
Python
250 lines
9.3 KiB
Python
import sqlite3
|
|
import json
|
|
import config
|
|
import utils
|
|
import logging
|
|
import sys
|
|
from datetime import datetime
|
|
|
|
# 连接 SQLite 数据库
|
|
DB_PATH = f"{config.global_host_data_dir}/stock_report.db" # 替换为你的数据库文件
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
|
|
# 获取表的列名和默认值
|
|
def get_table_columns_and_defaults(tbl_name):
|
|
try:
|
|
cursor.execute(f"PRAGMA table_info({tbl_name})")
|
|
columns = cursor.fetchall()
|
|
column_info = {}
|
|
for col in columns:
|
|
col_name = col[1]
|
|
default_value = col[4]
|
|
column_info[col_name] = default_value
|
|
return column_info
|
|
except sqlite3.Error as e:
|
|
logging.error(f"Error getting table columns: {e}")
|
|
return None
|
|
|
|
# 检查并处理数据
|
|
def check_and_process_data(data, tbl_name):
|
|
column_info = get_table_columns_and_defaults(tbl_name=tbl_name)
|
|
if column_info is None:
|
|
return None
|
|
processed_data = {}
|
|
for col, default in column_info.items():
|
|
if col == 'id': # 自增主键,不需要用户提供
|
|
continue
|
|
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
|
|
continue
|
|
if col in ['author', 'authorID']:
|
|
values = data.get(col, [])
|
|
processed_data[col] = ','.join(values)
|
|
elif col in data:
|
|
processed_data[col] = data[col]
|
|
else:
|
|
if default is not None:
|
|
processed_data[col] = default
|
|
else:
|
|
processed_data[col] = None
|
|
return processed_data
|
|
|
|
|
|
# 插入或更新数据
|
|
def insert_or_update_common(data, tbl_name, uniq_key='infoCode'):
|
|
try:
|
|
processed_data = check_and_process_data(data, tbl_name)
|
|
if processed_data is None:
|
|
return None
|
|
|
|
columns = ', '.join(processed_data.keys())
|
|
values = list(processed_data.values())
|
|
placeholders = ', '.join(['?' for _ in values])
|
|
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
|
|
|
|
sql = f'''
|
|
INSERT INTO {tbl_name} ({columns}, updated_at)
|
|
VALUES ({placeholders}, datetime('now', 'localtime'))
|
|
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
|
|
'''
|
|
cursor.execute(sql, values)
|
|
conn.commit()
|
|
|
|
# 获取插入或更新后的 report_id
|
|
cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
|
|
report_id = cursor.fetchone()[0]
|
|
return report_id
|
|
except sqlite3.Error as e:
|
|
logging.error(f"Error inserting or updating data: {e}")
|
|
return None
|
|
|
|
|
|
# 查询数据
|
|
def query_reports_comm(tbl_name, querystr='', limit=None ):
|
|
try:
|
|
if tbl_name in [utils.tbl_stock, utils.tbl_new_stock] :
|
|
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
|
|
elif tbl_name in [utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] :
|
|
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
|
|
else:
|
|
logging.warning(f'wrong table name: {tbl_name}')
|
|
return None
|
|
|
|
if limit :
|
|
sql = sql + f' limit {limit}'
|
|
|
|
cursor.execute(sql)
|
|
results = cursor.fetchall()
|
|
|
|
# 获取列名
|
|
column_names = [description[0] for description in cursor.description]
|
|
|
|
# 将结果转换为字典列表
|
|
result_dict_list = []
|
|
for row in results:
|
|
row_dict = {column_names[i]: value for i, value in enumerate(row)}
|
|
result_dict_list.append(row_dict)
|
|
|
|
return result_dict_list
|
|
|
|
except sqlite3.Error as e:
|
|
logging.error(f"查询 href 失败: {e}")
|
|
return None
|
|
|
|
'''
|
|
# 插入或更新 industry_report 数据
|
|
def insert_or_update_report(report):
|
|
try:
|
|
sql = """
|
|
INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate,
|
|
industryCode, industryName, emRatingCode, emRatingValue,
|
|
emRatingName, attachSize, attachPages, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
|
|
ON CONFLICT(infoCode) DO UPDATE SET
|
|
title=excluded.title,
|
|
orgCode=excluded.orgCode,
|
|
orgName=excluded.orgName,
|
|
orgSName=excluded.orgSName,
|
|
publishDate=excluded.publishDate,
|
|
industryCode=excluded.industryCode,
|
|
industryName=excluded.industryName,
|
|
emRatingCode=excluded.emRatingCode,
|
|
emRatingValue=excluded.emRatingValue,
|
|
emRatingName=excluded.emRatingName,
|
|
attachSize=excluded.attachSize,
|
|
attachPages=excluded.attachPages,
|
|
updated_at=datetime('now', 'localtime')
|
|
"""
|
|
|
|
values = (
|
|
report["infoCode"], report["title"], report["orgCode"], report["orgName"],
|
|
report["orgSName"], report["publishDate"], report["industryCode"],
|
|
report["industryName"], report["emRatingCode"], report["emRatingValue"],
|
|
report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0)
|
|
)
|
|
|
|
cursor.execute(sql, values)
|
|
conn.commit()
|
|
|
|
# 获取插入或更新后的 report_id
|
|
cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],))
|
|
report_id = cursor.fetchone()[0]
|
|
return report_id
|
|
|
|
except sqlite3.Error as e:
|
|
conn.rollback()
|
|
logging.error(f"数据库错误: {e}")
|
|
return None
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logging.error(f"未知错误: {e}")
|
|
return None
|
|
|
|
# 查询研报数据
|
|
def query_industry_reports(querystr='', limit=None):
|
|
try:
|
|
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}"
|
|
if limit :
|
|
sql = sql + f' limit {limit}'
|
|
|
|
cursor.execute(sql)
|
|
results = cursor.fetchall()
|
|
return results
|
|
|
|
except sqlite3.Error as e:
|
|
logging.error(f"查询 href 失败: {e}")
|
|
return None
|
|
|
|
# 插入或更新 industry_report 数据
|
|
def insert_or_update_stock_report(report):
|
|
try:
|
|
sql = """
|
|
INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName,
|
|
publishDate, industryCode, industryName, emIndustryCode, emRatingCode,
|
|
emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
|
|
ON CONFLICT(infoCode) DO UPDATE SET
|
|
title=excluded.title,
|
|
stockName=excluded.stockName,
|
|
stockCode=excluded.stockCode,
|
|
orgCode=excluded.orgCode,
|
|
orgName=excluded.orgName,
|
|
orgSName=excluded.orgSName,
|
|
publishDate=excluded.publishDate,
|
|
industryCode=excluded.industryCode,
|
|
industryName=excluded.industryName,
|
|
emIndustryCode=excluded.emIndustryCode,
|
|
emRatingCode=excluded.emRatingCode,
|
|
emRatingValue=excluded.emRatingValue,
|
|
emRatingName=excluded.emRatingName,
|
|
attachPages=excluded.attachPages,
|
|
attachSize=excluded.attachSize,
|
|
author=excluded.author,
|
|
authorID=excluded.authorID,
|
|
updated_at=datetime('now', 'localtime')
|
|
"""
|
|
|
|
values = (
|
|
report["infoCode"], report["title"], report["stockName"], report["stockCode"],
|
|
report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"],
|
|
report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""),
|
|
report["emRatingCode"], report["emRatingValue"], report["emRatingName"],
|
|
report.get("attachPages", 0), report.get("attachSize", 0),
|
|
",".join(report.get("author", [])), ",".join(report.get("authorID", []))
|
|
)
|
|
|
|
cursor.execute(sql, values)
|
|
conn.commit()
|
|
|
|
# 获取插入或更新后的 report_id
|
|
cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],))
|
|
report_id = cursor.fetchone()[0]
|
|
return report_id
|
|
|
|
|
|
except sqlite3.Error as e:
|
|
conn.rollback()
|
|
logging.error(f"数据库错误: {e}")
|
|
return None
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logging.error(f"未知错误: {e}")
|
|
return None
|
|
|
|
|
|
# 查询研报数据
|
|
def query_stock_reports(querystr='', limit=None):
|
|
try:
|
|
sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}"
|
|
if limit :
|
|
sql = sql + f' limit {limit}'
|
|
|
|
cursor.execute(sql)
|
|
results = cursor.fetchall()
|
|
return results
|
|
|
|
except sqlite3.Error as e:
|
|
logging.error(f"查询 href 失败: {e}")
|
|
return None
|
|
|
|
''' |