Neural Architecture Search (NAS)
Вручную спроектированные архитектуры — результат опыта и интуиции. NAS — алгоритмический перебор архитектурного пространства с оптимизацией под конкретную задачу, датасет и hardware target. Не замена архитектурного мышления, а способ найти конфигурации, до которых человек просто не дойдёт за разумное время.
Почему наивный NAS убивает GPU-бюджет
Классический NAS в исполнении NASNet (Google, 2017) — 500 GPU-дней на A100-эквиваленте. Проблема в том, что каждая кандидатная архитектура обучалась с нуля до сходимости. При пространстве поиска в 10^10 конфигураций полный перебор невозможен в принципе.
Современные подходы решают это через три принципиально разные идеи:
One-shot NAS / Weight Sharing. Суперсеть (supernet) включает все возможные подграфы. Каждый кандидат — «путь» через эту суперсеть, который использует уже обученные веса. DARTS, SNAS, Single-Path NAS — все они строятся на этой идее. Время поиска падает с сотен GPU-дней до 1–4 дней.
Predictor-based NAS. Обучается surrogate-модель, которая предсказывает accuracy архитектуры без её полного обучения. BANANAS, NASBOWL, NAO используют этот подход. Выборка из пространства поиска + 100–200 реальных оценок → предиктор точности для следующих миллиона кандидатов.
Hardware-aware NAS. Оптимизация не только по accuracy, но по latency на конкретном устройстве. MNasNet, FBNet, Once-for-All — ищут Pareto-front в пространстве (accuracy, latency/MACs). Критично для edge deployment.
Глубокий разбор: DARTS и его проблемы в production
DARTS (Differentiable Architecture Search) — наиболее используемый one-shot метод. Идея: вместо дискретного выбора операции (3×3 conv vs 5×5 conv vs skip) используем непрерывные веса α для каждой операции, оптимизируемые через gradient descent.
import torch
import torch.nn as nn
from torch.nn import functional as F
class MixedOp(nn.Module):
"""
DARTS mixed operation: взвешенная сумма всех кандидатных операций.
Веса alpha оптимизируются через архитектурный градиент.
"""
def __init__(self, C: int, stride: int):
super().__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES: # ['none', 'skip_connect', 'sep_conv_3x3', ...]
op = OPS[primitive](C, stride, affine=False)
self._ops.append(op)
def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
# weights = softmax(alpha) — архитектурные веса
return sum(w * op(x) for w, op in zip(weights, self._ops))
class DARTSCell(nn.Module):
def __init__(self, steps: int, multiplier: int, C_prev_prev: int,
C_prev: int, C: int, reduction: bool, reduction_prev: bool):
super().__init__()
self._steps = steps # число промежуточных узлов (обычно 4)
self._multiplier = multiplier # сколько узлов конкатенируется на выходе
# ... инициализация preprocess и mixed ops
def forward(self, s0: torch.Tensor, s1: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
states = [s0, s1]
offset = 0
for i in range(self._steps):
s = sum(
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
)
offset += len(states)
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
Двухуровневая оптимизация DARTS — главная инженерная сложность. Сетевые веса w и архитектурные веса α оптимизируются попеременно:
def train_darts_step(model, architect, optimizer_w, optimizer_alpha,
train_queue, valid_queue, lr_w: float):
"""
DARTS: чередование шагов оптимизации весов сети и архитектурных весов.
"""
for step, (input_train, target_train) in enumerate(train_queue):
# 1. Архитектурный шаг: обновляем alpha по валидационной потере
input_valid, target_valid = next(iter(valid_queue))
architect.step(
input_train, target_train,
input_valid, target_valid,
lr=lr_w, optimizer=optimizer_w,
unrolled=False # True = second-order DARTS, в 2x дороже
)
# 2. Шаг весов: обновляем w по тренировочной потере
optimizer_w.zero_grad()
logits = model(input_train)
loss = F.cross_entropy(logits, target_train)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer_w.step()
Проблема коллапса операций. В чистом DARTS skip-connection операции почти всегда «побеждают» — у них нулевые параметры, они хорошо обучаются на ранних этапах, и архитектурные веса α[skip] устойчиво растут. Результат: найденная архитектура вырождается в почти skip-only сеть с плохим обобщением. Решения:
- DARTS+: отсечение skip-connections с наибольшим α на финальном этапе
- P-DARTS: прогрессивное увеличение глубины сети во время поиска
- GDAS: Gumbel-softmax вместо softmax для α — разреженный выбор операций
Hardware-aware NAS на практике
Для мобильного деплоя (Android, CoreML) accuracy — не единственная метрика. Latency на целевом железе важнее FLOP-подсчёта, потому что разные операции выполняются по-разному на реальном железе.
Once-for-All (MIT) — обучается одна суперсеть, из которой без дообучения извлекаются подсети под любой hardware constraint:
from ofa.model_zoo import ofa_net
# Загружаем предобученную OFA суперсеть
ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)
# Специализируем под конкретный device с latency constraint
from ofa.nas.efficiency_predictor import Latency_MBV3_MeasuredNet
efficiency_predictor = Latency_MBV3_MeasuredNet(
'note10', # Samsung Note10 — реальные замеры латентности
ofa_network
)
# Evolutionary search: ищем подсеть с latency < 25ms и max accuracy
from ofa.nas.search_algorithm.evolution_finder import EvolutionFinder
finder = EvolutionFinder(
efficiency_constraint=25, # ms
efficiency_predictor=efficiency_predictor,
accuracy_predictor=accuracy_predictor,
population_size=100,
max_time_budget=500 # эволюционных шагов
)
best_valids, best_info = finder.run_evolution_search()
В реальном проекте: NAS под MobileNetV3-space для задачи классификации производственного брака (640×480, 12 классов). Целевая платформа — NVIDIA Jetson Nano (4GB RAM, 128 CUDA cores). Ограничение: latency < 30ms при batch=1. Ручная архитектура MobileNetV3-Large давала 28.4ms и accuracy 91.3%. OFA-поиск за 6 часов нашёл подсеть: 22.1ms, accuracy 92.7%. Без единого ручного изменения архитектуры.
Практический стек и когда NAS оправдан
| Сценарий | Подход | Время поиска | Инструмент |
|---|---|---|---|
| Image classification, стандартный | DARTS / PC-DARTS | 1–2 дня (4× A100) | nni (Microsoft) или automl (torchvision) |
| Edge deployment (мобайл, MCU) | OFA / MNasNet-style | 6–24 часа | Once-for-All, TuNAS |
| NLP / Transformer architecture | NAS-BERT, AutoFormer | 2–5 дней | Hugging Face NAS toolkit |
| Кастомные операции, custom hardware | Predictor-based NAS | 1–3 дня + 100 eval | BANANAS, NASBOWL |
Microsoft NNI — наиболее зрелый open-source фреймворк для NAS. Поддерживает DARTS, ENAS, Random NAS, SPOS из коробки. Интеграция с PyTorch и TensorFlow.
import nni
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback
trainer = DartsTrainer(
model=model,
loss=nn.CrossEntropyLoss(),
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optimizer,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=64,
log_frequency=10,
callbacks=[LRSchedulerCallback(scheduler)]
)
trainer.fit()
# Получаем финальную архитектуру
export_result = trainer.export()
Когда NAS не нужен. Если задача стандартная и данных < 50k примеров — возьмите предобученную ResNet-50 или EfficientNet-B0 и fine-tune. NAS оправдан при: кастомных hardware-ограничениях, нетипичных входных данных (гиперспектральные снимки, специфические модальности), необходимости кардинально уменьшить модель без потери качества.
Процесс работы
- Определение search space — критически важный шаг: задаём блоки, операции, диапазоны каналов, максимальную глубину. Неправильный search space = плохой результат независимо от алгоритма
- Выбор стратегии поиска — DARTS для GPU-rich окружения, evolutionary для hardware-aware, predictor-based при ограниченном бюджете оценок
- Профилирование целевого железа — реальные замеры latency/throughput для операций из search space на production hardware
- Поиск и оценка кандидатов — Weights & Biases для трекинга, MLflow для хранения найденных архитектур
- Full training найденной архитектуры с нуля — веса из фазы поиска не используются
- Validation на holdout set, профилирование на production-железе
Сроки: определение search space и setup — 1 неделя. Сам поиск — 1–5 дней вычислений. Full training кандидата + валидация — 1–2 недели. Итого: 3–6 недель на полный NAS-цикл.







