modify scripts

This commit is contained in:
oscarz
2025-04-13 16:51:58 +08:00
parent 8f0c5b3eea
commit 26cd6b52ca
4 changed files with 80 additions and 11 deletions

View File

@ -14,11 +14,12 @@ commit_msg="$1"
# 如果没有提供 commit message提示用户输入 # 如果没有提供 commit message提示用户输入
if [ -z "$commit_msg" ]; then if [ -z "$commit_msg" ]; then
read -p "请输入 commit message: " commit_msg commit_msg="modify scripts"
if [ -z "$commit_msg" ]; then #read -p "请输入 commit message: " commit_msg
echo "❌ 提交信息不能为空!" #if [ -z "$commit_msg" ]; then
exit 1 # echo "❌ 提交信息不能为空!"
fi # exit 1
#fi
fi fi
# 添加所有更改 # 添加所有更改

View File

@ -89,6 +89,19 @@ class StockReportDB:
logging.error(f"Error inserting or updating data: {e}") logging.error(f"Error inserting or updating data: {e}")
return None return None
def update_pages(self, data, tbl_name, uniq_key='infoCode'):
try:
sql = f'''
update {tbl_name} SET attachPages={data['attachPages']} where id={data['id']}
'''
self.cursor.execute(sql)
self.conn.commit()
return data['id']
except sqlite3.Error as e:
logging.error(f"Error inserting or updating data: {e}")
return None
def query_reports_comm(self, tbl_name, querystr='', limit=None): def query_reports_comm(self, tbl_name, querystr='', limit=None):
try: try:
if tbl_name in [StockReportDB.TBL_STOCK, StockReportDB.TBL_NEW_STOCK, StockReportDB.TBL_INDUSTRY, StockReportDB.TBL_MACRESEARCH, StockReportDB.TBL_STRATEGY]: if tbl_name in [StockReportDB.TBL_STOCK, StockReportDB.TBL_NEW_STOCK, StockReportDB.TBL_INDUSTRY, StockReportDB.TBL_MACRESEARCH, StockReportDB.TBL_STRATEGY]:

View File

