Настройка FSDP (Fully Sharded Data Parallel) для обучения
FSDP — нативная реализация fully sharded data parallelism в PyTorch (появилась в версии 1.11). В отличие от DeepSpeed ZeRO, FSDP является частью PyTorch core и не требует дополнительных зависимостей. Шардирует параметры, градиенты и состояние оптимизатора между GPU аналогично DeepSpeed ZeRO Stage 3.
Принцип работы
При forward pass: параметры каждого sharded layer собираются (all-gather) со всех GPU перед вычислением. После forward — немедленно освобождаются, если включён reshard_after_forward. При backward pass: параметры снова собираются, градиенты вычисляются, затем reduce-scatter распределяет шарды градиентов по GPU.
Это устраняет ситуацию, когда каждый GPU хранит полную копию модели, как в обычном DDP.
Базовая настройка
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools
def setup_fsdp(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def wrap_model_with_fsdp(model, rank):
# Политика автоматического оборачивания: шардировать слои > 100M параметров
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=100_000_000
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=False),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # Аналог ZeRO-3
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
),
)
return model
Стратегии шардирования
from torch.distributed.fsdp import ShardingStrategy
# FULL_SHARD — полное шардирование (аналог ZeRO-3)
# Максимальная экономия памяти, максимальный overhead на коммуникацию
strategy = ShardingStrategy.FULL_SHARD
# SHARD_GRAD_OP — шардирование только градиентов и оптимизатора (ZeRO-2)
# Баланс между памятью и скоростью
strategy = ShardingStrategy.SHARD_GRAD_OP
# NO_SHARD — обычный DDP без шардирования
strategy = ShardingStrategy.NO_SHARD
# HYBRID_SHARD — FULL_SHARD внутри узла, репликация между узлами
# Оптимален для multi-node с быстрым NVLink внутри узла
strategy = ShardingStrategy.HYBRID_SHARD
Wrap policy для Transformer моделей
Для трансформеров важно оборачивать каждый Transformer block отдельно:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Каждый LlamaDecoderLayer будет отдельным FSDP unit
llama_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
model = FSDP(model, auto_wrap_policy=llama_auto_wrap_policy)
Сохранение и загрузка checkpoint
С FSDP checkpoint требует специальной обработки, так как параметры шардированы:
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
# Сохранение — собираем полный state dict на rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
torch.save(cpu_state, "checkpoint.pt")
# Загрузка — загружаем на CPU, затем распределяем
if rank == 0:
state_dict = torch.load("checkpoint.pt")
else:
state_dict = {}
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
model.load_state_dict(state_dict)
FSDP vs DeepSpeed ZeRO: сравнение
| Критерий | FSDP | DeepSpeed ZeRO-3 |
|---|---|---|
| Интеграция с PyTorch | Нативная | Внешняя библиотека |
| CPU/NVMe offload | Ограниченный | Продвинутый (ZeRO-Infinity) |
| Поддержка Hugging Face | Через Accelerate | Нативная |
| Производительность | Сопоставимо | Незначительно быстрее для очень больших моделей |
| Сложность настройки | Ниже | Выше |
Интеграция с Accelerate
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
FSDP — правильный выбор для команд, работающих в экосистеме PyTorch без желания добавлять DeepSpeed как зависимость. Для LLaMA-2 70B на 8x A100 80GB FSDP FULL_SHARD обеспечивает ~800-900 tokens/s при BF16 обучении.







