modify scripts
This commit is contained in:
248
reports_em/sqlite_utils.py
Normal file
248
reports_em/sqlite_utils.py
Normal file
@ -0,0 +1,248 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import config
|
||||
import utils
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# 连接 SQLite 数据库
|
||||
DB_PATH = f"{config.global_host_data_dir}/stock_report.db" # 替换为你的数据库文件
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表的列名和默认值
|
||||
def get_table_columns_and_defaults(tbl_name):
|
||||
try:
|
||||
cursor.execute(f"PRAGMA table_info({tbl_name})")
|
||||
columns = cursor.fetchall()
|
||||
column_info = {}
|
||||
for col in columns:
|
||||
col_name = col[1]
|
||||
default_value = col[4]
|
||||
column_info[col_name] = default_value
|
||||
return column_info
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"Error getting table columns: {e}")
|
||||
return None
|
||||
|
||||
# 检查并处理数据
|
||||
def check_and_process_data(data, tbl_name):
|
||||
column_info = get_table_columns_and_defaults(tbl_name=tbl_name)
|
||||
if column_info is None:
|
||||
return None
|
||||
processed_data = {}
|
||||
for col, default in column_info.items():
|
||||
if col == 'id': # 自增主键,不需要用户提供
|
||||
continue
|
||||
if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可
|
||||
continue
|
||||
if col in ['author', 'authorID']:
|
||||
values = data.get(col, [])
|
||||
processed_data[col] = ','.join(values)
|
||||
elif col in data:
|
||||
processed_data[col] = data[col]
|
||||
else:
|
||||
if default is not None:
|
||||
processed_data[col] = default
|
||||
else:
|
||||
processed_data[col] = None
|
||||
return processed_data
|
||||
|
||||
|
||||
# 插入或更新数据
|
||||
def insert_or_update_common(data, tbl_name, uniq_key='infoCode'):
|
||||
try:
|
||||
processed_data = check_and_process_data(data, tbl_name)
|
||||
if processed_data is None:
|
||||
return None
|
||||
|
||||
columns = ', '.join(processed_data.keys())
|
||||
values = list(processed_data.values())
|
||||
placeholders = ', '.join(['?' for _ in values])
|
||||
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')'
|
||||
|
||||
sql = f'''
|
||||
INSERT INTO {tbl_name} ({columns}, updated_at)
|
||||
VALUES ({placeholders}, datetime('now', 'localtime'))
|
||||
ON CONFLICT (infoCode) DO UPDATE SET {update_clause}
|
||||
'''
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
|
||||
# 获取插入或更新后的 report_id
|
||||
cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],))
|
||||
report_id = cursor.fetchone()[0]
|
||||
return report_id
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"Error inserting or updating data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 查询数据
|
||||
def query_reports_comm(tbl_name, querystr='', limit=None ):
|
||||
try:
|
||||
if tbl_name in [utils.tbl_stock, utils.tbl_new_stock, utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] :
|
||||
sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}"
|
||||
else:
|
||||
logging.warning(f'wrong table name: {tbl_name}')
|
||||
return None
|
||||
|
||||
if limit :
|
||||
sql = sql + f' limit {limit}'
|
||||
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
|
||||
# 获取列名
|
||||
column_names = [description[0] for description in cursor.description]
|
||||
|
||||
# 将结果转换为字典列表
|
||||
result_dict_list = []
|
||||
for row in results:
|
||||
row_dict = {column_names[i]: value for i, value in enumerate(row)}
|
||||
result_dict_list.append(row_dict)
|
||||
|
||||
return result_dict_list
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"查询 href 失败: {e}")
|
||||
return None
|
||||
|
||||
'''
|
||||
# 插入或更新 industry_report 数据
|
||||
def insert_or_update_report(report):
|
||||
try:
|
||||
sql = """
|
||||
INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate,
|
||||
industryCode, industryName, emRatingCode, emRatingValue,
|
||||
emRatingName, attachSize, attachPages, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
|
||||
ON CONFLICT(infoCode) DO UPDATE SET
|
||||
title=excluded.title,
|
||||
orgCode=excluded.orgCode,
|
||||
orgName=excluded.orgName,
|
||||
orgSName=excluded.orgSName,
|
||||
publishDate=excluded.publishDate,
|
||||
industryCode=excluded.industryCode,
|
||||
industryName=excluded.industryName,
|
||||
emRatingCode=excluded.emRatingCode,
|
||||
emRatingValue=excluded.emRatingValue,
|
||||
emRatingName=excluded.emRatingName,
|
||||
attachSize=excluded.attachSize,
|
||||
attachPages=excluded.attachPages,
|
||||
updated_at=datetime('now', 'localtime')
|
||||
"""
|
||||
|
||||
values = (
|
||||
report["infoCode"], report["title"], report["orgCode"], report["orgName"],
|
||||
report["orgSName"], report["publishDate"], report["industryCode"],
|
||||
report["industryName"], report["emRatingCode"], report["emRatingValue"],
|
||||
report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0)
|
||||
)
|
||||
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
|
||||
# 获取插入或更新后的 report_id
|
||||
cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],))
|
||||
report_id = cursor.fetchone()[0]
|
||||
return report_id
|
||||
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
logging.error(f"数据库错误: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"未知错误: {e}")
|
||||
return None
|
||||
|
||||
# 查询研报数据
|
||||
def query_industry_reports(querystr='', limit=None):
|
||||
try:
|
||||
sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}"
|
||||
if limit :
|
||||
sql = sql + f' limit {limit}'
|
||||
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"查询 href 失败: {e}")
|
||||
return None
|
||||
|
||||
# 插入或更新 industry_report 数据
|
||||
def insert_or_update_stock_report(report):
|
||||
try:
|
||||
sql = """
|
||||
INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName,
|
||||
publishDate, industryCode, industryName, emIndustryCode, emRatingCode,
|
||||
emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime'))
|
||||
ON CONFLICT(infoCode) DO UPDATE SET
|
||||
title=excluded.title,
|
||||
stockName=excluded.stockName,
|
||||
stockCode=excluded.stockCode,
|
||||
orgCode=excluded.orgCode,
|
||||
orgName=excluded.orgName,
|
||||
orgSName=excluded.orgSName,
|
||||
publishDate=excluded.publishDate,
|
||||
industryCode=excluded.industryCode,
|
||||
industryName=excluded.industryName,
|
||||
emIndustryCode=excluded.emIndustryCode,
|
||||
emRatingCode=excluded.emRatingCode,
|
||||
emRatingValue=excluded.emRatingValue,
|
||||
emRatingName=excluded.emRatingName,
|
||||
attachPages=excluded.attachPages,
|
||||
attachSize=excluded.attachSize,
|
||||
author=excluded.author,
|
||||
authorID=excluded.authorID,
|
||||
updated_at=datetime('now', 'localtime')
|
||||
"""
|
||||
|
||||
values = (
|
||||
report["infoCode"], report["title"], report["stockName"], report["stockCode"],
|
||||
report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"],
|
||||
report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""),
|
||||
report["emRatingCode"], report["emRatingValue"], report["emRatingName"],
|
||||
report.get("attachPages", 0), report.get("attachSize", 0),
|
||||
",".join(report.get("author", [])), ",".join(report.get("authorID", []))
|
||||
)
|
||||
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
|
||||
# 获取插入或更新后的 report_id
|
||||
cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],))
|
||||
report_id = cursor.fetchone()[0]
|
||||
return report_id
|
||||
|
||||
|
||||
except sqlite3.Error as e:
|
||||
conn.rollback()
|
||||
logging.error(f"数据库错误: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"未知错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 查询研报数据
|
||||
def query_stock_reports(querystr='', limit=None):
|
||||
try:
|
||||
sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}"
|
||||
if limit :
|
||||
sql = sql + f' limit {limit}'
|
||||
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logging.error(f"查询 href 失败: {e}")
|
||||
return None
|
||||
|
||||
'''
|
||||
Reference in New Issue
Block a user