@ -14,6 +14,7 @@ import src.utils.utils as utils
from src.config.config import global_host_data_dir, global_share_db_dir from src.config.config import global_host_data_dir, global_share_db_dir
from src.db_utils.reports import StockReportDB, DatabaseConnectionError from src.db_utils.reports import StockReportDB, DatabaseConnectionError
from src.logger.logger import setup_logging from src.logger.logger import setup_logging
import PyPDF2
# 初始化日志 # 初始化日志
setup_logging() setup_logging()
@ -52,6 +53,19 @@ start_date = two_years_ago.strftime("%Y-%m-%d")
end_date = current_date.strftime("%Y-%m-%d") end_date = current_date.strftime("%Y-%m-%d")
this_week_date = seven_days_ago.strftime("%Y-%m-%d") this_week_date = seven_days_ago.strftime("%Y-%m-%d")
min_down_pages = 10
def get_pdf_page_count(pdf_path):
try:
# 以二进制只读模式打开 PDF 文件
with open(pdf_path, 'rb') as file:
# 创建一个 PdfReader 对象
pdf_reader = PyPDF2.PdfReader(file)
# 获取 PDF 文件的页数
page_count = len(pdf_reader.pages)
return page_count
except Exception as e:
logging.warning(f"处理文件 {pdf_path} 时出错: {e}")
return None
def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix): def fetch_reports_list_general(fetch_func, table_name, s_date, e_date, data_dir_prefix):
# 示例:获取前 3 页的数据 # 示例:获取前 3 页的数据
@ -111,18 +125,49 @@ def parse_func_general(row, tbl_name):
url = url.format(info_code) url = url.format(info_code)
# 拼目录 # 拼目录
dir_year = publish_date[:4] if len(publish_date)>4 else '' dir_year = publish_date[:4] if len(publish_date)>4 else ''
dir_path = f'{pdf_base_dir}/{dir_year}/{map_tbl_name[tbl_name]}' #dir_path = f'{pdf_base_dir}/{dir_year}/{map_tbl_name[tbl_name]}'
dir_path = f'{pdf_base_dir}/{dir_year}'
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
return url, os.path.join(dir_path, file_name) return url, os.path.join(dir_path, file_name)
# 检查pdf的页数如果小于限定的值则移动到其他目录
def check_pdf_pages(file_path, row, tbl_name):
pages = get_pdf_page_count(file_path)
if pages is None or pages < min_down_pages:
# 获取文件所在目录
file_dir = os.path.dirname(file_path)
# 创建 tmp 子目录
tmp_dir = os.path.join(file_dir, 'tmp')
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
# 移动文件到 tmp 子目录
file_name = os.path.basename(file_path)
new_path = os.path.join(tmp_dir, file_name)
shutil.move(file_path, new_path)
logging.debug(f"move {file_name} to {tmp_dir}")
# macro 和 stra 表,需要更新页码回去
if tbl_name == StockReportDB.TBL_MACRESEARCH or tbl_name == StockReportDB.TBL_STRATEGY:
data={}
data['infoCode'] = row['infoCode']
data['id'] = row['id']
data['attachPages'] = pages
row_id = db_tools.update_pages(data, tbl_name)
if row_id:
logging.debug(f"update one row. tbl: {tbl_name}, rowid:{row_id}")
else:
logging.warning(f"update data failed. tbl: {tbl_name}, rowid:{row['id']}")
return False
# 通用下载函数 # 通用下载函数
def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None): def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_date, e_date=end_date, limit=None, min_page=None):
# 下载pdf # 下载pdf
if s_date: if s_date:
querystr += f" AND publishDate >= '{s_date} 00:00:00.000' " querystr += f" AND publishDate >= '{s_date} 00:00:00.000' "
if e_date: if e_date:
querystr += f" AND publishDate <= '{e_date} 23:59:59.999' " querystr += f" AND publishDate <= '{e_date} 23:59:59.999' "
if min_page:
querystr += f" AND attachPages >= {min_page} "
rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit) rows = db_tools.query_reports_comm(tbl_name, querystr=querystr, limit=limit)
if rows is None: if rows is None:
@ -145,6 +190,7 @@ def download_pdf_stock_general(parse_func, tbl_name, querystr='', s_date=start_d
down = em.download_pdf(pdf_url, file_path) down = em.download_pdf(pdf_url, file_path)
if down: if down:
logging.info(f'saved file {file_path}') logging.info(f'saved file {file_path}')
check_pdf_pages(file_path, row, tbl_name)
else: else:
logging.warning(f'download pdf file error. file_path: {pdf_url}, save_path: {file_path}') logging.warning(f'download pdf file error. file_path: {pdf_url}, save_path: {file_path}')
else: else:
@ -176,13 +222,13 @@ def fetch_reports_list_strategy(s_date=start_date, e_date=end_date):
# 下载股票pdf # 下载股票pdf
def download_pdf_stock(s_date=start_date, e_date=end_date): def download_pdf_stock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STOCK, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_STOCK, ' ', s_date, e_date, limit=2 if debug else None, min_page=min_down_pages)
def download_pdf_newstock(s_date=start_date, e_date=end_date): def download_pdf_newstock(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_NEW_STOCK, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_NEW_STOCK, ' ', s_date, e_date, limit=2 if debug else None, min_page=min_down_pages)
def download_pdf_industry(s_date=start_date, e_date=end_date): def download_pdf_industry(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_INDUSTRY, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_INDUSTRY, ' ', s_date, e_date, limit=2 if debug else None, min_page=min_down_pages)
def download_pdf_macresearch(s_date=start_date, e_date=end_date): def download_pdf_macresearch(s_date=start_date, e_date=end_date):
download_pdf_stock_general(parse_func_general, StockReportDB.TBL_MACRESEARCH, ' ', s_date, e_date, limit=2 if debug else None) download_pdf_stock_general(parse_func_general, StockReportDB.TBL_MACRESEARCH, ' ', s_date, e_date, limit=2 if debug else None)
@ -263,6 +309,9 @@ def run_func(function_names, function_map):
def main(cmd, mode, args_debug, args_force, begin, end): def main(cmd, mode, args_debug, args_force, begin, end):
global debug global debug
debug = args_debug debug = args_debug
if debug:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
global force global force
force = args_force force = args_force

View File

@ -56,13 +56,19 @@ def get_stock_by_fs():
save_to_csv(df, config.global_host_input_dir, f"em_index_all.csv") save_to_csv(df, config.global_host_input_dir, f"em_index_all.csv")
return df return df
# 东财-概念板块成分股
def get_stock_by_xx(sy):
df = ak.stock_board_concept_cons_em(symbol=sy)
save_to_csv(df, config.global_host_input_dir, f"em_{sy}.csv")
def refresh_main_index(): def refresh_main_index():
#get_csindex(symbol='000300') #get_csindex(symbol='000300')
#get_csindex(symbol='000510') #get_csindex(symbol='000510')
#get_sina_index_stock(symbol='000300') #get_sina_index_stock(symbol='000300')
#get_ah_stock() #get_ah_stock()
#get_hk_ggt_components() #get_hk_ggt_components()
get_stock_by_fs() #get_stock_by_fs()
get_stock_by_xx('半导体概念')