basecls.utils.env 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import ctypes
import os
import subprocess

__all__ = ["set_nccl_env", "set_num_threads"]


[文档]def set_nccl_env(): """Set NCCL environments, which is essential to multi-node training.""" os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" os.environ["NCCL_IB_HCA"] = subprocess.getoutput( "pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; " "do cat $i/ports/1/gid_attrs/types/* 2>/dev/null " "| grep v >/dev/null && echo $i ; done; popd > /dev/null" ) os.environ["NCCL_IB_GID_INDEX"] = "3" os.environ["NCCL_IB_TC"] = "106"
[文档]def set_num_threads(num: int = 1): """Set number of threads in OpenMP, OpenCV, MKL, OPENBLAS, VECLIB, NUMEXPR, etc. Args: num: number of threads. Default: 1 """ try: import mkl mkl.set_num_threads(num) except Exception: pass for name in ["libmkl_rt.so", "libmkl_rt.dylib", "mkl_Rt.dll"]: try: mkl_rt = ctypes.CDLL(name) mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(num))) except Exception: pass os.environ["OMP_NUM_THREADS"] = str(num) os.environ["OPENBLAS_NUM_THREADS"] = str(num) os.environ["MKL_NUM_THREADS"] = str(num) os.environ["VECLIB_MAXIMUM_THREADS"] = str(num) os.environ["NUMEXPR_NUM_THREADS"] = str(num) os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" try: import cv2 cv2.setNumThreads(num) cv2.ocl.setUseOpenCL(False) except Exception: pass