Оптимизация ML-модели (pruning) для мобильного устройства
Pruning — удаление части весов или нейронов из модели. Логика: в нейросети, обученной на реальных данных, значительная доля весов близка к нулю и почти не влияет на выход. Их можно обнулить или удалить без существенной потери точности, но с выигрышем в скорости и объёме.
Звучит привлекательно. На практике — pruning сложнее квантизации, требует дообучения после прореживания и не всегда даёт ожидаемое ускорение на мобильных устройствах из-за особенностей реализации.
Два вида pruning
Unstructured pruning — обнуляем отдельные веса (sparse матрицы). Матрица 90% нулей — казалось бы, 10× экономия. Но GPU/NPU работают с плотными матрицами, sparse вычисления там не ускоряются. Практическая польза: уменьшение размера модели после сжатия (нули хорошо компрессируются). Но не скорость инференса на обычных устройствах.
Structured pruning — удаляем целые фильтры (каналы) в свёрточных слоях или головы в attention. Результат — физически меньший граф, который реально быстрее на любом железе. Это то, что реально нужно для мобиля.
Structured pruning: практика на PyTorch
import torch
import torch.nn.utils.prune as prune
# L1-based structured pruning: удаляем 30% фильтров из Conv2d слоёв
# по критерию минимальной L1-нормы (наименее важные фильтры)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(
module,
name='weight',
amount=0.3, # 30% каналов
n=1, # L1 норма
dim=0 # dim=0 — выходные фильтры
)
# После pruning — важно сделать веса постоянными (убрать mask)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.remove(module, 'weight')
После этого модель содержит нулевые фильтры, но они всё ещё в графе. Следующий шаг — фактическое удаление нулевых каналов:
# Кастомная функция удаления нулевых фильтров
def remove_zero_filters(conv_layer, next_layer=None):
"""Удаляем фильтры с нулевыми весами и синхронизируем следующий слой"""
weight = conv_layer.weight.data
# Маска: фильтры с ненулевыми весами
nonzero_mask = weight.abs().sum(dim=(1,2,3)) > 1e-6
conv_layer.weight = nn.Parameter(weight[nonzero_mask])
if conv_layer.bias is not None:
conv_layer.bias = nn.Parameter(conv_layer.bias.data[nonzero_mask])
conv_layer.out_channels = nonzero_mask.sum().item()
# Синхронизируем следующий слой (входные каналы)
if next_layer is not None and isinstance(next_layer, nn.Conv2d):
next_layer.weight = nn.Parameter(next_layer.weight.data[:, nonzero_mask])
next_layer.in_channels = nonzero_mask.sum().item()
Это нужно делать осторожно — BatchNorm слои после Conv тоже содержат параметры для каждого канала и требуют синхронизации.
Fine-tuning после pruning
После удаления 20–40% фильтров модель теряет точность. Обязательный этап — fine-tuning на обучающих данных. Правило: чем агрессивнее pruning, тем дольше fine-tuning.
# Fine-tuning после pruning — обычно 10-20% от исходного числа эпох
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=1e-4) # меньший LR
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
for epoch in range(20):
train_one_epoch(pruned_model, train_loader, optimizer)
val_acc = evaluate(pruned_model, val_loader)
scheduler.step()
print(f"Epoch {epoch}: val_acc={val_acc:.4f}")
Iterative pruning — цикл pruning → fine-tuning → pruning — даёт лучший результат, чем однократное удаление большого числа фильтров.
Lottery Ticket Hypothesis: глубже
Для задач, где результат критичен, используем Lottery Ticket подход: обучаем полную сеть, находим «выигрышные билеты» — sparse subnetworks, которые можно обучить до сопоставимой точности с нуля. Реализация через torch_pruning библиотеку:
import torch_pruning as tp
# Анализ зависимостей между слоями
example_inputs = torch.zeros(1, 3, 224, 224)
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=example_inputs)
# Получаем группы связанных слоёв (pruning одного требует pruning связанных)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=tp.importance.MagnitudeImportance(p=1),
pruning_ratio=0.5, # удалить 50% каналов
global_pruning=False,
iterative_steps=5 # итеративно за 5 шагов
)
Почему pruning не всегда даёт ускорение
MobileNetV3 уже оптимизирован: depthwise separable convolutions с малым числом каналов. Удалить 30% фильтров из слоя с 16 каналами — получаем 11 каналов. Разница в скорости — минимальная, overhead от tensor операций остаётся.
Pruning хорошо работает на больших моделях: ResNet-50, EfficientNet-B4, BERT. На уже компактных MobileNet/EfficientNet-lite — эффект ниже. В таких случаях лучше начать с более лёгкой базовой архитектуры, а не прунить тяжёлую.
Комбинация с квантизацией
Pruning + квантизация — стандартная двухшаговая оптимизация:
- Structured pruning 30–40% → fine-tuning → уменьшаем граф
- INT8 квантизация сжатого графа → финальная модель
Пример результата: EfficientNet-B0 (20 МБ FP32, 80 мс Android) → pruning 35% + INT8 → 4 МБ, 18 мс. Точность top-1 упала с 77.1% до 75.8%.
Процесс
Анализ модели на pruning-пригодность → выбор критерия и степени прореживания → итеративный pruning + fine-tuning → проверка точности → опционально: квантизация → замеры на целевых устройствах.
Ориентиры по срокам
Structured pruning с fine-tuning на готовом датасете — 2–4 недели. Итеративный pruning с полным экспериментальным циклом — 4–8 недель.







