import sqlite3 import json import config import utils import logging import sys from datetime import datetime # 连接 SQLite 数据库 DB_PATH = f"{config.global_share_db_dir}/stock_report.db" # 替换为你的数据库文件 conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() # 获取表的列名和默认值 def get_table_columns_and_defaults(tbl_name): try: cursor.execute(f"PRAGMA table_info({tbl_name})") columns = cursor.fetchall() column_info = {} for col in columns: col_name = col[1] default_value = col[4] column_info[col_name] = default_value return column_info except sqlite3.Error as e: logging.error(f"Error getting table columns: {e}") return None # 检查并处理数据 def check_and_process_data(data, tbl_name): column_info = get_table_columns_and_defaults(tbl_name=tbl_name) if column_info is None: return None processed_data = {} for col, default in column_info.items(): if col == 'id': # 自增主键,不需要用户提供 continue if col == 'created_at' or col == 'updated_at': # 日期函数,用户自己指定即可 continue if col in ['author', 'authorID']: values = data.get(col, []) processed_data[col] = ','.join(values) elif col in data: processed_data[col] = data[col] else: if default is not None: processed_data[col] = default else: processed_data[col] = None return processed_data # 插入或更新数据 def insert_or_update_common(data, tbl_name, uniq_key='infoCode'): try: processed_data = check_and_process_data(data, tbl_name) if processed_data is None: return None columns = ', '.join(processed_data.keys()) values = list(processed_data.values()) placeholders = ', '.join(['?' for _ in values]) update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != 'infoCode']) + ', updated_at=datetime(\'now\', \'localtime\')' sql = f''' INSERT INTO {tbl_name} ({columns}, updated_at) VALUES ({placeholders}, datetime('now', 'localtime')) ON CONFLICT (infoCode) DO UPDATE SET {update_clause} ''' cursor.execute(sql, values) conn.commit() # 获取插入或更新后的 report_id cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data["infoCode"],)) report_id = cursor.fetchone()[0] return report_id except sqlite3.Error as e: logging.error(f"Error inserting or updating data: {e}") return None # 查询数据 def query_reports_comm(tbl_name, querystr='', limit=None ): try: if tbl_name in [utils.tbl_stock, utils.tbl_new_stock, utils.tbl_industry, utils.tbl_macresearch, utils.tbl_strategy] : sql = f"SELECT id, infoCode, title, orgSName, industryName, stockName, publishDate FROM {tbl_name} WHERE 1=1 {querystr}" else: logging.warning(f'wrong table name: {tbl_name}') return None if limit : sql = sql + f' limit {limit}' cursor.execute(sql) results = cursor.fetchall() # 获取列名 column_names = [description[0] for description in cursor.description] # 将结果转换为字典列表 result_dict_list = [] for row in results: row_dict = {column_names[i]: value for i, value in enumerate(row)} result_dict_list.append(row_dict) return result_dict_list except sqlite3.Error as e: logging.error(f"查询 href 失败: {e}") return None ''' # 插入或更新 industry_report 数据 def insert_or_update_report(report): try: sql = """ INSERT INTO industry_report (infoCode, title, orgCode, orgName, orgSName, publishDate, industryCode, industryName, emRatingCode, emRatingValue, emRatingName, attachSize, attachPages, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime')) ON CONFLICT(infoCode) DO UPDATE SET title=excluded.title, orgCode=excluded.orgCode, orgName=excluded.orgName, orgSName=excluded.orgSName, publishDate=excluded.publishDate, industryCode=excluded.industryCode, industryName=excluded.industryName, emRatingCode=excluded.emRatingCode, emRatingValue=excluded.emRatingValue, emRatingName=excluded.emRatingName, attachSize=excluded.attachSize, attachPages=excluded.attachPages, updated_at=datetime('now', 'localtime') """ values = ( report["infoCode"], report["title"], report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"], report["industryCode"], report["industryName"], report["emRatingCode"], report["emRatingValue"], report["emRatingName"], report.get("attachSize", 0), report.get("attachPages", 0) ) cursor.execute(sql, values) conn.commit() # 获取插入或更新后的 report_id cursor.execute("SELECT id FROM industry_report WHERE infoCode = ?", (report["infoCode"],)) report_id = cursor.fetchone()[0] return report_id except sqlite3.Error as e: conn.rollback() logging.error(f"数据库错误: {e}") return None except Exception as e: conn.rollback() logging.error(f"未知错误: {e}") return None # 查询研报数据 def query_industry_reports(querystr='', limit=None): try: sql = f"SELECT id, infoCode, title, orgSName, industryName, publishDate FROM industry_report WHERE 1=1 {querystr}" if limit : sql = sql + f' limit {limit}' cursor.execute(sql) results = cursor.fetchall() return results except sqlite3.Error as e: logging.error(f"查询 href 失败: {e}") return None # 插入或更新 industry_report 数据 def insert_or_update_stock_report(report): try: sql = """ INSERT INTO stock_report (infoCode, title, stockName, stockCode, orgCode, orgName, orgSName, publishDate, industryCode, industryName, emIndustryCode, emRatingCode, emRatingValue, emRatingName, attachPages, attachSize, author, authorID, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now', 'localtime')) ON CONFLICT(infoCode) DO UPDATE SET title=excluded.title, stockName=excluded.stockName, stockCode=excluded.stockCode, orgCode=excluded.orgCode, orgName=excluded.orgName, orgSName=excluded.orgSName, publishDate=excluded.publishDate, industryCode=excluded.industryCode, industryName=excluded.industryName, emIndustryCode=excluded.emIndustryCode, emRatingCode=excluded.emRatingCode, emRatingValue=excluded.emRatingValue, emRatingName=excluded.emRatingName, attachPages=excluded.attachPages, attachSize=excluded.attachSize, author=excluded.author, authorID=excluded.authorID, updated_at=datetime('now', 'localtime') """ values = ( report["infoCode"], report["title"], report["stockName"], report["stockCode"], report["orgCode"], report["orgName"], report["orgSName"], report["publishDate"], report.get("industryCode", ""), report.get("industryName", ""), report.get("emIndustryCode", ""), report["emRatingCode"], report["emRatingValue"], report["emRatingName"], report.get("attachPages", 0), report.get("attachSize", 0), ",".join(report.get("author", [])), ",".join(report.get("authorID", [])) ) cursor.execute(sql, values) conn.commit() # 获取插入或更新后的 report_id cursor.execute("SELECT id FROM stock_report WHERE infoCode = ?", (report["infoCode"],)) report_id = cursor.fetchone()[0] return report_id except sqlite3.Error as e: conn.rollback() logging.error(f"数据库错误: {e}") return None except Exception as e: conn.rollback() logging.error(f"未知错误: {e}") return None # 查询研报数据 def query_stock_reports(querystr='', limit=None): try: sql = f"SELECT id, infoCode, title, orgSName, stockName, publishDate FROM stock_report WHERE 1=1 {querystr}" if limit : sql = sql + f' limit {limit}' cursor.execute(sql) results = cursor.fetchall() return results except sqlite3.Error as e: logging.error(f"查询 href 失败: {e}") return None '''