Source code for mixmo.core.optimizer

"""
Optimizer factory
"""

import torch.optim as optim
from mixmo.utils.logger import get_logger


LOGGER = get_logger(__name__, level="DEBUG")


[docs]def get_optimizer(optimizer, list_param_groups): """ Builds optimizer objects from config """ optimizer_name = optimizer["name"] optimizer_params = optimizer["params"] LOGGER.info(f"Using optimizer {optimizer_name} with params {optimizer_params}") if optimizer_name == "sgd": optimizer = optim.SGD(list_param_groups, **optimizer_params) elif optimizer_name == "adam": optimizer = optim.Adam(list_param_groups, **optimizer_params) elif optimizer_name == "adadelta": optimizer = optim.Adadelta(list_param_groups, **optimizer_params) elif optimizer_name == "rmsprop": optimizer = optim.RMSprop(list_param_groups, **optimizer_params) else: raise KeyError("Bad optimizer name or not implemented (sgd, adam, adadelta).") return optimizer