Source code for mixmo.networks

"""
Networks used in the main paper
"""

from mixmo.utils.logger import get_logger

from mixmo.networks import resnet, wrn


LOGGER = get_logger(__name__, level="DEBUG")


[docs]def get_network(config_network, config_args): """ Return a new instance of network """ # Available networks for tiny if config_args["data"]["dataset"].startswith('tinyimagenet'): network_factory = resnet.resnet_network_factory elif config_args["data"]["dataset"].startswith('cifar'): network_factory = wrn.wrn_network_factory else: raise NotImplementedError LOGGER.warning(f"Loading network: {config_network['name']}") return network_factory[config_network["name"]]( config_network=config_network, config_args=config_args)