Pruning (обрезка) нейросетевой модели для оптимизации
Pruning — удаление малозначимых параметров (весов, нейронов, attention heads, слоёв) из обученной нейросети. Цель — снизить размер модели и ускорить инференс при минимальной потере качества. В контексте LLM pruning часто комбинируют с квантизацией и дистилляцией для максимального сжатия.
Виды pruning
Unstructured pruning: обнуляются отдельные веса по всей матрице. Высокое сжатие, но требует sparse computation — стандартные GPU не ускоряют sparse операции «из коробки».
Structured pruning: удаляются целые структурные элементы — нейроны, attention heads, слои. Результат — реально меньшая плотная модель, которая работает быстрее на стандартном железе.
Semi-structured pruning (N:M sparsity): удаляются N весов из каждого блока M. Формат 2:4 поддерживается NVIDIA Ampere и выше на аппаратном уровне (до 2× ускорение).
LLM-Pruner: структурированный pruning LLM
# Пример использования LLM-Pruner
# pip install llm-pruner
from LLMPruner.pruner import LlamaStructuredPruner
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-7B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-7B")
pruner = LlamaStructuredPruner(
model=model,
tokenizer=tokenizer,
pruning_ratio=0.25, # Удалить 25% параметров
)
# Вычисление важности параметров на calibration data
calibration_data = ["Текст для анализа важности весов...", ...]
pruner.get_mask(calibration_data, method="taylor") # Taylor expansion importance
# Применение маски и pruning
pruned_model = pruner.prune()
SparseGPT: эффективный unstructured pruning без retraining
SparseGPT — метод, позволяющий pruning 50–60% весов LLM за несколько часов без повторного обучения:
# sparsegpt — библиотека от авторов метода
# Пример концептуального кода
from sparsegpt import SparseGPT
sparsegpt = SparseGPT(model)
sparsegpt.fasterprune(
sparsity=0.5, # 50% sparsity
prunen=2, # N в N:M
prunem=4, # M в N:M (2:4 — поддерживается аппаратно)
percdamp=0.01,
blocksize=128,
)
При 2:4 sparsity (50%) на NVIDIA A100/H100 ускорение inference на Tensor Core около 1.7–2×.
Wanda: простой и эффективный pruning
Wanda (Pruning by Weights and Activations) — один из самых эффективных методов, использующий произведение |W| × ||X|| для определения важности весов:
# Wanda проще SparseGPT, но сопоставимо по качеству
# Работает за несколько минут на 7B модели
def wanda_pruning(model, calibration_loader, sparsity=0.5):
"""Упрощённая реализация Wanda"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# Накапливаем статистику активаций
activation_norms = get_activation_norms(module, calibration_loader)
# Importance score = |W| * ||X||
importance = module.weight.abs() * activation_norms
# Pruning по threshold
threshold = torch.quantile(importance, sparsity)
mask = importance > threshold
module.weight.data *= mask
return model
Depth pruning: удаление слоёв
Для LLM средние слои часто менее критичны, чем первые и последние:
def depth_prune_llm(model, layers_to_remove: list[int]):
"""Удаление указанных decoder layers"""
# Для Llama-архитектуры
remaining_layers = [
layer for i, layer in enumerate(model.model.layers)
if i not in layers_to_remove
]
model.model.layers = torch.nn.ModuleList(remaining_layers)
return model
# Пример: удаляем 8 средних слоёв из 32 (25% depth reduction)
pruned_model = depth_prune_llm(model, layers_to_remove=list(range(12, 20)))
# Результат: 24-слойная модель из 32-слойной
Практический кейс: оптимизация edge-деплоя
Задача: дообученная Llama 3.1 8B для промышленного контроллера (ARM-сервер, 16GB RAM, нет GPU). Требование: инференс < 2с на запрос.
Стратегия оптимизации:
- GGUF Q4_K_M квантизация: 8B → 4.1GB, 8 tok/s на CPU (недостаточно)
- Depth pruning (удаление 8 слоёв из 32): -25% latency, -3% качества
- Width pruning attention heads (удаление 20% голов): -15% latency
- Повторная квантизация: GGUF Q4_K_M на pruned модели
Итоговые характеристики pruned+quantized модели:
- Размер: 3.1GB (vs 4.1GB)
- Throughput: 14 tok/s на ARM (vs 8 tok/s)
- Latency для 100-токенного ответа: 7с → 1.8с (цель достигнута)
- Потеря качества (LLM-judge): 7%
Recovery Fine-Tuning после pruning
Pruning всегда вызывает деградацию. Recovery training восстанавливает часть качества:
# После pruning — краткий fine-tuning для восстановления
from trl import SFTTrainer, SFTConfig
# Используем тот же датасет, что для fine-tuning, но с меньшим LR
recovery_config = SFTConfig(
num_train_epochs=1, # 1 эпоха для recovery
learning_rate=5e-5, # Ниже, чем при full fine-tuning
gradient_checkpointing=True,
bf16=True,
)
trainer = SFTTrainer(model=pruned_model, args=recovery_config, train_dataset=dataset)
trainer.train()
Recovery fine-tuning типично возвращает 50–70% потерянного качества при 1 эпохе обучения.
Сроки
- Выбор стратегии pruning: 3–5 дней
- Calibration и pruning: 4–24 часа (зависит от метода и размера)
- Recovery fine-tuning: 2–8 часов
- Benchmarking и оценка: 3–5 дней
- Итого: 2–4 недели







