basecore.utils.registry 源代码
#!/usr/bin/env python3
# Copyright (c) Megvii Inc. All rights reserved.
import pprint
from typing import Dict, Optional
from tabulate import tabulate
# design of Registry is inspired by fvcore, please check
# https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py
# for more details
[文档]class Registry:
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
OPTIMIZER = Registry('optimizer')
To register an object:
.. code-block:: python
@OPTIMIZER.register()
class MyOptimizer():
...
Or:
.. code-block:: python
OPTIMIZER.register(MyOptimizer)
Or:
.. code-block:: python
@OPTIMIZER.register("Name for Registry")
class MyOptimizer():
...
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, object] = {}
def _do_register(self, name: str, obj: object) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
[文档] def register(self, obj: object = None, name: str = None) -> Optional[object]:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class: object) -> object:
nonlocal name
if name is None:
name = func_or_class.__name__ # pyre-ignore
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
if name is None:
name = obj.__name__ # pyre-ignore
self._do_register(name, obj)
[文档] def get(self, name: str) -> object:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(name, self._name)
)
return ret
def __contains__(self, name: str) -> bool:
return name in self._obj_map
def __repr__(self) -> str:
table_headers = ["Names", "Objects"]
table_content = [(k, pprint.pformat(v)) for k, v in self._obj_map.items()]
table = tabulate(table_content, headers=table_headers, tablefmt="fancy_grid")
return "Registry of {}:\n".format(self._name) + table
[文档] def items(self):
return self._obj_map.items()
[文档] def keys(self):
return self._obj_map.keys()
[文档] def values(self):
return self._obj_map.values()