basecls.engine.tester 源代码
#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import datetime
import time
import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
from basecore.config import ConfigDict
from basecore.engine import BaseTester
from basecore.network import adjust_stats
from basecore.utils import log_every_n_seconds
from loguru import logger
from megengine import jit
from basecls.data import DataLoaderType
from basecls.layers import Preprocess
__all__ = ["ClsTester"]
[文档]class ClsTester(BaseTester):
def __init__(self, cfg: ConfigDict, model: M.Module, dataloader: DataLoaderType):
super().__init__(model, dataloader)
self.cfg = cfg
self.preprocess = Preprocess(cfg.preprocess.img_mean, cfg.preprocess.img_std)
[文档] def test(self, warm_iters=5, log_seconds=5):
cnt = 0
acc1 = 0
acc5 = 0
total_iters = len(self.dataloader)
warm_iters = min(warm_iters, total_iters)
total_time = 0
with adjust_stats(self.model, training=False) as model:
model_step = jit.trace(model, symbolic=True) if self.cfg.trace else model
for iters, data in enumerate(self.dataloader, 1):
if iters == warm_iters + 1:
total_time = 0
samples, targets = self.preprocess(data)
start_time = time.perf_counter()
outputs = model_step(samples)
mge._full_sync() # use full_sync func to sync launch queue for dynamic execution
total_time += time.perf_counter() - start_time
accs = F.metric.topk_accuracy(outputs, targets, (1, 5))
cnt += targets.shape[0]
acc1 += accs[0].item() * 100 * targets.shape[0]
acc5 += accs[1].item() * 100 * targets.shape[0]
if log_seconds > 0:
count_iters = iters - warm_iters if iters > warm_iters else iters
time_per_iter = total_time / count_iters
infer_eta = (total_iters - iters) * time_per_iter
log_every_n_seconds(
"Inference process {}/{}, average speed:{:.4f}s/iters. ETA:{}".format(
iters,
total_iters,
time_per_iter,
datetime.timedelta(seconds=int(infer_eta)),
),
n=log_seconds,
)
logger.info(
"Finish inference process, total time:{}, average speed:{:.4f}s/iters.".format(
datetime.timedelta(seconds=int(total_time)),
total_time / (len(self.dataloader) - warm_iters),
)
)
cnt = dist.functional.all_reduce_sum(mge.Tensor(cnt)).item()
acc1 = dist.functional.all_reduce_sum(mge.Tensor(acc1)).item() / cnt
acc5 = dist.functional.all_reduce_sum(mge.Tensor(acc5)).item() / cnt
if dist.get_rank() == 0:
logger.info(f"Test Acc@1: {acc1:.3f}, Acc@5: {acc5:.3f}")
return acc1, acc5