diff --git a/src/sqlalchemy/alembic.ini b/src/sqlalchemy/alembic.ini new file mode 100644 index 0000000..0d7a00e --- /dev/null +++ b/src/sqlalchemy/alembic.ini @@ -0,0 +1,147 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/sqlalchemy/config.py b/src/sqlalchemy/config.py new file mode 100644 index 0000000..ad31e00 --- /dev/null +++ b/src/sqlalchemy/config.py @@ -0,0 +1,12 @@ +import os +from pathlib import Path + +# MySQL 配置 +db_config = { + 'host': 'testdb', + 'user': 'root', + 'password': 'mysqlpw', + 'database': 'stockdb' +} + +global_db_url = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:3306/{db_config['database']}?charset=utf8mb4" \ No newline at end of file diff --git a/src/sqlalchemy/migrations/stockdb/README b/src/sqlalchemy/migrations/stockdb/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/src/sqlalchemy/migrations/stockdb/alembic.ini b/src/sqlalchemy/migrations/stockdb/alembic.ini new file mode 100644 index 0000000..0d7a00e --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/alembic.ini @@ -0,0 +1,147 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/sqlalchemy/migrations/stockdb/env.py b/src/sqlalchemy/migrations/stockdb/env.py new file mode 100644 index 0000000..ac2fe60 --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/env.py @@ -0,0 +1,99 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = None + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +''' 修改点 +from models.modelclass_b import Base +target_metadata = Base.metadata + +def run_migrations_online(): + url = "sqlite:///../databases/db_b.db" + connectable = create_engine(url) + # 保持其他代码不变 +''' + +import os +from alembic import context +from sqlalchemy import create_engine +from logging.config import fileConfig +from models.stockdb import Base +target_metadata = Base.metadata +from config import db_config + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + """ + + url = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:3306/{db_config['database']}?charset=utf8mb4" + connectable = create_engine(url) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/sqlalchemy/migrations/stockdb/script.py.mako b/src/sqlalchemy/migrations/stockdb/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/src/sqlalchemy/migrations/stockdb/versions/0a39f1baab38_auto_update_from_stockdb.py b/src/sqlalchemy/migrations/stockdb/versions/0a39f1baab38_auto_update_from_stockdb.py new file mode 100644 index 0000000..c532c29 --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/versions/0a39f1baab38_auto_update_from_stockdb.py @@ -0,0 +1,79 @@ +"""Auto update from stockdb + +Revision ID: 0a39f1baab38 +Revises: df56ac7669f2 +Create Date: 2025-07-31 16:05:20.579725 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0a39f1baab38' +down_revision: Union[str, Sequence[str], None] = 'df56ac7669f2' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('daily_snap', + sa.Column('code', sa.String(length=20), nullable=False, comment='股票代码'), + sa.Column('curr_date', sa.String(length=20), nullable=False, comment='交易日期'), + sa.Column('name', sa.String(length=50), nullable=True, comment='股票名称'), + sa.Column('market_id', sa.String(length=50), nullable=True, comment='市场名称'), + sa.Column('code_prefix', sa.Integer(), nullable=True, comment='代码前缀'), + sa.Column('industry', sa.String(length=50), nullable=True, comment='所处行业'), + sa.Column('listing_date', sa.Date(), nullable=True, comment='上市时间'), + sa.Column('latest_price', sa.Float(), nullable=True, comment='最新价'), + sa.Column('price_change_percent', sa.Float(), nullable=True, comment='涨跌幅(%)'), + sa.Column('price_change', sa.Float(), nullable=True, comment='涨跌额'), + sa.Column('volume', sa.Float(), nullable=True, comment='成交量'), + sa.Column('turnover', sa.Float(), nullable=True, comment='成交额'), + sa.Column('amplitude', sa.Float(), nullable=True, comment='振幅(%)'), + sa.Column('turnover_rate', sa.Float(), nullable=True, comment='换手率(%)'), + sa.Column('pe_dynamic', sa.Float(), nullable=True, comment='市盈率动'), + sa.Column('volume_ratio', sa.Float(), nullable=True, comment='量比'), + sa.Column('change_5min', sa.Float(), nullable=True, comment='5分钟涨跌(%)'), + sa.Column('highest', sa.Float(), nullable=True, comment='最高'), + sa.Column('lowest', sa.Float(), nullable=True, comment='最低'), + sa.Column('opening', sa.Float(), nullable=True, comment='今开'), + sa.Column('previous_close', sa.Float(), nullable=True, comment='昨收'), + sa.Column('price_speed', sa.Float(), nullable=True, comment='涨速(%)'), + sa.Column('total_market_cap', sa.Float(), nullable=True, comment='总市值'), + sa.Column('circulating_market_cap', sa.Float(), nullable=True, comment='流通市值'), + sa.Column('pb_ratio', sa.Float(), nullable=True, comment='市净率'), + sa.Column('change_60d', sa.Float(), nullable=True, comment='60日涨跌幅(%)'), + sa.Column('change_ytd', sa.Float(), nullable=True, comment='年初至今涨跌幅(%)'), + sa.Column('weighted_roe', sa.Float(), nullable=True, comment='加权净资产收益率(%)'), + sa.Column('total_shares', sa.Float(), nullable=True, comment='总股本'), + sa.Column('circulating_shares', sa.Float(), nullable=True, comment='已流通股份'), + sa.Column('operating_revenue', sa.Float(), nullable=True, comment='营业收入'), + sa.Column('revenue_growth', sa.Float(), nullable=True, comment='营业收入同比增长(%)'), + sa.Column('net_profit', sa.Float(), nullable=True, comment='归属净利润'), + sa.Column('net_profit_growth', sa.Float(), nullable=True, comment='归属净利润同比增长(%)'), + sa.Column('undistributed_profit_per_share', sa.Float(), nullable=True, comment='每股未分配利润'), + sa.Column('gross_margin', sa.Float(), nullable=True, comment='毛利率(%)'), + sa.Column('asset_liability_ratio', sa.Float(), nullable=True, comment='资产负债率(%)'), + sa.Column('reserve_per_share', sa.Float(), nullable=True, comment='每股公积金'), + sa.Column('earnings_per_share', sa.Float(), nullable=True, comment='每股收益'), + sa.Column('net_asset_per_share', sa.Float(), nullable=True, comment='每股净资产'), + sa.Column('pe_static', sa.Float(), nullable=True, comment='市盈率静'), + sa.Column('pe_ttm', sa.Float(), nullable=True, comment='市盈率TTM'), + sa.Column('report_period', sa.String(length=20), nullable=True, comment='报告期'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='记录创建时间'), + sa.PrimaryKeyConstraint('code', 'curr_date'), + comment='股票交易数据表' + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('daily_snap') + # ### end Alembic commands ### diff --git a/src/sqlalchemy/migrations/stockdb/versions/838eec3adc23_auto_update_from_stockdb.py b/src/sqlalchemy/migrations/stockdb/versions/838eec3adc23_auto_update_from_stockdb.py new file mode 100644 index 0000000..1f06a2e --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/versions/838eec3adc23_auto_update_from_stockdb.py @@ -0,0 +1,38 @@ +"""Auto update from stockdb + +Revision ID: 838eec3adc23 +Revises: 0a39f1baab38 +Create Date: 2025-07-31 16:56:37.635401 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '838eec3adc23' +down_revision: Union[str, Sequence[str], None] = '0a39f1baab38' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('futu_trading_days', + sa.Column('market', sa.String(length=20), nullable=False, comment='市场标识'), + sa.Column('trade_date', sa.Date(), nullable=False, comment='交易日日期'), + sa.Column('trade_date_type', sa.String(length=20), nullable=False, comment='交易日类型'), + sa.PrimaryKeyConstraint('market', 'trade_date'), + comment='富途证券交易日历表' + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('futu_trading_days') + # ### end Alembic commands ### diff --git a/src/sqlalchemy/migrations/stockdb/versions/9680e7b8e29b_auto_update_from_stockdb.py b/src/sqlalchemy/migrations/stockdb/versions/9680e7b8e29b_auto_update_from_stockdb.py new file mode 100644 index 0000000..eb6309a --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/versions/9680e7b8e29b_auto_update_from_stockdb.py @@ -0,0 +1,87 @@ +"""Auto update from stockdb + +Revision ID: 9680e7b8e29b +Revises: 838eec3adc23 +Create Date: 2025-07-31 17:35:37.108765 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9680e7b8e29b' +down_revision: Union[str, Sequence[str], None] = '838eec3adc23' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('daily_snap', + sa.Column('code', sa.String(length=20), nullable=False, comment='股票代码'), + sa.Column('curr_date', sa.String(length=20), nullable=False, comment='交易日期'), + sa.Column('name', sa.String(length=50), nullable=True, comment='股票名称'), + sa.Column('market_id', sa.String(length=50), nullable=True, comment='市场名称'), + sa.Column('code_prefix', sa.Integer(), nullable=True, comment='代码前缀'), + sa.Column('industry', sa.String(length=50), nullable=True, comment='所处行业'), + sa.Column('listing_date', sa.Date(), nullable=True, comment='上市时间'), + sa.Column('latest_price', sa.Float(), nullable=True, comment='最新价'), + sa.Column('price_change_percent', sa.Float(), nullable=True, comment='涨跌幅(%)'), + sa.Column('price_change', sa.Float(), nullable=True, comment='涨跌额'), + sa.Column('volume', sa.Float(), nullable=True, comment='成交量'), + sa.Column('turnover', sa.Float(), nullable=True, comment='成交额'), + sa.Column('amplitude', sa.Float(), nullable=True, comment='振幅(%)'), + sa.Column('turnover_rate', sa.Float(), nullable=True, comment='换手率(%)'), + sa.Column('pe_dynamic', sa.Float(), nullable=True, comment='市盈率动'), + sa.Column('volume_ratio', sa.Float(), nullable=True, comment='量比'), + sa.Column('change_5min', sa.Float(), nullable=True, comment='5分钟涨跌(%)'), + sa.Column('highest', sa.Float(), nullable=True, comment='最高'), + sa.Column('lowest', sa.Float(), nullable=True, comment='最低'), + sa.Column('opening', sa.Float(), nullable=True, comment='今开'), + sa.Column('previous_close', sa.Float(), nullable=True, comment='昨收'), + sa.Column('price_speed', sa.Float(), nullable=True, comment='涨速(%)'), + sa.Column('total_market_cap', sa.Float(), nullable=True, comment='总市值'), + sa.Column('circulating_market_cap', sa.Float(), nullable=True, comment='流通市值'), + sa.Column('pb_ratio', sa.Float(), nullable=True, comment='市净率'), + sa.Column('change_60d', sa.Float(), nullable=True, comment='60日涨跌幅(%)'), + sa.Column('change_ytd', sa.Float(), nullable=True, comment='年初至今涨跌幅(%)'), + sa.Column('weighted_roe', sa.Float(), nullable=True, comment='加权净资产收益率(%)'), + sa.Column('total_shares', sa.Float(), nullable=True, comment='总股本'), + sa.Column('circulating_shares', sa.Float(), nullable=True, comment='已流通股份'), + sa.Column('operating_revenue', sa.Float(), nullable=True, comment='营业收入'), + sa.Column('revenue_growth', sa.Float(), nullable=True, comment='营业收入同比增长(%)'), + sa.Column('net_profit', sa.Float(), nullable=True, comment='归属净利润'), + sa.Column('net_profit_growth', sa.Float(), nullable=True, comment='归属净利润同比增长(%)'), + sa.Column('undistributed_profit_per_share', sa.Float(), nullable=True, comment='每股未分配利润'), + sa.Column('gross_margin', sa.Float(), nullable=True, comment='毛利率(%)'), + sa.Column('asset_liability_ratio', sa.Float(), nullable=True, comment='资产负债率(%)'), + sa.Column('reserve_per_share', sa.Float(), nullable=True, comment='每股公积金'), + sa.Column('earnings_per_share', sa.Float(), nullable=True, comment='每股收益'), + sa.Column('net_asset_per_share', sa.Float(), nullable=True, comment='每股净资产'), + sa.Column('pe_static', sa.Float(), nullable=True, comment='市盈率静'), + sa.Column('pe_ttm', sa.Float(), nullable=True, comment='市盈率TTM'), + sa.Column('report_period', sa.String(length=20), nullable=True, comment='报告期'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='记录创建时间'), + sa.PrimaryKeyConstraint('code', 'curr_date'), + comment='股票交易数据表' + ) + op.create_table('futu_trading_days', + sa.Column('market', sa.String(length=20), nullable=False, comment='市场标识'), + sa.Column('trade_date', sa.Date(), nullable=False, comment='交易日日期'), + sa.Column('trade_date_type', sa.String(length=20), nullable=False, comment='交易日类型'), + sa.PrimaryKeyConstraint('market', 'trade_date'), + comment='富途证券交易日历表' + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('futu_trading_days') + op.drop_table('daily_snap') + # ### end Alembic commands ### diff --git a/src/sqlalchemy/migrations/stockdb/versions/df56ac7669f2_auto_update_from_stockdb.py b/src/sqlalchemy/migrations/stockdb/versions/df56ac7669f2_auto_update_from_stockdb.py new file mode 100644 index 0000000..0823c7e --- /dev/null +++ b/src/sqlalchemy/migrations/stockdb/versions/df56ac7669f2_auto_update_from_stockdb.py @@ -0,0 +1,32 @@ +"""Auto update from stockdb + +Revision ID: df56ac7669f2 +Revises: +Create Date: 2025-07-31 16:04:32.216988 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'df56ac7669f2' +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('daily_snap', sa.Column('market_id', sa.String(length=50), nullable=True, comment='市场名称')) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('daily_snap', 'market_id') + # ### end Alembic commands ### diff --git a/src/sqlalchemy/models/base.py b/src/sqlalchemy/models/base.py new file mode 100644 index 0000000..05a0fb1 --- /dev/null +++ b/src/sqlalchemy/models/base.py @@ -0,0 +1,23 @@ +from sqlalchemy import Column, Integer, DateTime +from sqlalchemy.ext.declarative import declarative_base +from datetime import datetime + +# 所有模型的基类 +Base = declarative_base() + +class BaseModel(Base): + __abstract__ = True # 抽象类,不生成实际表 + + # 通用字段 + id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID") + created_at = Column(DateTime, default=datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间") + + # 通用方法:将模型实例转为字典(方便JSON序列化) + def to_dict(self, exclude=None): + exclude = exclude or [] + return { + c.name: getattr(self, c.name) + for c in self.__table__.columns + if c.name not in exclude + } \ No newline at end of file diff --git a/src/sqlalchemy/models/stockdb.py b/src/sqlalchemy/models/stockdb.py new file mode 100644 index 0000000..30401e6 --- /dev/null +++ b/src/sqlalchemy/models/stockdb.py @@ -0,0 +1,448 @@ +from sqlalchemy import BigInteger, Date, DateTime, Double, Float, Integer, String, text, Numeric +from sqlalchemy.dialects.mysql import TINYINT, VARCHAR +from typing import Optional + +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +import datetime +import decimal + +class Base(DeclarativeBase): + pass + + +class FutuMarketSnapshot(Base): + __tablename__ = 'futu_market_snapshot' + + code: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''"), comment='股票代码') + name: Mapped[str] = mapped_column(String(255), server_default=text("''"), comment='股票名称') + update_time: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True, comment='当前价更新时间') + last_price: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='最新价格') + open_price: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='今日开盘价') + high_price: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='最高价格') + low_price: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='最低价格') + prev_close_price: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='昨收盘价格') + volume: Mapped[int] = mapped_column(BigInteger, server_default=text("'0'"), comment='成交数量') + turnover: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='成交金额') + turnover_rate: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='换手率') + suspension: Mapped[int] = mapped_column(TINYINT(1), server_default=text("'0'"), comment='是否停牌') + listing_date: Mapped[datetime.date] = mapped_column(Date, server_default=text("'1970-01-01'"), comment='上市日期') + equity_valid: Mapped[int] = mapped_column(TINYINT(1), server_default=text("'0'"), comment='是否正股') + issued_shares: Mapped[int] = mapped_column(BigInteger, server_default=text("'0'"), comment='总股本') + total_market_val: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='总市值') + net_asset: Mapped[int] = mapped_column(BigInteger, server_default=text("'0'"), comment='资产净值') + net_profit: Mapped[int] = mapped_column(BigInteger, server_default=text("'0'"), comment='净利润') + earning_per_share: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='每股盈利') + outstanding_shares: Mapped[int] = mapped_column(BigInteger, server_default=text("'0'"), comment='流通股本') + net_asset_per_share: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='每股净资产') + circular_market_val: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='流通市值') + ey_ratio: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='收益率') + pe_ratio: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='市盈率') + pb_ratio: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='市净率') + pe_ttm_ratio: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='市盈率 TTM') + dividend_ttm: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='股息 TTM') + dividend_ratio_ttm: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='股息率 TTM') + dividend_lfy: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='股息 LFY') + dividend_lfy_ratio: Mapped[float] = mapped_column(Float, server_default=text("'0'"), comment='股息率 LFY') + + +class FutuPlatList(Base): + __tablename__ = 'futu_plat_list' + + up_date: Mapped[datetime.date] = mapped_column(Date, primary_key=True) + market: Mapped[str] = mapped_column(String(50)) + plat: Mapped[str] = mapped_column(String(50), primary_key=True, server_default=text("'INDUSTRY'")) + code: Mapped[str] = mapped_column(String(50), primary_key=True) + plate_id: Mapped[str] = mapped_column(String(50)) + plate_name: Mapped[str] = mapped_column(String(255)) + + +class FutuRehab(Base): + __tablename__ = 'futu_rehab' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + ex_div_date: Mapped[datetime.date] = mapped_column(Date, primary_key=True) + forward_adj_factorA: Mapped[float] = mapped_column(Float) + forward_adj_factorB: Mapped[float] = mapped_column(Float) + backward_adj_factorA: Mapped[float] = mapped_column(Float) + backward_adj_factorB: Mapped[float] = mapped_column(Float) + + +class Hs300(Base): + __tablename__ = 'hs300' + + up_date: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + index_code: Mapped[str] = mapped_column(String(100), server_default=text("'000300'")) + index_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + index_name_eng: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code_inner: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + code_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code_name_eng: Mapped[str] = mapped_column(String(100), server_default=text("''")) + exchange: Mapped[str] = mapped_column(String(100), server_default=text("''")) + exchange_eng: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code: Mapped[str] = mapped_column(String(100), server_default=text("''")) + + +class Hs3003yearsYieldStats2410(Base): + __tablename__ = 'hs300_3years_yield_stats_2410' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + year_diff: Mapped[int] = mapped_column(Integer, primary_key=True) + max_yield_rate: Mapped[float] = mapped_column(Float) + max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate: Mapped[float] = mapped_column(Float) + min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + avg_yield_rate: Mapped[float] = mapped_column(Float) + median_yield_rate: Mapped[float] = mapped_column(Float) + win_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate: Mapped[float] = mapped_column(Float) + annual_min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_avg_yield_rate: Mapped[float] = mapped_column(Float) + annual_median_yield_rate: Mapped[float] = mapped_column(Float) + max_deficit_days: Mapped[int] = mapped_column(Integer) + max_deficit_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_deficit_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_yield_variance: Mapped[float] = mapped_column(Float) + + +class Hs3005yearsYieldStats2410(Base): + __tablename__ = 'hs300_5years_yield_stats_2410' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + year_diff: Mapped[int] = mapped_column(Integer, primary_key=True) + max_yield_rate: Mapped[float] = mapped_column(Float) + max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate: Mapped[float] = mapped_column(Float) + min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + avg_yield_rate: Mapped[float] = mapped_column(Float) + median_yield_rate: Mapped[float] = mapped_column(Float) + win_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate: Mapped[float] = mapped_column(Float) + annual_min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_avg_yield_rate: Mapped[float] = mapped_column(Float) + annual_median_yield_rate: Mapped[float] = mapped_column(Float) + max_deficit_days: Mapped[int] = mapped_column(Integer) + max_deficit_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_deficit_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_yield_variance: Mapped[float] = mapped_column(Float) + + +class Hs300AjustKline202410(Base): + __tablename__ = 'hs300_ajust_kline_202410' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + time_key: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True) + hfq_open: Mapped[float] = mapped_column(Float) + hfq_close: Mapped[float] = mapped_column(Float) + qfq_open: Mapped[float] = mapped_column(Float) + qfq_close: Mapped[float] = mapped_column(Float) + none_open: Mapped[float] = mapped_column(Float) + none_close: Mapped[float] = mapped_column(Float) + + +class Hs300HisKlineHfq(Base): + __tablename__ = 'hs300_his_kline_hfq' + + code: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + time_key: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True) + open: Mapped[float] = mapped_column(Float) + close: Mapped[float] = mapped_column(Float) + high: Mapped[float] = mapped_column(Float) + low: Mapped[float] = mapped_column(Float) + pe_ratio: Mapped[Optional[float]] = mapped_column(Float) + turnover_rate: Mapped[Optional[float]] = mapped_column(Float) + volume: Mapped[Optional[int]] = mapped_column(BigInteger) + turnover: Mapped[Optional[decimal.Decimal]] = mapped_column(Double(asdecimal=True)) + change_rate: Mapped[Optional[float]] = mapped_column(Float) + last_close: Mapped[Optional[float]] = mapped_column(Float) + + +class Hs300HisKlineNone(Base): + __tablename__ = 'hs300_his_kline_none' + + code: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + time_key: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True) + open: Mapped[float] = mapped_column(Float) + close: Mapped[float] = mapped_column(Float) + high: Mapped[float] = mapped_column(Float) + low: Mapped[float] = mapped_column(Float) + pe_ratio: Mapped[Optional[float]] = mapped_column(Float) + turnover_rate: Mapped[Optional[float]] = mapped_column(Float) + volume: Mapped[Optional[int]] = mapped_column(BigInteger) + turnover: Mapped[Optional[decimal.Decimal]] = mapped_column(Double(asdecimal=True)) + change_rate: Mapped[Optional[float]] = mapped_column(Float) + last_close: Mapped[Optional[float]] = mapped_column(Float) + + +class Hs300QfqHis(Base): + __tablename__ = 'hs300_qfq_his' + + code: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + time_key: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True) + open: Mapped[float] = mapped_column(Float) + close: Mapped[float] = mapped_column(Float) + high: Mapped[float] = mapped_column(Float) + low: Mapped[float] = mapped_column(Float) + pe_ratio: Mapped[Optional[float]] = mapped_column(Float) + turnover_rate: Mapped[Optional[float]] = mapped_column(Float) + volume: Mapped[Optional[int]] = mapped_column(BigInteger) + turnover: Mapped[Optional[decimal.Decimal]] = mapped_column(Double(asdecimal=True)) + change_rate: Mapped[Optional[float]] = mapped_column(Float) + last_close: Mapped[Optional[float]] = mapped_column(Float) + + +class IndexHk(Base): + __tablename__ = 'index_hk' + + up_date: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + index_code: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("'000300'")) + index_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code_inner: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + code_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + weight: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code: Mapped[str] = mapped_column(String(100), server_default=text("''")) + + +class IndexHs(Base): + __tablename__ = 'index_hs' + + up_date: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + index_code: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("'000300'")) + index_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + index_name_eng: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code_inner: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + code_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code_name_eng: Mapped[str] = mapped_column(String(100), server_default=text("''")) + exchange: Mapped[str] = mapped_column(String(100), server_default=text("''")) + exchange_eng: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code: Mapped[str] = mapped_column(String(100), server_default=text("''")) + + +class IndexUs(Base): + __tablename__ = 'index_us' + + up_date: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + index_code: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("'000300'")) + index_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code_inner: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + code_name: Mapped[str] = mapped_column(String(100), server_default=text("''")) + weight: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code: Mapped[str] = mapped_column(String(100), server_default=text("''")) + + +class Sp500(Base): + __tablename__ = 'sp500' + + sp_no: Mapped[int] = mapped_column(Integer, server_default=text("'0'")) + code_name: Mapped[str] = mapped_column(VARCHAR(100), server_default=text("''")) + code_inner: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("''")) + sector: Mapped[str] = mapped_column(String(100), server_default=text("''")) + code: Mapped[str] = mapped_column(String(100), server_default=text("''")) + up_date: Mapped[str] = mapped_column(String(100), primary_key=True, server_default=text("'2024-10-02'")) + + +class Sp5003yearsYieldStats2410(Base): + __tablename__ = 'sp500_3years_yield_stats_2410' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + year_diff: Mapped[int] = mapped_column(Integer, primary_key=True) + max_yield_rate: Mapped[float] = mapped_column(Float) + max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate: Mapped[float] = mapped_column(Float) + min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + avg_yield_rate: Mapped[float] = mapped_column(Float) + median_yield_rate: Mapped[float] = mapped_column(Float) + win_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate: Mapped[float] = mapped_column(Float) + annual_min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_avg_yield_rate: Mapped[float] = mapped_column(Float) + annual_median_yield_rate: Mapped[float] = mapped_column(Float) + max_deficit_days: Mapped[int] = mapped_column(Integer) + max_deficit_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_deficit_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_yield_variance: Mapped[float] = mapped_column(Float) + + +class Sp5005yearsYieldStats2410(Base): + __tablename__ = 'sp500_5years_yield_stats_2410' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + year_diff: Mapped[int] = mapped_column(Integer, primary_key=True) + max_yield_rate: Mapped[float] = mapped_column(Float) + max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate: Mapped[float] = mapped_column(Float) + min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + avg_yield_rate: Mapped[float] = mapped_column(Float) + median_yield_rate: Mapped[float] = mapped_column(Float) + win_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate: Mapped[float] = mapped_column(Float) + annual_max_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_max_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate: Mapped[float] = mapped_column(Float) + annual_min_yield_rate_start: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_min_yield_rate_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_avg_yield_rate: Mapped[float] = mapped_column(Float) + annual_median_yield_rate: Mapped[float] = mapped_column(Float) + max_deficit_days: Mapped[int] = mapped_column(Integer) + max_deficit_start: Mapped[datetime.datetime] = mapped_column(DateTime) + max_deficit_end: Mapped[datetime.datetime] = mapped_column(DateTime) + annual_yield_variance: Mapped[float] = mapped_column(Float) + + +class Sp500AjustKline202410(Base): + __tablename__ = 'sp500_ajust_kline_202410' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + time_key: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True) + hfq_open: Mapped[float] = mapped_column(Float) + hfq_close: Mapped[float] = mapped_column(Float) + qfq_open: Mapped[float] = mapped_column(Float) + qfq_close: Mapped[float] = mapped_column(Float) + none_open: Mapped[float] = mapped_column(Float) + none_close: Mapped[float] = mapped_column(Float) + + +class Sp500HisKlineNone(Base): + __tablename__ = 'sp500_his_kline_none' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + time_key: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True) + open: Mapped[float] = mapped_column(Float) + high: Mapped[float] = mapped_column(Float) + low: Mapped[float] = mapped_column(Float) + close: Mapped[float] = mapped_column(Float) + volume: Mapped[int] = mapped_column(BigInteger) + dividends: Mapped[float] = mapped_column(Float) + stock_splits: Mapped[float] = mapped_column(Float) + adj_close: Mapped[float] = mapped_column(Float) + + +class Sp500QfqHis202410(Base): + __tablename__ = 'sp500_qfq_his_202410' + + code: Mapped[str] = mapped_column(String(100), primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + time_key: Mapped[datetime.datetime] = mapped_column(DateTime, primary_key=True) + open: Mapped[float] = mapped_column(Float) + high: Mapped[float] = mapped_column(Float) + low: Mapped[float] = mapped_column(Float) + close: Mapped[float] = mapped_column(Float) + volume: Mapped[int] = mapped_column(BigInteger) + dividends: Mapped[float] = mapped_column(Float) + stock_splits: Mapped[float] = mapped_column(Float) + adj_close: Mapped[float] = mapped_column(Float) + + + +class DailySanpModel(Base): + """股票数据模型类(使用mapped_column风格)""" + __tablename__ = 'daily_snap' + __table_args__ = {"comment": "股票交易数据表"} + + # 联合主键:股票代码+报告期 + code: Mapped[str] = mapped_column(String(20), primary_key=True, comment='股票代码') + curr_date: Mapped[str] = mapped_column(String(20), primary_key=True, comment='交易日期') + + # 基本信息字段 + name: Mapped[str | None] = mapped_column(String(50), comment='股票名称') + market_id: Mapped[str | None] = mapped_column(String(50), comment='市场名称') + code_prefix: Mapped[int | None] = mapped_column(Integer, comment='代码前缀') + industry: Mapped[str | None] = mapped_column(String(50), comment='所处行业') + listing_date: Mapped[datetime.date | None] = mapped_column(Date, comment='上市时间') + + # 交易数据字段 + latest_price: Mapped[float | None] = mapped_column(Float, comment='最新价') + price_change_percent: Mapped[float | None] = mapped_column(Float, comment='涨跌幅(%)') + price_change: Mapped[float | None] = mapped_column(Float, comment='涨跌额') + volume: Mapped[float | None] = mapped_column(Float, comment='成交量') + turnover: Mapped[float | None] = mapped_column(Float, comment='成交额') + amplitude: Mapped[float | None] = mapped_column(Float, comment='振幅(%)') + turnover_rate: Mapped[float | None] = mapped_column(Float, comment='换手率(%)') + pe_dynamic: Mapped[float | None] = mapped_column(Float, comment='市盈率动') + volume_ratio: Mapped[float | None] = mapped_column(Float, comment='量比') + change_5min: Mapped[float | None] = mapped_column(Float, comment='5分钟涨跌(%)') + highest: Mapped[float | None] = mapped_column(Float, comment='最高') + lowest: Mapped[float | None] = mapped_column(Float, comment='最低') + opening: Mapped[float | None] = mapped_column(Float, comment='今开') + previous_close: Mapped[float | None] = mapped_column(Float, comment='昨收') + price_speed: Mapped[float | None] = mapped_column(Float, comment='涨速(%)') + + # 市值数据字段 + total_market_cap: Mapped[float | None] = mapped_column(Float, comment='总市值') + circulating_market_cap: Mapped[float | None] = mapped_column(Float, comment='流通市值') + pb_ratio: Mapped[float | None] = mapped_column(Float, comment='市净率') + change_60d: Mapped[float | None] = mapped_column(Float, comment='60日涨跌幅(%)') + change_ytd: Mapped[float | None] = mapped_column(Float, comment='年初至今涨跌幅(%)') + + # 财务数据字段 + weighted_roe: Mapped[float | None] = mapped_column(Float, comment='加权净资产收益率(%)') + total_shares: Mapped[float | None] = mapped_column(Float, comment='总股本') + circulating_shares: Mapped[float | None] = mapped_column(Float, comment='已流通股份') + operating_revenue: Mapped[float | None] = mapped_column(Float, comment='营业收入') + revenue_growth: Mapped[float | None] = mapped_column(Float, comment='营业收入同比增长(%)') + net_profit: Mapped[float | None] = mapped_column(Float, comment='归属净利润') + net_profit_growth: Mapped[float | None] = mapped_column(Float, comment='归属净利润同比增长(%)') + undistributed_profit_per_share: Mapped[float | None] = mapped_column(Float, comment='每股未分配利润') + gross_margin: Mapped[float | None] = mapped_column(Float, comment='毛利率(%)') + asset_liability_ratio: Mapped[float | None] = mapped_column(Float, comment='资产负债率(%)') + reserve_per_share: Mapped[float | None] = mapped_column(Float, comment='每股公积金') + earnings_per_share: Mapped[float | None] = mapped_column(Float, comment='每股收益') + net_asset_per_share: Mapped[float | None] = mapped_column(Float, comment='每股净资产') + pe_static: Mapped[float | None] = mapped_column(Float, comment='市盈率静') + pe_ttm: Mapped[float | None] = mapped_column(Float, comment='市盈率TTM') + report_period: Mapped[str | None] = mapped_column(String(20), comment='报告期') + + # 记录创建时间 + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.now, comment='记录创建时间') + + +class FutuTradingDayModel(Base): + """富途交易日历模型类""" + __tablename__ = "futu_trading_days" + __table_args__ = {"comment": "富途证券交易日历表"} + + # 联合主键:market + trade_date + market: Mapped[str] = mapped_column( + String(20), + primary_key=True, + comment="市场标识" + ) + trade_date: Mapped[datetime.date] = mapped_column( + Date, + primary_key=True, + comment="交易日日期" + ) + trade_date_type: Mapped[str] = mapped_column( + String(20), + comment="交易日类型" + ) diff --git a/src/sqlalchemy/readme.txt b/src/sqlalchemy/readme.txt new file mode 100644 index 0000000..c3ee1e4 --- /dev/null +++ b/src/sqlalchemy/readme.txt @@ -0,0 +1,13 @@ + +# 从数据库生成模型类 +sqlacodegen --outfile=models/stockdb.py "mysql+pymysql://root:mysqlpw@testdb:3306/stockdb?charset=utf8mb4" + +# 初始化 +mkdir migrations +alembic init migrations/stockdb +# 修改 alembic.ini 脚本,以及 env.py 导入 models + +# 同步修改到数据库(读取 models/shared.py ) +./scripts/sync_stockdb.sh + +### 对视图支持不好,主要是视图的字段没有类型,所以在导入导出时会出错,慎用! \ No newline at end of file diff --git a/src/sqlalchemy/scripts/sync_stockdb.sh b/src/sqlalchemy/scripts/sync_stockdb.sh new file mode 100755 index 0000000..aee0b93 --- /dev/null +++ b/src/sqlalchemy/scripts/sync_stockdb.sh @@ -0,0 +1,7 @@ +#!/bin/bash +cd $(dirname $0)/.. + +alembic -c migrations/stockdb/alembic.ini revision --autogenerate -m "Auto update from stockdb" +alembic -c migrations/stockdb/alembic.ini upgrade head + +echo "数据库 stockdb 同步完成" \ No newline at end of file diff --git a/src/static/daily_snap_em.py b/src/static/daily_snap_em.py new file mode 100644 index 0000000..51dbf5a --- /dev/null +++ b/src/static/daily_snap_em.py @@ -0,0 +1,231 @@ +""" +Script Name: +Description: 获取沪深300成分股的最新股价, 并计算年内涨幅, 924以来的涨幅, 市盈率, 股息率等。 + 调用em历史数据接口。 + +Author: [Your Name] +Created Date: YYYY-MM-DD +Last Modified: YYYY-MM-DD +Version: 1.0 + +Modification History: + - YYYY-MM-DD [Your Name]: + - YYYY-MM-DD [Your Name]: + - YYYY-MM-DD [Your Name]: +""" + +import pymysql +import logging +import csv +import os +import re +import time +import pandas as pd +import numpy as np +from datetime import datetime +import argparse +import src.crawling.stock_hist_em as his_em +import src.logger.logger as logger +from src.config.config import global_stock_data_dir +from src.crawler.zixuan.xueqiu_zixuan import XueQiuStockFetcher +from src.sqlalchemy.models.stockdb import DailySanpModel, Base +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from src.sqlalchemy.config import global_db_url + +# 配置日志 +logger.setup_logging() + +current_date = datetime.now().strftime("%Y%m%d") +current_year = datetime.now().strftime("%Y") + +res_dir = global_stock_data_dir + +# 刷新代码列表,并返回 +def flush_code_map(): + code_id_map_em_df = his_em.code_id_map_em() + print(code_id_map_em_df) + return code_id_map_em_df + +# 获取所有市场的当年股价快照,带重试机制。 +def fetch_snap_all(max_retries: int = 3) -> pd.DataFrame: + # 检查文件是否存在 + file_name = f'{res_dir}/snapshot_em_{current_date}.csv' + if os.path.exists(file_name): + try: + # 读取本地文件 + snap_data = pd.read_csv(file_name, encoding='utf-8') + logging.info(f"load snapshot data from local: {file_name}\n\n") + return snap_data + except Exception as e: + logging.warning(f"读取本地文件失败: {e},将重新拉取数据\n\n") + + # 拉取数据 + market_fs = {"cn": "m:0 t:6,m:0 t:80,m:1 t:2,m:1 t:23,m:0 t:81 s:2048", + "hk": "m:128 t:3,m:128 t:4,m:128 t:1,m:128 t:2", + "us": "m:105,m:106,m:107"} + + result = pd.DataFrame() + for market_id, fs in market_fs.items(): + df = his_em.stock_zh_a_spot_em(fs, fs_desc=market_id) + if df.empty: + logging.warning(f'{market_id} empty data. please check.') + return pd.DataFrame() + else: + logging.info(f'get {market_id} stock snapshot. stock count: {len(df)}') + # 关键步骤:添加market_id列,值为当前市场标识 + df['market_id'] = market_id # 新增一列,记录数据所属市场 + result = pd.concat([result, df], ignore_index=True) + + result.to_csv(file_name, index=False, encoding='utf-8') + logging.info(f"get snapshot data and write to file: {file_name}\n\n") + + return result + + +def load_xueqiu_codes(): + # 替换为你的实际cookie + USER_COOKIES = "u=5682299253; HMACCOUNT=AA6F9D2598CE96D7; xq_is_login=1; snbim_minify=true; _c_WBKFRo=BuebJX5KAbPh1PGBVFDvQTV7x7VF8W2cvWtaC99v; _nb_ioWEgULi=; cookiesu=661740133906455; device_id=fbe0630e603f726742fec4f9a82eb5fb; s=b312165egu; bid=1f3e6ffcb97fd2d9b4ddda47551d4226_m7fv1brw; Hm_lvt_1db88642e346389874251b5a1eded6e3=1751852390; xq_a_token=a0fd17a76966314ab80c960412f08e3fffb3ec0f; xqat=a0fd17a76966314ab80c960412f08e3fffb3ec0f; xq_id_token=eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJ1aWQiOjU2ODIyOTkyNTMsImlzcyI6InVjIiwiZXhwIjoxNzU0NzAzMjk5LCJjdG0iOjE3NTIxMTEyOTkyODYsImNpZCI6ImQ5ZDBuNEFadXAifQ.Vbs-LDgB4bCJI2N644DwfeptdcamKsAm2hbXxlPnJ_0fnTJhXp6T-2Gc6b6jmhTjXJIsWta8IuS0rQBB1L-9fKpUliNFHkv4lr7FW2x7QhrZ1D4lrvjihgBxKHq8yQl31uO6lmUOJkoRaS4LM1pmkSL_UOVyw8aUeuVjETFcJR1HFDHwWpHCLM8kY55fk6n1gEgDZnYNh1_FACqlm6LU4Vq14wfQgyF9sfrGzF8rxXX0nns_j-Dq2k8vN3mknh8yUHyzCyq6Sfqn6NeVdR0vPOciylyTtNq5kOUBFb8uJe48aV2uLGww3dYV8HbsgqW4k0zam3r3QDErfSRVIg-Usw; xq_r_token=1b73cbfb47fcbd8e2055ca4a6dc7a08905dacd7d; Hm_lpvt_1db88642e346389874251b5a1eded6e3=1752714700; is_overseas=0; ssxmod_itna=QqfxBD2D9DRQPY5i7YYxiwS4GhDYu0D0dGMD3qiQGglDFqAPKDHKm=lerDUhGr5h044VYmkTtDlxWeDZDG9dDqx0orXU7BB411D+iENYYe2GG+=3X0xOguYo7I=xmAkwKhSSIXNG2A+DnmeDQKDoxGkDivoD0IYwDiiTx0rD0eDPxDYDG4mDDvvQ84DjmEmFfoGImAeQIoDbORhz74DROdDS73A+IoGqW3Da1A3z8RGDmKDIhjozmoDFOL3Yq0k54i3Y=Ocaq0OZ+BGR0gvh849m1xkHYRr/oRCYQD4KDx5qAxOx20Z3isrfDxRvt70KGitCH4N4DGbh5gYH7x+GksdC58CNR3sx=1mt2qxkGd+QmoC5ZGYdixKG52q4iiqPj53js4D; ssxmod_itna2=QqfxBD2D9DRQPY5i7YYxiwS4GhDYu0D0dGMD3qiQGglDFqAPKDHKm=lerDUhGr5h044VYmkwYDioSBbrtN4=Htz/DUihxz=w4aD" + + # 初始化获取器 + fetcher = XueQiuStockFetcher( + cookies=USER_COOKIES, + size=1000, + retry_count=3 + ) + all_codes = [] + stocks = fetcher.get_stocks_by_group( + category=1, # 股票 + pid=-1 # 全部 + ) + if stocks: + for item in stocks: + code = item['symbol'] + mkt = item['marketplace'] + + if mkt: + if mkt.lower() == 'cn': + code = format_stock_code(code) + elif mkt.lower() == 'hk': + code = f"HK.{code}" + else: + code = f"US.{code}" + + all_codes.append({'code': code, 'code_name': item['name']}) + + return all_codes + + +def insert_stock_data_to_db(dataframe, db_url=global_db_url): + """ + 将pandas DataFrame中的股票数据插入到MySQL数据库 + + 参数: + dataframe: 包含股票数据的pandas DataFrame + db_url: 数据库连接字符串,格式如'mysql+mysqldb://user:password@host:port/dbname?charset=utf8mb4' + """ + # 创建数据库引擎 + engine = create_engine(db_url) + + # 创建数据表(如果不存在) + Base.metadata.create_all(engine) + + # 创建会话 + Session = sessionmaker(bind=engine) + session = Session() + + # 注意:pandas中NaN在数值列用np.nan,字符串列用pd.NA,统一替换为None + dataframe = dataframe.replace({np.nan: None, pd.NA: None}) + try: + count_insert = 0 + count_update = 0 + # 遍历DataFrame的每一行 + for _, row in dataframe.iterrows(): + # 先检查 code 是否存在且有效 + if not row.get('代码'): + logging.warning(f"警告:发现无效的 code 值,跳过该行数据。行数据:{row['名称']}") + continue # 跳过无效行 + + # 创建股票数据对象 + stock = DailySanpModel( + code=row['代码'], + curr_date=current_date, # TODO: 怎么判断当前的数据是哪一天的? 要看当前时间是否已经开盘,还是在盘前,还是前一个交易日的? + name=row['名称'], + market_id=row['market_id'], + code_prefix=row['代码前缀'], + industry=row['所处行业'], + listing_date=pd.to_datetime(row['上市时间']).date() if row['上市时间'] else None, + + latest_price=row['最新价'], + price_change_percent=row['涨跌幅'], + price_change=row['涨跌额'], + volume=row['成交量'], + turnover=row['成交额'], + amplitude=row['振幅'], + turnover_rate=row['换手率'], + pe_dynamic=row['市盈率动'], + volume_ratio=row['量比'], + change_5min=row['5分钟涨跌'], + highest=row['最高'], + lowest=row['最低'], + opening=row['今开'], + previous_close=row['昨收'], + price_speed=row['涨速'], + + total_market_cap=row['总市值'], + circulating_market_cap=row['流通市值'], + pb_ratio=row['市净率'], + change_60d=row['60日涨跌幅'], + change_ytd=row['年初至今涨跌幅'], + + weighted_roe=row['加权净资产收益率'], + total_shares=row['总股本'], + circulating_shares=row['已流通股份'], + operating_revenue=row['营业收入'], + revenue_growth=row['营业收入同比增长'], + net_profit=row['归属净利润'], + net_profit_growth=row['归属净利润同比增长'], + undistributed_profit_per_share=row['每股未分配利润'], + gross_margin=row['毛利率'], + asset_liability_ratio=row['资产负债率'], + reserve_per_share=row['每股公积金'], + earnings_per_share=row['每股收益'], + net_asset_per_share=row['每股净资产'], + pe_static=row['市盈率静'], + pe_ttm=row['市盈率TTM'], + report_period=row['报告期'] + ) + # 2. 执行merge:存在则更新,不存在则插入 + merged_stock = session.merge(stock) + + # 3. 统计插入/更新数量 + if merged_stock in session.new: # 新插入 + count_insert += 1 + elif merged_stock in session.dirty: # 已更新 + count_update += 1 + + # 提交事务 + session.commit() + logging.info(f"成功插入 {count_insert} 条,更新 {count_update} 条数据") + + except Exception as e: + # 发生错误时回滚 + session.rollback() + logging.warning(f"插入数据失败: {str(e)}") + finally: + # 关闭会话 + session.close() + + +def main(): + # 获取快照数据 + snap_data = fetch_snap_all() + if snap_data.empty: + logging.error(f"fetching snapshot data error!") + return + em_code_map = {row['代码']: row['代码前缀'] for _, row in snap_data.iterrows()} + + insert_stock_data_to_db(dataframe=snap_data) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/static/get_futu_trading_days.py b/src/static/get_futu_trading_days.py new file mode 100644 index 0000000..d5f56b6 --- /dev/null +++ b/src/static/get_futu_trading_days.py @@ -0,0 +1,140 @@ +from futu import * +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from src.sqlalchemy.models.stockdb import Base, FutuTradingDayModel +import datetime +import logging +import numpy as np +import src.logger.logger as logger +from src.sqlalchemy.config import global_db_url + +# 配置日志 +logger.setup_logging() + +def get_current_year(): + """获取当前年份""" + return datetime.datetime.now().year + +def get_year_date_range(year): + """获取指定年份的年初和年末日期""" + start_date = f"{year}-01-01" + end_date = f"{year}-12-31" + return start_date, end_date + +def init_db(db_url=global_db_url): + """初始化数据库连接和表结构""" + engine = create_engine(db_url) + # 创建表(如果不存在) + Base.metadata.create_all(engine) + return sessionmaker(bind=engine) + +def save_trading_days_to_db(session, market, trading_days): + """ + 将交易日历数据保存到数据库 + 存在则更新,不存在则插入 + """ + count_insert = 0 + count_update = 0 + + for day in trading_days: + # 转换日期字符串为date对象 + trade_date = datetime.datetime.strptime(day['time'], '%Y-%m-%d').date() + + # 查询是否已存在 + existing = session.query(FutuTradingDayModel).filter( + FutuTradingDayModel.market == market, + FutuTradingDayModel.trade_date == trade_date + ).first() + + if existing: + # 存在则更新类型 + if existing.trade_date_type != day['trade_date_type']: + existing.trade_date_type = day['trade_date_type'] + count_update += 1 + else: + # 不存在则插入新记录 + new_day = FutuTradingDayModel( + market=market, + trade_date=trade_date, + trade_date_type=day['trade_date_type'] + ) + session.add(new_day) + count_insert += 1 + + return count_insert, count_update + +def get_futu_trading_days(market, desc, db_url=global_db_url, year=None): + """ + 获取指定市场的交易日历并保存到数据库 + :param market: 市场类型(如TradeDateMarket.HK) + :param db_url: 数据库连接字符串 + :param year: 年份,默认当前年 + """ + # 确定年份 + target_year = year or get_current_year() + start_date, end_date = get_year_date_range(target_year) + logging.info(f"开始获取 {market} 市场 {target_year} 年的交易日历({start_date} 至 {end_date})") + + # 初始化富途连接 + quote_ctx = None + try: + quote_ctx = OpenQuoteContext(host='127.0.0.1', port=11111) + + # 请求交易日历 + ret, data = quote_ctx.request_trading_days( + market=market, + start=start_date, + end=end_date + ) + + if ret != RET_OK: + logging.error(f"获取交易日历失败: {data}") + return False + + logging.info(f"成功获取 {len(data)} 条交易日数据") + + # 初始化数据库会话 + Session = init_db(db_url) + session = Session() + + try: + # 保存到数据库 + count_insert, count_update = save_trading_days_to_db( + session=session, + market=desc, # 存储枚举值(如'HK') + trading_days=data + ) + session.commit() + logging.info(f"数据库操作完成:新增 {count_insert} 条,更新 {count_update} 条") + return True + + except Exception as e: + session.rollback() + logging.error(f"数据库操作失败: {str(e)}") + return False + finally: + session.close() + + except Exception as e: + logging.error(f"连接富途API失败: {str(e)}") + return False + finally: + # 确保连接关闭 + if quote_ctx: + quote_ctx.close() + +if __name__ == "__main__": + # 示例:获取港股(HK)和美股(US)的当前年交易日历 + markets = { + 'CN' : TradeDateMarket.CN, # 港股 + 'HK' : TradeDateMarket.HK, # 港股 + 'US' : TradeDateMarket.US # 美股 + } + + for desc, market in markets.items(): + get_futu_trading_days( + market=market, + desc=desc, + # db_url=DB_URL, + # year=2024 # 可选:指定年份,默认当前年 + ) diff --git a/src/static/trading_day.py b/src/static/trading_day.py new file mode 100644 index 0000000..deccf94 --- /dev/null +++ b/src/static/trading_day.py @@ -0,0 +1,122 @@ +from datetime import datetime, time, date, timedelta +from zoneinfo import ZoneInfo +from sqlalchemy.orm import Session +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from src.sqlalchemy.models.stockdb import FutuTradingDayModel +from src.sqlalchemy.config import global_db_url + + +# 市场时间配置(东八区 Asia/Shanghai) +# 新增盘前开始时间(pre_market_start)作为判断临界点 +MARKET_HOURS = { + # A股(沪深):盘前集合竞价9:15开始,实际交易9:30-15:00 + "CN": { + "pre_market_start": time(9, 15), # 盘前开始时间(集合竞价) + "morning_start": time(9, 30), + "morning_end": time(11, 30), + "afternoon_start": time(13, 0), + "afternoon_end": time(15, 0) + }, + # 港股:盘前竞价9:00开始,实际交易9:30-16:00 + "HK": { + "pre_market_start": time(9, 0), # 盘前开始时间 + "morning_start": time(9, 30), + "morning_end": time(12, 0), + "afternoon_start": time(13, 0), + "afternoon_end": time(16, 0) + }, + # 美股:盘前交易4:00-9:30(纽约时间),对应东八区夏令时16:00-21:30,冬令时17:00-22:30 + "US": { + # 东八区盘前开始时间(夏令时/冬令时) + "dst_pre_start": time(16, 0), # 夏令时盘前开始 + "std_pre_start": time(17, 0), # 冬令时盘前开始 + # 东八区交易时间(同之前) + "dst_start": time(21, 30), + "dst_end": time(4, 0), + "std_start": time(22, 30), + "std_end": time(5, 0) + } +} + +def get_trading_date(market: str, db_session: Session) -> str: + """ + 根据新逻辑返回交易日期: + 1. 若当前日期是交易日,且当前时间 >= 盘前开始时间 → 取当前交易日 + 2. 否则(非交易日,或交易日但未到盘前时间) → 取前一交易日 + + 参数: + market: 市场标识(CN=A股, HK=港股, US=美股) + db_session: SQLAlchemy数据库会话对象 + 返回: + 交易日期字符串(YYYY-MM-DD) + """ + # 1. 获取东八区当前时间和日期 + tz_sh = ZoneInfo("Asia/Shanghai") + now = datetime.now(tz_sh) # 含时区的当前时间 + current_date: date = now.date() # 东八区当前日期(仅日期) + current_time: time = now.time() # 东八区当前时间(仅时间) + + + # 2. 判断当前日期是否为该市场的交易日 + def is_trading_day(market: str, date: date, session: Session) -> bool: + """检查指定日期是否为该市场的交易日""" + return session.query(FutuTradingDayModel).filter( + FutuTradingDayModel.market == market, + FutuTradingDayModel.trade_date == date + ).first() is not None + + + # 3. 获取该市场的盘前开始时间(东八区) + def get_pre_market_start(market: str, now: datetime) -> time: + """根据市场和当前时间(判断夏令时)返回盘前开始时间""" + if market == "US": + # 美股需根据纽约时区判断夏令时 + tz_ny = ZoneInfo("America/New_York") + now_ny = now.astimezone(tz_ny) + is_dst = now_ny.dst() != timedelta(0) # 夏令时判断 + return MARKET_HOURS["US"]["dst_pre_start"] if is_dst else MARKET_HOURS["US"]["std_pre_start"] + else: + # A股/港股直接返回配置的盘前时间 + return MARKET_HOURS[market]["pre_market_start"] + + + # 4. 核心逻辑判断 + current_is_trading_day = is_trading_day(market, current_date, db_session) + pre_market_start = get_pre_market_start(market, now) + is_after_pre_market = current_time >= pre_market_start # 当前时间是否过了盘前开始时间 + + + # 5. 确定查询条件 + if current_is_trading_day and is_after_pre_market: + # 情况1:当前是交易日,且已过盘前时间 → 取当前交易日 + target_date = current_date + else: + # 情况2:非交易日,或未到盘前时间 → 取前一交易日 + # 查询小于当前日期的最大交易日 + prev_trading_day = db_session.query(FutuTradingDayModel.trade_date).filter( + FutuTradingDayModel.market == market, + FutuTradingDayModel.trade_date < current_date + ).order_by(FutuTradingDayModel.trade_date.desc()).first() + + if not prev_trading_day: + raise ValueError(f"未查询到{market}市场的前一交易日数据") + target_date = prev_trading_day.trade_date + + + return target_date.strftime("%Y-%m-%d") + +# 示例:获取各市场的目标交易日期 + +if __name__ == "__main__": + engine = create_engine(global_db_url) + Session = sessionmaker(bind=engine) + db_session = Session() + + try: + # 分别获取三个市场的交易日期 + print(f"A股目标交易日期:{get_trading_date('CN', db_session)}") + print(f"港股目标交易日期:{get_trading_date('HK', db_session)}") + print(f"美股目标交易日期:{get_trading_date('US', db_session)}") + finally: + db_session.close() \ No newline at end of file