basecls.zoo.benchmark 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import argparse
import datetime
import multiprocessing as mp
import time

import megengine as mge
import megengine.amp as amp
import megengine.autodiff as autodiff
import megengine.distributed as dist
import megengine.functional as F
import megengine.jit as jit
import megengine.optimizer as optim
from basecore.utils import log_every_n_seconds
from loguru import logger

from basecls.data.fake_data import FakeDataLoader
from basecls.layers import Preprocess
from basecls.utils import registers, set_nccl_env, set_num_threads


[文档]def main(): parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", default="resnet50", type=str) parser.add_argument("--mode", default="eval", type=str) parser.add_argument("-d", "--device", default="gpu", type=str) parser.add_argument("--amp", default=0, type=int) parser.add_argument("--fastrun", action="store_true") parser.add_argument("--trace", action="store_true") parser.add_argument("--dtr", action="store_true") parser.add_argument("-b", "--batch-size", default=32, type=int) parser.add_argument("--channel", default=3, type=int) parser.add_argument("--height", default=224, type=int) parser.add_argument("--width", default=224, type=int) parser.add_argument("-n", "--world-size", default=8, type=int) parser.add_argument("--warm-iters", default=50, type=int) parser.add_argument("-t", "--total-iters", default=200, type=int) parser.add_argument("--log-seconds", default=2, type=int) args = parser.parse_args() mp.set_start_method("spawn") set_nccl_env() set_num_threads() if args.world_size == 1: worker(args) else: dist.launcher(worker, n_gpus=args.world_size)(args)
[文档]@logger.catch def worker(args: argparse.Namespace): if dist.get_rank() != 0: logger.remove() logger.info(f"args: {args}") if args.fastrun: logger.info("Using fastrun mode...") mge.functional.debug_param.set_execution_strategy("PROFILE") if args.dtr: logger.info("Enabling DTR...") mge.dtr.enable() mge.set_default_device(f"{args.device}{dist.get_rank()}") model = registers.models.get(args.model)(head=dict(w_out=1000)) dataloader = FakeDataLoader( args.batch_size, (args.height, args.width), args.channel, length=args.warm_iters + args.total_iters, num_classes=1000, ) if args.mode == "train": BenchCls = TrainBench elif args.mode == "eval": BenchCls = EvalBench else: raise NotImplementedError(f"Benchmark mode '{args.mode}' not supported") bench = BenchCls(model, dataloader, args.trace, args.amp) bench.benchmark(args.warm_iters, args.log_seconds)
[文档]class ClsBench: def __init__(self, model, dataloader, trace: bool = False): self.model = model self.dataloader = dataloader self.preprocess = Preprocess(mean=127, std=128) if trace: self.model_step = jit.trace(self.model_step, symbolic=True)
[文档] def benchmark(self, warm_iters=50, log_seconds=2): total_iters = len(self.dataloader) - warm_iters total_time = 0 for i, data in enumerate(self.dataloader, 1): if i == warm_iters + 1: total_time = 0 samples, targets = self.preprocess(data) mge._full_sync() t = time.perf_counter() self.model_step(samples, targets) mge._full_sync() total_time += time.perf_counter() - t if log_seconds > 0: cnt = i - warm_iters if i > warm_iters else i tot = total_iters if i > warm_iters else warm_iters cycle = total_time / cnt eta = (tot - cnt) * cycle log_every_n_seconds( "{} process {}/{}, average speed:{:0.3f}ms/iters. ETA:{}".format( "Benchmark" if i > warm_iters else "Warmup", cnt, tot, cycle * 1000, datetime.timedelta(seconds=int(eta)), ), n=log_seconds, ) avg_speed_ms = total_time / total_iters * 1000 logger.info( "Benchmark total time:{}, average speed:{:0.3f}ms/iters.".format( datetime.timedelta(seconds=int(total_time)), avg_speed_ms ) ) return avg_speed_ms
[文档] def model_step(self, samples, targets): raise NotImplementedError
[文档]class TrainBench(ClsBench): def __init__(self, model, dataloader, trace: bool = False, amp_version: int = 0): model.train() super().__init__(model, dataloader, trace) self.opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001) self.gm = autodiff.GradManager() callbacks = ( [dist.make_allreduce_cb("mean", dist.WORLD)] if dist.get_world_size() > 1 else None ) self.gm.attach(model.parameters(), callbacks=callbacks) self.amp_version = amp_version self.scaler = ( amp.GradScaler(init_scale=65536.0, growth_interval=2000) if amp_version == 2 else amp.GradScaler(init_scale=128.0, growth_interval=0) )
[文档] def model_step(self, samples, targets): with self.gm: with amp.autocast(enabled=self.amp_version > 0): pred = self.model(samples) loss = F.loss.cross_entropy(pred, targets) if self.amp_version > 0: self.scaler.backward(self.gm, loss, update_scale=False) self.scaler.update() else: self.gm.backward(loss) self.opt.step().clear_grad()
[文档]class EvalBench(ClsBench): def __init__(self, model, dataloader, trace: bool = False, amp_version: int = 0): model.eval() super().__init__(model, dataloader, trace) self.amp_version = amp_version
[文档] def model_step(self, samples, targets): with amp.autocast(enabled=self.amp_version > 0): self.model(samples)
if __name__ == "__main__": main()