Разработка AI-модели на базе Temporal Fusion Transformer для рынков
Temporal Fusion Transformer (TFT) — архитектура, разработанная в Google Brain специально для задач прогнозирования временных рядов с разнородными входными данными. В отличие от vanilla Transformer, TFT явно обрабатывает три типа переменных: статические (не меняются со временем), known future (известны наперёд) и unknown (наблюдаемые только до момента прогноза).
Что делает TFT особенным для финансов
Три категории входных переменных:
| Тип | Примеры для рынка | Обработка |
|---|---|---|
| Static covariates | Тикер, сектор, market cap | Static embeddings |
| Known future | Даты earnings, FOMC даты, праздники | Future encoder |
| Past observed | Returns, volume, VIX, RSI | Past encoder |
Это принципиально важно: зная, что через 5 дней будет заседание ФРС, модель должна учитывать это при прогнозе прямо сейчас. TFT делает это явно.
Variable Selection Network (VSN): Learnable веса для каждой входной переменной. Позволяет автоматически отфильтровать нерелевантные признаки и получить interpretability — какие переменные реально важны для прогноза.
Gated Residual Network (GRN): Нелинейная обработка с gate-механизмом, контролирующим насколько нелинейное преобразование применяется (gate = 0: pass-through, gate = 1: full nonlinear).
Полная архитектура TFT
Static covariates → Static Covariate Encoders
↓
Past observed → LSTM encoder ─────────────┐
├→ Multi-head Attention → GRN → Quantile Output
Known future → LSTM decoder ──────────────┘
Внутри attention: temporal self-attention, где каждый шаг прогноза может "смотреть" на релевантную историю.
Реализация для рыночных данных
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
data = prepare_market_dataframe(
tickers=['AAPL', 'MSFT', ...], # 100+ инструментов
start='2015-01-01'
)
training = TimeSeriesDataSet(
data[data.date < '2022-01-01'],
time_idx="time_idx",
target="forward_5d_return",
group_ids=["ticker"],
max_encoder_length=126, # 6 месяцев истории
max_prediction_length=5, # 5 дней прогноза
static_categoricals=["sector", "country"],
static_reals=["log_market_cap", "beta"],
time_varying_known_reals=["days_to_earnings", "fomc_flag", "vix"],
time_varying_unknown_reals=[
"return", "volume_ratio", "rsi", "atr_normalized",
"momentum_12_1", "short_interest_ratio"
],
)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.001,
hidden_size=160,
attention_head_size=4,
dropout=0.1,
hidden_continuous_size=64,
loss=QuantileLoss(quantiles=[0.1, 0.25, 0.5, 0.75, 0.9])
)
Обучение и гиперпараметры
Ключевые гиперпараметры:
-
hidden_size: 64-256 (основная ёмкость модели) -
attention_head_size: 1-4 -
max_encoder_length: 60-252 (1 квартал — 1 год) -
dropout: 0.05-0.3
Learning rate finding:
res = trainer.tuner.lr_find(
tft, train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
max_lr=0.1
)
optimal_lr = res.suggestion()
Early stopping:
from pytorch_lightning.callbacks import EarlyStopping
early_stop_callback = EarlyStopping(
monitor="val_loss", patience=10, mode="min"
)
Quantile прогнозы и их применение
TFT нативно выдаёт квантильные прогнозы (p10, p25, p50, p75, p90). Это ценно для:
Risk-based position sizing:
point_forecast = forecasts['p50']
uncertainty = forecasts['p90'] - forecasts['p10']
position_size = base_size × (1 / (uncertainty / expected_return))
Asymmetric return profiles: Если p90 − p50 >> p50 − p10 → правостороннее распределение → потенциал роста превышает риск.
Interpretability: Variable Importance
raw_predictions, x = tft.predict(val_dataloader, mode="raw", return_x=True)
interpretation = tft.interpret_output(raw_predictions, reduction="sum")
fig = tft.plot_interpretation(interpretation)
Пример результата: Variable importance показывает, что momentum_12_1 (0.22), vix (0.18) и days_to_earnings (0.15) — главные предикторы. short_interest_ratio (0.04) — незначимый.
Attention pattern visualization: модель обращает максимальное внимание на точки за 5 и 20 дней до прогноза — соответствует недельному и месячному momentum эффекту.
Benchmark против других методов
На M5 конкурсе (Walmart demand forecasting, 2020):
- TFT: ÙDL 0.1127 (топ-10%)
- LightGBM: 0.1152
- DeepAR: 0.1189
- Prophet: 0.1402
Преимущество TFT особенно выражено при наличии known future covariates и static features.
Сроки: TFT baseline для 50+ инструментов — 4-5 недель. Полноценная система с earnings calendar, macro covariates и portfolio construction — 3-4 месяца.







