Реализация Neural Architecture Search (NAS) для проектирования оптимальной архитектуры модели

Проектируем и внедряем системы искусственного интеллекта: от прототипа до production-ready решения. Наша команда объединяет экспертизу в машинном обучении, дата-инжиниринге и MLOps, чтобы AI работал не в лаборатории, а в реальном бизнесе.
Показано 1 из 1 услугВсе 1566 услуг
Реализация Neural Architecture Search (NAS) для проектирования оптимальной архитектуры модели
Сложная
от 1 недели до 3 месяцев
Часто задаваемые вопросы
Направления AI-разработки
Этапы разработки AI-решения
Последние работы
  • image_website-b2b-advance_0.png
    Разработка сайта компании B2B ADVANCE
    1218
  • image_web-applications_feedme_466_0.webp
    Разработка веб-приложения для компании FEEDME
    1161
  • image_websites_belfingroup_462_0.webp
    Разработка веб-сайта для компании БЕЛФИНГРУПП
    853
  • image_ecommerce_furnoro_435_0.webp
    Разработка интернет магазина для компании FURNORO
    1047
  • image_logo-advance_0.png
    Разработка логотипа компании B2B Advance
    561
  • image_crm_enviok_479_0.webp
    Разработка веб-приложения для компании Enviok
    825

Neural Architecture Search (NAS)

Вручную спроектированные архитектуры — результат опыта и интуиции. NAS — алгоритмический перебор архитектурного пространства с оптимизацией под конкретную задачу, датасет и hardware target. Не замена архитектурного мышления, а способ найти конфигурации, до которых человек просто не дойдёт за разумное время.

Почему наивный NAS убивает GPU-бюджет

Классический NAS в исполнении NASNet (Google, 2017) — 500 GPU-дней на A100-эквиваленте. Проблема в том, что каждая кандидатная архитектура обучалась с нуля до сходимости. При пространстве поиска в 10^10 конфигураций полный перебор невозможен в принципе.

Современные подходы решают это через три принципиально разные идеи:

One-shot NAS / Weight Sharing. Суперсеть (supernet) включает все возможные подграфы. Каждый кандидат — «путь» через эту суперсеть, который использует уже обученные веса. DARTS, SNAS, Single-Path NAS — все они строятся на этой идее. Время поиска падает с сотен GPU-дней до 1–4 дней.

Predictor-based NAS. Обучается surrogate-модель, которая предсказывает accuracy архитектуры без её полного обучения. BANANAS, NASBOWL, NAO используют этот подход. Выборка из пространства поиска + 100–200 реальных оценок → предиктор точности для следующих миллиона кандидатов.

Hardware-aware NAS. Оптимизация не только по accuracy, но по latency на конкретном устройстве. MNasNet, FBNet, Once-for-All — ищут Pareto-front в пространстве (accuracy, latency/MACs). Критично для edge deployment.

Глубокий разбор: DARTS и его проблемы в production

DARTS (Differentiable Architecture Search) — наиболее используемый one-shot метод. Идея: вместо дискретного выбора операции (3×3 conv vs 5×5 conv vs skip) используем непрерывные веса α для каждой операции, оптимизируемые через gradient descent.

import torch
import torch.nn as nn
from torch.nn import functional as F

class MixedOp(nn.Module):
    """
    DARTS mixed operation: взвешенная сумма всех кандидатных операций.
    Веса alpha оптимизируются через архитектурный градиент.
    """
    def __init__(self, C: int, stride: int):
        super().__init__()
        self._ops = nn.ModuleList()
        for primitive in PRIMITIVES:  # ['none', 'skip_connect', 'sep_conv_3x3', ...]
            op = OPS[primitive](C, stride, affine=False)
            self._ops.append(op)

    def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
        # weights = softmax(alpha) — архитектурные веса
        return sum(w * op(x) for w, op in zip(weights, self._ops))


class DARTSCell(nn.Module):
    def __init__(self, steps: int, multiplier: int, C_prev_prev: int,
                 C_prev: int, C: int, reduction: bool, reduction_prev: bool):
        super().__init__()
        self._steps = steps       # число промежуточных узлов (обычно 4)
        self._multiplier = multiplier  # сколько узлов конкатенируется на выходе
        # ... инициализация preprocess и mixed ops

    def forward(self, s0: torch.Tensor, s1: torch.Tensor,
                weights: torch.Tensor) -> torch.Tensor:
        states = [s0, s1]
        offset = 0
        for i in range(self._steps):
            s = sum(
                self._ops[offset + j](h, weights[offset + j])
                for j, h in enumerate(states)
            )
            offset += len(states)
            states.append(s)
        return torch.cat(states[-self._multiplier:], dim=1)

Двухуровневая оптимизация DARTS — главная инженерная сложность. Сетевые веса w и архитектурные веса α оптимизируются попеременно:

