Реализация Federated Learning для обучения моделей без передачи данных
Federated Learning — парадигма обучения ML-моделей, при которой данные остаются на устройствах клиентов (смартфоны, больничные серверы, банковские системы), а в центральный сервер передаются только обновления весов модели. Это позволяет обучать модели на чувствительных данных без их централизации.
Когда применять FL
- Медицинские данные: несколько больниц обучают модель диагностики без обмена данными пациентов
- Финансы: банки-конкуренты совместно обучают модель фрода без раскрытия транзакций
- Мобильные устройства: персонализированные модели на данных пользователей без их upload
- IoT: модели на данных промышленного оборудования, которые нельзя передавать по security причинам
FedAvg — базовый алгоритм
Federated Averaging (McMahan et al., 2017) — стандартный алгоритм FL:
- Сервер инициализирует глобальную модель $w_0$
- На каждом раунде t: сервер выбирает подмножество клиентов, рассылает текущие веса
- Каждый клиент: обучает модель на локальных данных (несколько эпох), возвращает $\Delta w_i$
- Сервер агрегирует: $w_{t+1} = \sum_i \frac{n_i}{n} w_i^t$, где $n_i$ — размер датасета клиента i
Реализация с PySyft / Flower
Flower (flwr) — наиболее зрелый open-source FL фреймворк:
import flwr as fl
import torch
from typing import Dict, List, Tuple, Optional
# Клиентская часть
class MedicalModelClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
def get_parameters(self, config) -> List[np.ndarray]:
return [param.data.numpy() for param in self.model.parameters()]
def set_parameters(self, parameters: List[np.ndarray]):
for param, new_param in zip(self.model.parameters(), parameters):
param.data = torch.tensor(new_param)
def fit(self, parameters, config) -> Tuple[List[np.ndarray], int, Dict]:
self.set_parameters(parameters)
# Локальное обучение
optimizer = torch.optim.SGD(self.model.parameters(),
lr=config.get("lr", 0.01))
local_epochs = config.get("local_epochs", 3)
self.model.train()
for epoch in range(local_epochs):
for batch in self.train_loader:
optimizer.zero_grad()
loss = self.model(batch)
loss.backward()
optimizer.step()
return self.get_parameters(config), len(self.train_loader.dataset), {}
def evaluate(self, parameters, config) -> Tuple[float, int, Dict]:
self.set_parameters(parameters)
loss, accuracy = test(self.model, self.val_loader)
return float(loss), len(self.val_loader.dataset), {"accuracy": float(accuracy)}
# Серверная часть
class FedAvgWithDP(fl.server.strategy.FedAvg):
"""FedAvg с Differential Privacy"""
def aggregate_fit(self, server_round, results, failures):
aggregated_params, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
if aggregated_params is not None:
# Добавление Gaussian noise для DP
noise_multiplier = 0.1
for param in fl.common.parameters_to_ndarrays(aggregated_params):
noise = np.random.normal(0, noise_multiplier, param.shape)
param += noise
return aggregated_params, aggregated_metrics
strategy = FedAvgWithDP(
min_fit_clients=5,
min_evaluate_clients=3,
min_available_clients=10,
fraction_fit=0.5, # 50% клиентов на каждый раунд
)
fl.server.start_server(
server_address="0.0.0.0:8080",
strategy=strategy,
config=fl.server.ServerConfig(num_rounds=50)
)
Differential Privacy в FL
DP гарантирует, что участие отдельного клиента не может быть обнаружено по глобальной модели:
from opacus import PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
epochs=local_epochs,
target_epsilon=5.0, # ε-DP параметр (меньше = приватнее)
target_delta=1e-5,
max_grad_norm=1.0, # Gradient clipping
)
Проблемы и решения
Non-IID данные — данные на разных клиентах имеют разные распределения. Решения: FedProx (добавляет proximal term), SCAFFOLD, FedNova.
Коммуникационные накладные расходы — передача весов модели при тысячах клиентов. Решения: gradient compression (Top-k sparsification), quantization (8-bit weights).
Stragglers — медленные клиенты задерживают раунд. Решения: асинхронный FL (FedAsync), таймаут на участие клиента.
Backdoor атаки — вредоносный клиент отравляет глобальную модель. Защиты: Byzantine-robust aggregation (Krum, Median), anomaly detection на обновлениях.
Метрики оценки FL системы
- Communication efficiency: количество раундов до достижения target accuracy
- Accuracy gap: разрица между centralised training и FL (обычно 1-5%)
- Privacy budget: $(\epsilon, \delta)$-DP достигнутый по итогам обучения
- Participation rate: % клиентов, успешно завершивших каждый раунд
Типичный проект: медицинский консорциум из 10 больниц обучает модель детекции рака на рентгенограммах. FL позволяет достичь AUC 0.94 — против 0.87 у лучшей отдельной больницы — без единой передачи данных пациентов.







