import torch
from torch.optim.optimizer import Optimizer
from .nmf import _proj_func, _get_norm
from .constants import eps
[docs]class BetaMu(Optimizer):
r"""Implements the classic multiplicative updater for NMF models minimizing β-divergence.
Note:
To use this optimizer, not only make sure your model parameters are non-negative, but the gradients
along the whole computational graph are always non-negative.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
beta (float, optional): beta divergence to be minimized, measuring the distance between target and the NMF model.
Default: ``1.``
l1_reg (float, optional): L1 regularize penalty. Default: ``0.``.
l2_reg (float, optional): L2 regularize penalty (weight decay). Default: ``0.``
orthogonal (float, optional): orthogonal regularize penalty. Default: ``0.``
"""
def __init__(self, params, beta=1, l1_reg=0, l2_reg=0, orthogonal=0):
if not 0.0 <= l1_reg:
raise ValueError("Invalid l1_reg value: {}".format(l1_reg))
if not 0.0 <= l2_reg:
raise ValueError("Invalid l2_reg value: {}".format(l2_reg))
if not 0.0 <= orthogonal:
raise ValueError("Invalid orthogonal value: {}".format(orthogonal))
defaults = dict(beta=beta, l1_reg=l1_reg,
l2_reg=l2_reg, orthogonal=orthogonal)
super(BetaMu, self).__init__(params, defaults)
[docs] @torch.no_grad()
def step(self, closure):
"""Performs a single update step.
Arguments:
closure (callable): a closure that reevaluates the model
and returns the target and predicted Tensor in the form:
``func()->Tuple(target,predict)``
"""
# Make sure the closure is always called with grad enabled
closure = torch.enable_grad()(closure)
status_cache = dict()
for group in self.param_groups:
for p in group['params']:
status_cache[id(p)] = p.requires_grad
p.requires_grad = False
for group in self.param_groups:
beta = group['beta']
l1_reg = group['l1_reg']
l2_reg = group['l2_reg']
ortho = group['orthogonal']
if beta < 1:
gamma = 1 / (2 - beta)
elif beta > 2:
gamma = 1 / (beta - 1)
else:
gamma = 1
for p in group['params']:
if not status_cache[id(p)]:
continue
p.requires_grad = True
V, WH = closure()
if not WH.requires_grad:
p.requires_grad = False
continue
if beta == 2:
output_neg = V
output_pos = WH
elif beta == 1:
output_neg = V / WH.add(eps)
output_pos = torch.ones_like(WH)
elif beta == 0:
WH_eps = WH.add(eps)
output_pos = WH_eps.reciprocal_()
output_neg = output_pos.square().mul_(V)
else:
WH_eps = WH.add(eps)
output_neg = WH_eps.pow(beta - 2).mul_(V)
output_pos = WH_eps.pow_(beta - 1)
# first backward
WH.backward(output_neg, retain_graph=True)
neg = torch.clone(p.grad).relu_()
p.grad.zero_()
WH.backward(output_pos)
pos = torch.clone(p.grad).relu_()
p.grad.add_(-neg)
if l1_reg > 0:
pos.add_(l1_reg)
if l2_reg > 0:
pos.add_(p, alpha=l2_reg)
if ortho > 0:
pos.add_(p.sum(1, keepdims=True) - p, alpha=ortho)
pos.add_(eps)
neg.add_(eps)
multiplier = neg.div_(pos)
if gamma != 1:
multiplier.pow_(gamma)
p.mul_(multiplier)
p.requires_grad = False
for group in self.param_groups:
for p in group['params']:
p.requires_grad = status_cache[id(p)]
return None
[docs]class SparsityProj(Optimizer):
r"""Implements parseness constrainted gradient projection method described in `Non-negative Matrix Factorization
with Sparseness Constraints`_.
.. _`Non-negative Matrix Factorization with Sparseness Constraints`:
https://www.jmlr.org/papers/volume5/hoyer04a/hoyer04a.pdf
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
sparsity (float): the target sparseness for `params`, with 0 < sparsity < 1
dim (int, optional): dimension over which to compute the sparseness for each parameter. Default: ``1``
max_iter (int, optional): maximal number of function evaluations per optimization step. Default: ``10``
"""
def __init__(self, params, sparsity, dim=1, max_iter=10):
if not 0.0 < sparsity < 1.:
raise ValueError("Invalid sparsity value: {}".format(sparsity))
defaults = dict(sparsity=sparsity, lr=1, dim=dim, max_iter=max_iter)
super(SparsityProj, self).__init__(params, defaults)
[docs] @torch.no_grad()
def step(self, closure):
"""Performs a single update step.
Arguments:
closure (callable): a closure that reevaluates the model and returns the loss
"""
loss = None
for group in self.param_groups:
sparsity = group['sparsity']
lr = group['lr']
dim = group['dim']
max_iter = group['max_iter']
with torch.enable_grad():
init_loss = closure()
init_loss.backward()
params = [(p, p.grad.clone())
for p in group['params'] if p.grad is not None]
for i in range(max_iter):
for p, g in params:
norms = _get_norm(p, dim)
p.add_(g, alpha=-lr)
N = p.numel() // p.shape[dim]
L1 = N ** 0.5 * (1 - sparsity) + sparsity
for j in range(p.shape[dim]):
slicer = (slice(None),) * dim + (j,)
p[slicer] = _proj_func(
p[slicer], L1 * norms[j], norms[j] ** 2)
loss = closure()
if loss <= init_loss:
break
for p, g in params:
p.add_(g, alpha=lr)
lr *= 0.5
lr *= 1.2
group['lr'] = lr
return loss