Настройка FSDP (Fully Sharded Data Parallel) для обучения

Проектируем и внедряем системы искусственного интеллекта: от прототипа до production-ready решения. Наша команда объединяет экспертизу в машинном обучении, дата-инжиниринге и MLOps, чтобы AI работал не в лаборатории, а в реальном бизнесе.
Показано 1 из 1 услугВсе 1566 услуг
Настройка FSDP (Fully Sharded Data Parallel) для обучения
Сложная
~3-5 рабочих дней
Часто задаваемые вопросы
Направления AI-разработки
Этапы разработки AI-решения
Последние работы
  • image_website-b2b-advance_0.png
    Разработка сайта компании B2B ADVANCE
    1218
  • image_web-applications_feedme_466_0.webp
    Разработка веб-приложения для компании FEEDME
    1161
  • image_websites_belfingroup_462_0.webp
    Разработка веб-сайта для компании БЕЛФИНГРУПП
    854
  • image_ecommerce_furnoro_435_0.webp
    Разработка интернет магазина для компании FURNORO
    1047
  • image_logo-advance_0.png
    Разработка логотипа компании B2B Advance
    561
  • image_crm_enviok_479_0.webp
    Разработка веб-приложения для компании Enviok
    825

Настройка FSDP (Fully Sharded Data Parallel) для обучения

FSDP — нативная реализация fully sharded data parallelism в PyTorch (появилась в версии 1.11). В отличие от DeepSpeed ZeRO, FSDP является частью PyTorch core и не требует дополнительных зависимостей. Шардирует параметры, градиенты и состояние оптимизатора между GPU аналогично DeepSpeed ZeRO Stage 3.

Принцип работы

При forward pass: параметры каждого sharded layer собираются (all-gather) со всех GPU перед вычислением. После forward — немедленно освобождаются, если включён reshard_after_forward. При backward pass: параметры снова собираются, градиенты вычисляются, затем reduce-scatter распределяет шарды градиентов по GPU.

Это устраняет ситуацию, когда каждый GPU хранит полную копию модели, как в обычном DDP.

Базовая настройка

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)
import functools

def setup_fsdp(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def wrap_model_with_fsdp(model, rank):
    # Политика автоматического оборачивания: шардировать слои > 100M параметров
    auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy,
        min_num_params=100_000_000
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        cpu_offload=CPUOffload(offload_params=False),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.FULL_SHARD,  # Аналог ZeRO-3
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.bfloat16,
        ),
    )
    return model

Стратегии шардирования

from torch.distributed.fsdp import ShardingStrategy

# FULL_SHARD — полное шардирование (аналог ZeRO-3)
# Максимальная экономия памяти, максимальный overhead на коммуникацию
strategy = ShardingStrategy.FULL_SHARD

# SHARD_GRAD_OP — шардирование только градиентов и оптимизатора (ZeRO-2)
# Баланс между памятью и скоростью
strategy = ShardingStrategy.SHARD_GRAD_OP

# NO_SHARD — обычный DDP без шардирования
strategy = ShardingStrategy.NO_SHARD

# HYBRID_SHARD — FULL_SHARD внутри узла, репликация между узлами
# Оптимален для multi-node с быстрым NVLink внутри узла
strategy = ShardingStrategy.HYBRID_SHARD

Wrap policy для Transformer моделей

Для трансформеров важно оборачивать каждый Transformer block отдельно:

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# Каждый LlamaDecoderLayer будет отдельным FSDP unit
llama_auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)

model = FSDP(model, auto_wrap_policy=llama_auto_wrap_policy)

Сохранение и загрузка checkpoint

С FSDP checkpoint требует специальной обработки, так как параметры шардированы:

from torch.distributed.fsdp import FullStateDictConfig, StateDictType

# Сохранение — собираем полный state dict на rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    cpu_state = model.state_dict()
    if rank == 0:
        torch.save(cpu_state, "checkpoint.pt")

# Загрузка — загружаем на CPU, затем распределяем
if rank == 0:
    state_dict = torch.load("checkpoint.pt")
else:
    state_dict = {}

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    model.load_state_dict(state_dict)

FSDP vs DeepSpeed ZeRO: сравнение

Критерий FSDP DeepSpeed ZeRO-3
Интеграция с PyTorch Нативная Внешняя библиотека
CPU/NVMe offload Ограниченный Продвинутый (ZeRO-Infinity)
Поддержка Hugging Face Через Accelerate Нативная
Производительность Сопоставимо Незначительно быстрее для очень больших моделей
Сложность настройки Ниже Выше

Интеграция с Accelerate

from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

FSDP — правильный выбор для команд, работающих в экосистеме PyTorch без желания добавлять DeepSpeed как зависимость. Для LLaMA-2 70B на 8x A100 80GB FSDP FULL_SHARD обеспечивает ~800-900 tokens/s при BF16 обучении.