def train_darts_step(model, architect, optimizer_w, optimizer_alpha,
                     train_queue, valid_queue, lr_w: float):
    """
    DARTS: чередование шагов оптимизации весов сети и архитектурных весов.
    """
    for step, (input_train, target_train) in enumerate(train_queue):
        # 1. Архитектурный шаг: обновляем alpha по валидационной потере
        input_valid, target_valid = next(iter(valid_queue))
        architect.step(
            input_train, target_train,
            input_valid, target_valid,
            lr=lr_w, optimizer=optimizer_w,
            unrolled=False  # True = second-order DARTS, в 2x дороже
        )

        # 2. Шаг весов: обновляем w по тренировочной потере
        optimizer_w.zero_grad()
        logits = model(input_train)
        loss = F.cross_entropy(logits, target_train)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer_w.step()

Проблема коллапса операций. В чистом DARTS skip-connection операции почти всегда «побеждают» — у них нулевые параметры, они хорошо обучаются на ранних этапах, и архитектурные веса α[skip] устойчиво растут. Результат: найденная архитектура вырождается в почти skip-only сеть с плохим обобщением. Решения:

  • DARTS+: отсечение skip-connections с наибольшим α на финальном этапе
  • P-DARTS: прогрессивное увеличение глубины сети во время поиска
  • GDAS: Gumbel-softmax вместо softmax для α — разреженный выбор операций

Hardware-aware NAS на практике

Для мобильного деплоя (Android, CoreML) accuracy — не единственная метрика. Latency на целевом железе важнее FLOP-подсчёта, потому что разные операции выполняются по-разному на реальном железе.

Once-for-All (MIT) — обучается одна суперсеть, из которой без дообучения извлекаются подсети под любой hardware constraint:

from ofa.model_zoo import ofa_net

# Загружаем предобученную OFA суперсеть
ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)

# Специализируем под конкретный device с latency constraint
from ofa.nas.efficiency_predictor import Latency_MBV3_MeasuredNet
efficiency_predictor = Latency_MBV3_MeasuredNet(
    'note10',   # Samsung Note10 — реальные замеры латентности
    ofa_network
)

# Evolutionary search: ищем подсеть с latency < 25ms и max accuracy
from ofa.nas.search_algorithm.evolution_finder import EvolutionFinder
finder = EvolutionFinder(
    efficiency_constraint=25,           # ms
    efficiency_predictor=efficiency_predictor,
    accuracy_predictor=accuracy_predictor,
    population_size=100,
    max_time_budget=500                  # эволюционных шагов
)
best_valids, best_info = finder.run_evolution_search()

В реальном проекте: NAS под MobileNetV3-space для задачи классификации производственного брака (640×480, 12 классов). Целевая платформа — NVIDIA Jetson Nano (4GB RAM, 128 CUDA cores). Ограничение: latency < 30ms при batch=1. Ручная архитектура MobileNetV3-Large давала 28.4ms и accuracy 91.3%. OFA-поиск за 6 часов нашёл подсеть: 22.1ms, accuracy 92.7%. Без единого ручного изменения архитектуры.

Практический стек и когда NAS оправдан

Сценарий Подход Время поиска Инструмент
Image classification, стандартный DARTS / PC-DARTS 1–2 дня (4× A100) nni (Microsoft) или automl (torchvision)
Edge deployment (мобайл, MCU) OFA / MNasNet-style 6–24 часа Once-for-All, TuNAS
NLP / Transformer architecture NAS-BERT, AutoFormer 2–5 дней Hugging Face NAS toolkit
Кастомные операции, custom hardware Predictor-based NAS 1–3 дня + 100 eval BANANAS, NASBOWL

Microsoft NNI — наиболее зрелый open-source фреймворк для NAS. Поддерживает DARTS, ENAS, Random NAS, SPOS из коробки. Интеграция с PyTorch и TensorFlow.

import nni
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback

trainer = DartsTrainer(
    model=model,
    loss=nn.CrossEntropyLoss(),
    metrics=lambda output, target: accuracy(output, target, topk=(1,)),
    optimizer=optimizer,
    num_epochs=50,
    dataset_train=dataset_train,
    dataset_valid=dataset_valid,
    batch_size=64,
    log_frequency=10,
    callbacks=[LRSchedulerCallback(scheduler)]
)
trainer.fit()
# Получаем финальную архитектуру
export_result = trainer.export()

Когда NAS не нужен. Если задача стандартная и данных < 50k примеров — возьмите предобученную ResNet-50 или EfficientNet-B0 и fine-tune. NAS оправдан при: кастомных hardware-ограничениях, нетипичных входных данных (гиперспектральные снимки, специфические модальности), необходимости кардинально уменьшить модель без потери качества.

Процесс работы

  1. Определение search space — критически важный шаг: задаём блоки, операции, диапазоны каналов, максимальную глубину. Неправильный search space = плохой результат независимо от алгоритма
  2. Выбор стратегии поиска — DARTS для GPU-rich окружения, evolutionary для hardware-aware, predictor-based при ограниченном бюджете оценок
  3. Профилирование целевого железа — реальные замеры latency/throughput для операций из search space на production hardware
  4. Поиск и оценка кандидатов — Weights & Biases для трекинга, MLflow для хранения найденных архитектур
  5. Full training найденной архитектуры с нуля — веса из фазы поиска не используются
  6. Validation на holdout set, профилирование на production-железе

Сроки: определение search space и setup — 1 неделя. Сам поиск — 1–5 дней вычислений. Full training кандидата + валидация — 1–2 недели. Итого: 3–6 недель на полный NAS-цикл.