This repository has been archived on 2026-01-07. You can view files and clone it, but cannot push or open issues or pull requests.
Files
resources/scrapy_proj/scrapy_proj/db_wapper/sqlite_base.py
2025-07-28 19:34:14 +08:00

588 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import sqlite3
import logging
from datetime import datetime
from typing import List, Dict, Optional, Any
home_dir = os.path.expanduser("~")
global_share_data_dir = f'{home_dir}/sharedata'
default_dbpath = f"{global_share_data_dir}/sqlite/scrapy.db"
shared_db_path = f"{global_share_data_dir}/sqlite/shared.db"
test_db_path = f"{global_share_data_dir}/sqlite/test.db"
# 单例元类
class SingletonMeta(type):
_instances = {} # 存储每个类的唯一实例
def __call__(cls, *args, **kwargs):
# 检查实例是否已存在,不存在则创建
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
# 数据库基类,封装了通用的操作。
class SQLiteDBHandler(metaclass=SingletonMeta): # 应用单例元类
def __init__(self, db_path=None):
# 防止重复初始化单例模式下可能被多次调用__init__
if hasattr(self, 'initialized') and self.initialized:
return
# 使用传入的 db_path 或默认路径
self.DB_PATH = db_path or default_dbpath
# 验证路径是否存在(可选)
if db_path and not os.path.exists(os.path.dirname(db_path)):
os.makedirs(os.path.dirname(db_path))
self.conn = sqlite3.connect(self.DB_PATH, check_same_thread=False)
self.conn.row_factory = sqlite3.Row # 结果集支持字典式访问
self.cursor = self.conn.cursor()
#self.conn.execute('PRAGMA journal_mode = WAL') # 启用 WAL(Write-Ahead Logging) 模式
#self.conn.commit()
self.conn.execute('PRAGMA journal_mode = DELETE') # 切换回传统模式
self.conn.commit() # 确保设置生效
# 检查 SQLite 版本
self.lower_sqlite_version = False
sqlite_version = sqlite3.sqlite_version_info
if sqlite_version < (3, 24, 0):
self.lower_sqlite_version = True
self.initialized = True # 标记初始化完成
self._column_cache = {} # 缓存表字段信息,避免重复查询
def __del__(self):
try:
self.close()
except Exception:
pass # 避免销毁时抛出异常影响程序退出
def _create_tables(self):
pass
# 接口函数,必须在各个子类中实现
def insert_item(self, item):
raise NotImplementedError("子类必须实现 insert_item 方法")
def get_table_columns_and_defaults(self, tbl_name):
"""获取表的字段信息(含默认值),并缓存结果"""
if tbl_name in self._column_cache:
return self._column_cache[tbl_name]
try:
self.cursor.execute(f"PRAGMA table_info({tbl_name})")
columns = self.cursor.fetchall()
column_info = {col[1]: col[4] for col in columns} # col[1]是字段名col[4]是默认值
self._column_cache[tbl_name] = column_info # 缓存结果
return column_info
except sqlite3.Error as e:
logging.error(f"Error getting table columns: {e}")
return None
def check_and_process_data(self, data, tbl_name):
column_info = self.get_table_columns_and_defaults(tbl_name)
if column_info is None:
return None
processed_data = {}
for col, default in column_info.items():
if col == 'id' or col == 'created_at': # 自增主键,不需要用户提供; 创建日期,使用建表默认值
continue
if col == 'updated_at': # 日期函数,用户自己指定即可
processed_data[col] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if col in data:
processed_data[col] = data[col]
return processed_data
def insert_or_update_common(self, data, tbl_name, uniq_key='url', exists_do_nothing=False):
if self.lower_sqlite_version:
return self.insert_or_update_common_lower(data, tbl_name, uniq_key, exists_do_nothing)
try:
processed_data = self.check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
# 无唯一键时直接插入
if uniq_key is None:
sql = f'''
INSERT INTO {tbl_name} ({columns})
VALUES ({placeholders})
'''
self.cursor.execute(sql, values)
self.conn.commit()
# 获取最后插入的ID
self.cursor.execute("SELECT last_insert_rowid()")
record_id = self.cursor.fetchone()[0]
return record_id
# 有唯一键时的冲突处理
if exists_do_nothing:
conflict_clause = f'ON CONFLICT ({uniq_key}) DO NOTHING'
else:
update_clause = ', '.join([f"{col}=EXCLUDED.{col}" for col in processed_data.keys() if col != uniq_key])
conflict_clause = f"ON CONFLICT ({uniq_key}) DO UPDATE SET {update_clause}"
sql = f'''
INSERT INTO {tbl_name} ({columns})
VALUES ({placeholders})
{conflict_clause}
'''
self.cursor.execute(sql, values)
self.conn.commit()
# 获取插入或更新后的记录 ID
self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data[uniq_key],))
record_id = self.cursor.fetchone()[0]
return record_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
def insert_or_update_common_lower(self, data, tbl_name, uniq_key='url', exists_do_nothing=False):
try:
processed_data = self.check_and_process_data(data, tbl_name)
if processed_data is None:
return None
columns = ', '.join(processed_data.keys())
values = list(processed_data.values())
placeholders = ', '.join(['?' for _ in values])
# 无唯一键时直接插入
if uniq_key is None:
sql = f'''
INSERT INTO {tbl_name} ({columns})
VALUES ({placeholders})
'''
self.cursor.execute(sql, values)
self.conn.commit()
# 获取最后插入的ID
self.cursor.execute("SELECT last_insert_rowid()")
record_id = self.cursor.fetchone()[0]
return record_id
# 有唯一键时的冲突处理
try:
sql = f'''
INSERT INTO {tbl_name} ({columns})
VALUES ({placeholders})
'''
self.cursor.execute(sql, values)
self.conn.commit()
except sqlite3.IntegrityError: # 唯一键冲突
if not exists_do_nothing:
update_clause = ', '.join([f"{col}=?" for col in processed_data.keys() if col != uniq_key])
update_values = [processed_data[col] for col in processed_data.keys() if col != uniq_key]
update_values.append(data[uniq_key])
sql = f"UPDATE {tbl_name} SET {update_clause} WHERE {uniq_key} = ?"
self.cursor.execute(sql, update_values)
self.conn.commit()
# 获取插入或更新后的记录 ID
self.cursor.execute(f"SELECT id FROM {tbl_name} WHERE {uniq_key} = ?", (data[uniq_key],))
record_id = self.cursor.fetchone()[0]
return record_id
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
def insert_or_update_with_composite_pk(self, data, tbl_name, composite_pk, exists_do_nothing=True):
"""
针对联合主键表执行插入或更新操作
:param table_name: 表名
:param data: 字典类型,待插入或更新的数据
:param composite_pk: 列表类型,联合主键字段名集合
:param need_update: 布尔值记录存在时是否更新默认True
:return: 操作影响的行数
"""
try:
# 校验联合主键参数有效性
if not isinstance(composite_pk, list) or len(composite_pk) < 2:
logging.error(f"联合主键必须是包含至少两个字段的列表: {composite_pk}")
return None
processed_data = self.check_and_process_data(data, tbl_name)
# 校验联合主键字段是否都在数据中存在
for pk_field in composite_pk:
if pk_field not in processed_data:
logging.error(f"联合主键字段 '{pk_field}' 未在数据中提供")
return None
# 构建查询条件
where_conditions = " AND ".join([f"{pk} = ?" for pk in composite_pk])
pk_values = [processed_data[pk] for pk in composite_pk]
# 检查记录是否存在
self.cursor.execute(
f"SELECT 1 FROM {tbl_name} WHERE {where_conditions}",
pk_values
)
exists = self.cursor.fetchone() is not None
if exists:
if exists_do_nothing:
return 0
# 构建更新字段(排除联合主键字段)
update_fields = [f for f in processed_data.keys() if f not in composite_pk]
if not update_fields:
return 0
set_clause = ", ".join([f"{field} = ?" for field in update_fields])
update_values = [processed_data[field] for field in update_fields] + pk_values
# 执行更新兼容低版本SQLite的标准语法
update_sql = f"UPDATE {tbl_name} SET {set_clause} WHERE {where_conditions}"
self.cursor.execute(update_sql, update_values)
return 1
else:
# 执行插入操作
columns = ", ".join(processed_data.keys())
placeholders = ", ".join(["?" for _ in processed_data.keys()])
insert_sql = f"INSERT INTO {tbl_name} ({columns}) VALUES ({placeholders})"
self.cursor.execute(insert_sql, list(processed_data.values()))
return 2
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
def get_id_by_key(self, tbl, uniq_key, val):
self.cursor.execute(f"SELECT id FROM {tbl} WHERE {uniq_key} = ?", (val,))
row = self.cursor.fetchone()
return row[0] if row else None
def close(self):
self.cursor.close()
self.conn.close()
def get_stat(self):
return {}
def _validate_fields(self, tbl_name: str, fields: List[str]) -> List[str]:
"""验证查询字段是否合法,返回有效字段列表"""
column_info = self.get_table_columns_and_defaults(tbl_name)
if not column_info:
return []
valid_fields = []
for field in fields:
# 处理带别名的字段(如 "pornstar as name"
match = re.match(r'^(\w+)\s+as\s+\w+$', field, re.IGNORECASE)
if match:
raw_field = match.group(1) # 提取原始字段名(如 "pornstar"
else:
raw_field = field # 普通字段(如 "id"、"name"
if raw_field in column_info:
valid_fields.append(field)
else:
logging.warning(f"无效查询字段: 表={tbl_name}, 字段={field}")
return valid_fields
def _validate_filter_fields(self, tbl_name: str, filters: Dict[str, Any],
condition_mapping: Dict[str, str]) -> Dict[str, Any]:
"""验证过滤条件中的字段是否合法,返回有效条件字典"""
column_info = self.get_table_columns_and_defaults(tbl_name)
if not column_info:
return {}
valid_filters = {}
for key, value in filters.items():
# 跳过排序和限制(单独处理)
if key in ["order_by", "limit"]:
valid_filters[key] = value
continue
# 解析字段名(处理 "__" 分隔的条件类型)
if "__" in key:
field_base, _ = key.split("__", 1)
else:
field_base = key
# 映射到表实际字段
mapped_field = condition_mapping.get(field_base, field_base)
if mapped_field in column_info:
valid_filters[key] = value
else:
logging.warning(f"无效过滤字段: 表={tbl_name}, 字段={field_base} (映射后={mapped_field})")
return valid_filters
def _validate_order_fields(self, tbl_name: str, allowed_order_fields: List[str]) -> List[str]:
"""验证排序字段是否合法,返回有效排序字段列表"""
column_info = self.get_table_columns_and_defaults(tbl_name)
if not column_info:
return []
valid_order_fields = []
for field in allowed_order_fields:
# 处理带排序方向的字段(如 "id DESC"、"name ASC"
raw_field = field.split()[0].strip() # 提取纯字段名
if raw_field in column_info:
valid_order_fields.append(field)
else:
logging.warning(f"无效排序字段: 表={tbl_name}, 字段={field}")
return valid_order_fields
def generic_query(
self,
table_name: str,
fields: List[str],
filters: Dict[str, Any],
condition_mapping: Optional[Dict[str, str]] = None,
allowed_order_fields: Optional[List[str]] = None,
simplify_single_field: bool = True # 新增参数:是否简化单字段结果
) -> Optional[List[Dict[str, Any]]]:
"""
带字段合法性校验的通用单表查询函数
新增逻辑:自动校验查询字段、过滤字段和排序字段的有效性
"""
try:
condition_mapping = condition_mapping or {}
allowed_order_fields = allowed_order_fields or []
# 1. 校验并过滤查询字段fields
valid_fields = self._validate_fields(table_name, fields)
if not valid_fields:
logging.error(f"无有效查询字段: 表={table_name}")
return None
select_fields = ", ".join(valid_fields)
is_single_field = len(valid_fields) == 1 # 判断是否单字段查询
# 2. 校验并过滤条件字段filters
valid_filters = self._validate_filter_fields(
table_name, filters, condition_mapping
)
# 3. 校验排序字段allowed_order_fields
valid_order_fields = self._validate_order_fields(
table_name, allowed_order_fields
)
# 构建SQL基础
sql = f"SELECT {select_fields} FROM {table_name} WHERE 1=1"
params = []
# 处理查询条件基于校验后的valid_filters
for key, value in valid_filters.items():
if key in ["order_by", "limit"]:
continue
if "__" in key:
field_base, condition_type = key.split("__", 1)
else:
field_base, condition_type = key, "eq"
field = condition_mapping.get(field_base, field_base)
# 生成SQL片段逻辑与之前一致
if condition_type == "eq":
sql += f" AND {field} = ?"
params.append(value)
elif condition_type == "like":
sql += f" AND {field} LIKE ?"
params.append(f"%{value}%")
elif condition_type == "in":
if isinstance(value, list):
placeholders = ", ".join(["?"] * len(value))
sql += f" AND {field} IN ({placeholders})"
params.extend(value)
else:
logging.warning(f"IN条件值必须是列表键: {key}")
elif condition_type == "not_in":
if isinstance(value, list):
placeholders = ", ".join(["?"] * len(value))
sql += f" AND {field} NOT IN ({placeholders})"
params.extend(value)
else:
logging.warning(f"NOT IN条件值必须是列表键: {key}")
elif condition_type == "gt":
sql += f" AND {field} > ?"
params.append(value)
elif condition_type == "lt":
sql += f" AND {field} < ?"
params.append(value)
else:
logging.warning(f"不支持的条件类型: {condition_type},键: {key}")
# 处理排序基于校验后的valid_order_fields
if "order_by" in valid_filters:
sql += f" ORDER BY {valid_filters["order_by"]}"
''' 加校验的这段屏蔽掉
if "order_by" in valid_filters:
order_field = valid_filters["order_by"]
# 检查排序字段是否在允许的列表中(且已通过合法性校验)
if any((order_field.startswith(valid) or valid.startswith(order_field)) for valid in valid_order_fields):
sql += f" ORDER BY {order_field}"
else:
logging.warning(f"不允许的排序字段: {order_field},表={table_name}")
'''
# 处理限制条数
if "limit" in valid_filters:
sql += " LIMIT ?"
params.append(valid_filters["limit"])
# 执行查询
self.cursor.execute(sql, params)
rows = self.cursor.fetchall()
# 4. 处理结果:单字段查询时简化为值数组
if not rows:
return [] # 空结果返回空列表
if is_single_field and simplify_single_field:
# 提取单字段的值(支持带别名的字段,如 "pornstar as name" 取 "name"
field_key = valid_fields[0].split(" as ")[-1].strip().lower()
return [row[field_key] for row in rows]
else:
# 多字段返回字典列表
return [dict(row) for row in rows]
except sqlite3.Error as e:
logging.error(f"查询失败: 表={table_name}, 错误={e}")
return None
def generic_stats_query(self, stats_config: List[Dict[str, str]]) -> Dict[str, int]:
"""
通用统计查询方法通过配置列表定义统计项自动生成并执行SQL
参数:
stats_config: 统计项配置列表,每个元素为字典,包含:
- 'table': 要统计的表名(必填)
- 'alias': 统计结果的别名(必填,如'actors''mov_full'
- 'where': 过滤条件(可选,如'uncensored=1 AND is_full_data=1'
返回:
统计结果字典键为alias值为统计数int
"""
try:
# 1. 生成子查询列表(每个统计项对应一个子查询)
subqueries = []
for config in stats_config:
table = config.get('table')
alias = config.get('alias')
where_clause = config.get('where')
# 校验必填参数
if not (table and alias):
logging.warning(f"统计项配置不完整:{config},跳过")
continue
# 构建单个子查询(如 "SELECT COUNT(*) FROM actors WHERE uncensored=1 AS act_un"
subquery = f"(SELECT COUNT(*) FROM {table}"
if where_clause:
subquery += f" WHERE {where_clause}"
subquery += f") AS {alias}"
subqueries.append(subquery)
if not subqueries:
logging.warning("无有效统计项配置,返回空结果")
return {}
# 2. 组合成完整SQL
sql = f"SELECT {', '.join(subqueries)}"
# 3. 执行查询
self.cursor.execute(sql)
row = self.cursor.fetchone()
if not row:
logging.warning("统计查询无结果")
return {}
# 4. 提取列名alias并映射结果
columns = [desc[0] for desc in self.cursor.description] # 获取别名列表
result = dict(zip(columns, row))
# 确保所有值都是整数COUNT(*)返回的是数字转换为int避免类型问题
return {k: int(v) if v is not None else 0 for k, v in result.items()}
except sqlite3.Error as e:
logging.error(f"统计查询失败: {e}")
return {}
def generic_get_record_count(
self,
table_name: str,
conditions: Optional[Dict[str, any]] = None,
condition_mapping: Optional[Dict[str, str]] = None
) -> int:
"""
通用记录数查询:查询指定表中满足条件的记录数量
参数:
table_name: 要查询的表名
conditions: 查询条件字典格式与generic_query中的filters一致
(如{'is_full_data': 1, 'url': 'xxx'}表示is_full_data=1 AND url='xxx'
condition_mapping: 字段映射同generic_query将条件键映射到表实际字段
返回:
满足条件的记录数int查询失败返回0
"""
try:
condition_mapping = condition_mapping or {}
conditions = conditions or {}
# 1. 构建基础SQL
sql = f"SELECT COUNT(*) AS cnt FROM {table_name} WHERE 1=1"
params = []
# 2. 处理查询条件(复用之前的条件解析逻辑)
for key, value in conditions.items():
# 解析条件类型支持基础的eq其他复杂条件可按需扩展
if "__" in key:
field_base, condition_type = key.split("__", 1)
else:
field_base, condition_type = key, "eq" # 默认等于
# 映射到表实际字段
field = condition_mapping.get(field_base, field_base)
# 生成条件SQL目前支持eq/like/in/not_in/gt/lt与generic_query保持一致
if condition_type == "eq":
sql += f" AND {field} = ?"
params.append(value)
elif condition_type == "like":
sql += f" AND {field} LIKE ?"
params.append(f"%{value}%")
elif condition_type == "in":
if isinstance(value, list):
placeholders = ", ".join(["?"] * len(value))
sql += f" AND {field} IN ({placeholders})"
params.extend(value)
elif condition_type == "not_in":
if isinstance(value, list):
placeholders = ", ".join(["?"] * len(value))
sql += f" AND {field} NOT IN ({placeholders})"
params.extend(value)
elif condition_type == "gt":
sql += f" AND {field} > ?"
params.append(value)
elif condition_type == "lt":
sql += f" AND {field} < ?"
params.append(value)
else:
logging.warning(f"不支持的条件类型: {condition_type},键: {key}")
# 3. 执行查询
self.cursor.execute(sql, params)
row = self.cursor.fetchone()
# 4. 解析结果确保返回整数默认0
return int(row[0]) if row and row[0] is not None else 0
except sqlite3.Error as e:
logging.error(f"记录数查询失败: 表={table_name}, 错误={e}")
return 0