PyTorch NMF Documentation¶
PyTorch NMF is a extension library for PyTorch.
It consists of basic NMF algorithm and its convolutional variants, which are hardly found in other NMF packages. In addition, by using the PyTorch automatic differentiation feature, it is able to adopt the classic multiplicative update rules into more complex NMF structures, and make it possible to train these complex models in a simple end-to-end fashion.
Installation¶
pypi¶
Our package can be installed through the Python Package Index (PyPI). The only main dependency is PyTorch, please refer to their page to install the desire version you need.
After installing PyTorch, execute the following command:
pip install torchnmf
source¶
Alternately, the latest development version can also be installed via pip:
pip install git+https://github.com/yoyololicon/pytorch-NMF
Introduction by Example¶
Background¶
The goal of Non-negative Matrix Factorization (NMF) is, given a N by M non-negative matrix V
, find a R by M
non-negative matrix H
(typically called activation matrix) and a N by R non-negative matrix W
(
typically called template matrix) such that their matrix product WH
approximate V
to some degree.
Generally, R is chosen to be smaller than min(N, M)
, which implies that high-dimensional data V
can be reduced
to some low-dimensional space.
Basic Non-negative Matrix Factorization¶
Let’s see how PyTorch NMF work in action!
First, assuming that we have a target matrix V
with shape [64, 1024]
:
V.size()
>>> torch.Size([64, 1024])
Second, we need to construct a NMF instance by giving the shape of the target matrix and a latent order R
.
We use R = 10
:
import torch
from torchnmf.nmf import NMF
model = NMF(V.t().shape, rank=10)
mode.W.size()
>>> torch.Size([64, 10])
mode.H.size()
>>> torch.Size([1024, 10])
Now our model has two attributes, W
and H
with the shape defined in the previous section.
Note
The H
is actually presented as transposed matrix in our implementation.
Then, fitting two matrix to our target data:
mode.fit(V.t())
The reconstructed matrix is the matrix product of the two trained matrix:
WH = model.W @ model.H.t()
Or you can just simply call the model and it will done by itself:
WH = model()
Training on GPU¶
If you have NVIDIA GPU installed machine and have installed CUDA, you can try moving your model and target matrix to GPU, and see how it speed up the fitting process:
V = V.cuda().t()
model = model.cuda()
model.fit(V)
In PyTorch NMF, we implemented different kinds of NMF by inheriting and extending torch.nn.Module
object, so you
can treat them just like any other PyTorch Module (ex: moving among different devices, casting to different data type… etc.)
Model Concatenation¶
Started at version 0.3, you can now combine different NMF module into a single module, and train it in an end-to-end fashion.
Let’s use the previous example again. Instead of factorize matrix V
into 2 matrix, we factorize it into 4 matrix.
That is:
It’s actually just chaining 3 NMF module all together, with 2 of them use the output from other NMF as their activation matrix.
Here is how you do it:
import torch
import torch.nn as nn
from torchnmf.nmf import NMF
class Chain(nn.Module):
def __init__(self):
super().__init__()
self.nmf1 = NMF((1024, 10), rank=4)
self.nmf2 = NMF(W=(24, 10))
self.nmf3 = NMF(W=(64, 24))
def forward(self):
WH = self.nmf1()
WWH = self.nmf2(H=WH)
WWWH = self.nmf3(H=WWH)
return WWWH
model = Chain()
output = model()
You can also use torch.nn.Sequential
to construct this kind of chaining model:
model = nn.Sequential(
NMF((1024, 10), rank=4),
NMF(W=(24, 10)),
NMF(W=(64, 24))
)
# In newer version of PyTorch at least one input should be given
# We can just give it `None`
output = model(None)
To fit the model, instead of calling class method fit
, you now need to construct a NMF trainer:
from torchnmf.trainer import BetaMu
trainer = BetaMu(model.parameters())
To update parameters, you need to call step()
function in every iteration:
epochs = 200
for e in range(epochs):
def closure():
trainer.zero_grad()
return V.t(), model(None)
trainer.step(closure)
torchnmf.nmf¶
-
class
torchnmf.nmf.
BaseComponent
(rank=None, W=None, H=None, trainable_W=True, trainable_H=True)[source]¶ Bases:
torch.nn.modules.module.Module
Base class for all NMF modules.
You can’t use this module directly. Your models should also subclass this class.
- Parameters
rank (int) – size of hidden dimension
W (tuple or Tensor) – size or initial weights of template tensor W
H (tuple or Tensor) – size or initial weights of activation tensor H
trainable_W (bool) – controls whether template tensor W is trainable when initial weights is given. Default:
True
trainable_H (bool) – controls whether activation tensor H is trainable when initial weights is given. Default:
True
-
W
¶ the template tensor of the module if corresponding argument is given. If size is given, values are initialized non-negatively.
- Type
Tensor or None
-
H
¶ the activation tensor of the module if corresponding argument is given. If size is given, values are initialized non-negatively.
- Type
Tensor or None
-
fit
(V, beta=1, tol=0.0001, max_iter=200, verbose=False, alpha=0, l1_ratio=0)[source]¶ Learn a NMF model for the data V by minimizing beta divergence.
To invoke this function, attributes
H
andW
should be presented in this module.- Parameters
V (Tensor) – data tensor to be decomposed. Can be a sparse tensor returned by
torch.sparse_coo_tensor()
beta (float) – beta divergence to be minimized, measuring the distance between V and the NMF model. Default:
1.
tol (float) – tolerance of the stopping condition. Default:
1e-4
max_iter (int) – maximum number of iterations before timing out. Default:
200
verbose (bool) – whether to be verbose. Default:
False
alpha (float) – constant that multiplies the regularization terms. Set it to zero to have no regularization Default:
0
l1_ratio (float) – the regularization mixing parameter, with 0 <= l1_ratio <= 1. For l1_ratio = 0 the penalty is an elementwise L2 penalty (aka Frobenius Norm). For l1_ratio = 1 it is an elementwise L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2. Default:
0
- Returns
total number of iterations
- Return type
-
forward
(H=None, W=None)[source]¶ An outer wrapper of
self.reconstruct(H,W)
.Note
Should call the
BaseComponent
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
static
reconstruct
(H, W)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
-
sparse_fit
(V, beta=2, max_iter=200, verbose=False, sW=None, sH=None)[source]¶ Learn a NMF model for the data V by minimizing beta divergence with sparseness constraints proposed in Non-negative Matrix Factorization with Sparseness Constraints.
To invoke this function, attributes
H
andW
should be presented in this module.Note
Although the value range of
beta
is unrestricted, the original implementation only use Euclidean Distance (which meansbeta=2
) as their loss function, and we have no gaurantee on other values besides 2.- Parameters
V (Tensor) – data tensor to be decomposed. Can be a sparse tensor returned by
torch.sparse_coo_tensor()
beta (float) – beta divergence to be minimized, measuring the distance between V and the NMF model Default:
1.
max_iter (int) – maximum number of iterations before timing out. Default:
200
verbose (bool) – whether to be verbose. Default:
False
sW (float or None) – the target sparseness for template tensor
W
, with 0 < sW < 1. Set it toNone
will have no constraint. Default:None
sH (float or None) – the target sparseness for activation tensor
H
, with 0 < sH < 1. Set it toNone
will have no constraint. Default:None
- Returns
total number of iterations
- Return type
-
class
torchnmf.nmf.
NMF
(Vshape=None, rank=None, **kwargs)[source]¶ Bases:
torchnmf.nmf.BaseComponent
Non-Negative Matrix Factorization (NMF).
Find two non-negative matrices (W, H) whose product approximates the non- negative matrix V: \(V \approx HW^T\).
This factorization can be used for example for dimensionality reduction, source separation or topic extraction.
Note
If Vshape argument is given, the model will try to infer the size of
W
andH
, and override arguments passed through toBaseComponent
.- Parameters
Vshape (tuple, optional) – size of target matrix V
rank (int, optional) – size of hidden dimension
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C)\)
W: \((C, R)\)
H: \((N, R)\)
Examples:
>>> V = torch.rand(20, 30) >>> m = NMF(V.shape, 5) >>> m.W.size() torch.Size([30, 5]) >>> m.H.size() torch.Size([20, 5]) >>> HWt = m() >>> HWt.size() torch.Size([20, 30])
-
class
torchnmf.nmf.
NMF2D
(Vshape=None, rank=None, kernel_size=1, **kwargs)[source]¶ Bases:
torchnmf.nmf.BaseComponent
Nonnegative Matrix Factor 2-D Deconvolution (NMF2D).
Find non-negative 3-dimensional tensor H and 4-dimensional tensor W whose 2D convolutional output approximates the non-negative 3-dimensional tensor V:
\[\mathbf{V} \approx \sum_{\tau} \sum_{\phi} \stackrel{\downarrow \phi}{\mathbf{W}^{\tau}} \stackrel{\rightarrow \tau}{\mathbf{H}^{\phi}}\]More precisely:
\[V_{i,j,k} \approx \sum_{l=0}^{k_1-1} \sum_{m=0}^{k_2-1} \sum_{r=0}^{\text{rank}-1} W_{i,r,l,m} * H_{r, j-l,k-m}\]Look at the paper: Nonnegative Matrix Factor 2-D Deconvolution for Blind Single Channel Source Separation by Schmidt et al. (2006) for more details.
Note
To match with PyTorch convention, an extra batch dimension is required for target tensor V.
Note
If Vshape argument is given, the model will try to infer the size of
W
andH
, and override arguments passed through toBaseComponent
.Warning
Using sparse tensor as target when calling
NMF2D.fit()
orNMF2D.sparse_fit()
is currently not supported.- Parameters
Vshape (tuple, optional) – size of target tensor V
rank (int, optional) – size of hidden dimension
kernel_size (int or tuple, optional) – size of the convolving kernel
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C, L_{out}, M_{out})\)
W: \((C, R, \text{kernel_size}[0], \text{kernel_size}[1])\)
H: \((N, R, L_{in}, M_{in})\) where
\[L_{in} = L_{out} - \text{kernel_size}[0] + 1\]\[M_{in} = M_{out} - \text{kernel_size}[1] + 1\]
Examples:
>>> V = torch.rand(33, 50).unsqueeze(0).unsqueeze(0) >>> m = NMF2D(V.shape, 16, 3) >>> m.W.size() torch.Size([1, 16, 3, 3]) >>> m.H.size() torch.Size([1, 16, 31, 48]) >>> HWt = m() >>> HWt.size() torch.Size([1, 1, 33, 50])
-
class
torchnmf.nmf.
NMF3D
(Vshape=None, rank=None, kernel_size=1, **kwargs)[source]¶ Bases:
torchnmf.nmf.BaseComponent
Nonnegative Matrix Factor 3-D Deconvolution (NMF3D).
Find non-negative 4-dimensional tensor H and 5-dimensional tensor W whose 2D convolutional output approximates the non-negative 4-dimensional tensor V:
\[V_{i,j,k,l} \approx \sum_{m=0}^{k_1-1} \sum_{n=0}^{k_2-1} \sum_{u=0}^{k_3-1} \sum_{r=0}^{\text{rank}-1} W_{i,r,m,n,u} * H_{r,j-m,k-n,l-u}\]Note
To match with PyTorch convention, an extra batch dimension is required for target tensor V.
Note
If Vshape argument is given, the model will try to infer the size of
W
andH
, and override arguments passed through toBaseComponent
.Warning
Using sparse tensor as target when calling
NMF3D.fit()
orNMF3D.sparse_fit()
is currently not supported.- Parameters
Vshape (tuple, optional) – size of target tensor V
rank (int, optional) – size of hidden dimension
kernel_size (int or tuple, optional) – size of the convolving kernel
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C, L_{out}, M_{out}, O_{out})\)
W: \((C, R, \text{kernel_size}[0], \text{kernel_size}[1], \text{kernel_size}[2])\)
H: \((N, R, L_{in}, M_{in}, O_{in})\) where
\[L_{in} = L_{out} - \text{kernel_size}[0] + 1\]\[M_{in} = M_{out} - \text{kernel_size}[1] + 1\]\[O_{in} = O_{out} - \text{kernel_size}[2] + 1\]
Examples:
>>> V = torch.rand(3, 64, 64, 100).unsqueeze(0) >>> m = NMF3D(V.shape, 8, (5, 5, 20)) >>> m.W.size() torch.Size([3, 8, 5, 5, 20]) >>> m.H.size() torch.Size([1, 8, 60, 60, 81]) >>> HWt = m() >>> HWt.size() torch.Size([1, 3, 64, 64, 100])
-
class
torchnmf.nmf.
NMFD
(Vshape=None, rank=None, T=1, **kwargs)[source]¶ Bases:
torchnmf.nmf.BaseComponent
Non-negative Matrix Factor Deconvolution (NMFD).
Find non-negative matrix H and 3-dimensional tensor W whose convolutional output approximates the non- negative matrix V:
\[\mathbf{V} \approx \sum_{t=0}^{T-1} \mathbf{W}_{t} \cdot \stackrel{t \rightarrow}{\mathbf{H}}\]More precisely:
\[V_{i,j} \approx \sum_{t=0}^{T-1} \sum_{r=0}^{\text{rank}-1} W_{i,r,t} * H_{r, j - t}\]Look at the paper: Non-negative Matrix Factor Deconvolution; Extraction of Multiple Sound Sources from Monophonic Inputs by Paris Smaragdis (2004) for more details.
Note
To match with PyTorch convention, an extra batch dimension is required for target matrix V.
Note
If Vshape argument is given, the model will try to infer the size of
W
andH
, and override arguments passed through toBaseComponent
.Warning
Using sparse tensor as target when calling
NMFD.fit()
orNMFD.sparse_fit()
is currently not supported.- Parameters
Vshape (tuple, optional) – size of target matrix V
rank (int, optional) – size of hidden dimension
T (int, optional) – size of the convolving window
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C, L_{out})\)
W: \((C, R, T)\)
H: \((N, R, L_{in})\) where
\[L_{in} = L_{out} - T + 1\]
Examples:
>>> V = torch.rand(33, 50).unsqueeze(0) >>> m = NMF(V.shape, 16, 3) >>> m.W.size() torch.Size([33, 16, 3]) >>> m.H.size() torch.Size([1, 16, 48]) >>> HWt = m() >>> HWt.size() torch.Size([1, 33, 50])
torchnmf.plca¶
-
class
torchnmf.plca.
BaseComponent
(rank=None, W=None, H=None, Z=None, trainable_W=True, trainable_H=True, trainable_Z=True)[source]¶ Bases:
torch.nn.modules.module.Module
Base class for all PLCA modules.
You can’t use this module directly. Your models should also subclass this class.
- Parameters
rank (int) – size of hidden dimension
W (tuple or Tensor) – size or initial probabilities of template tensor W
H (tuple or Tensor) – size or initial probabilities of activation tensor H
Z (Tensor) – initial probabilities of latent vector Z
trainable_W (bool) – controls whether template tensor W is trainable when initial probabilities is given. Default:
True
trainable_H (bool) – controls whether activation tensor H is trainable when initial probabilities is given. Default:
True
trainable_Z (bool) – controls whether latent vector Z is trainable when initial probabilities is given. Default:
True
-
W
¶ the template tensor of the module if corresponding argument is given. If size is given, values are initialized randomly.
- Type
Tensor or None
-
H
¶ the activation tensor of the module if corresponding argument is given. If size is given, values are initialized randomly.
- Type
Tensor or None
-
Z
¶ the latent vector of the module if corresponding argument or rank is given. If rank is given, values are initialized uniformly.
- Type
Tensor or None
-
fit
(V, tol=0.0001, max_iter=200, verbose=False, W_alpha=1.0, H_alpha=1.0, Z_alpha=1.0)[source]¶ Learn a PLCA model for the data V by maximizing the following log probability of V and model params \(\theta\) using EM algorithm:
\[\begin{split}\mathcal{L} (\theta)= \sum_{k_1...k_N} v_{k_1...k_N}\log{\hat{v}_{k_1...k_N}} \\ + \sum_k (\alpha_{z,k} - 1) \log z_k \\ + \sum_{f_1...f_M} (\alpha_{w,f_1...f_M} - 1) \log w_{f_1...f_M} \\ + \sum_{\tau_1...\tau_L} (\alpha_{h,\tau_1...\tau_L} - 1) \log h_{\tau_1...\tau_L} \\\end{split}\]Where \(\hat{V}\) is the reconstructed output, N is the number of dimensions of target tensor \(V\), M is the number of dimensions of tensor \(W\), and L is the number of dimensions of tensor \(H\). The last three terms come from Dirichlet prior assumption.
To invoke this function, attributes
H
,W
andZ
should be presented in this module.- Parameters
V (Tensor) – data tensor to be decomposed
tol (float) – tolerance of the stopping condition. Default:
1e-4
max_iter (int) – maximum number of iterations before timing out. Default:
200
verbose (bool) – whether to be verbose. Default:
False
W_alpha (float) – hyper parameter of Dirichlet prior on W. Can be a scalar or a tensor that is broadcastable to W. Set it to one to have no regularization. Default:
1
H_alpha (float) – hyper parameter of Dirichlet prior on H. Can be a scalar or a tensor that is broadcastable to H. Set it to one to have no regularization. Default:
1
Z_alpha (float) – hyper parameter of Dirichlet prior on Z. Can be a scalar or a tensor that is broadcastable to Z. Set it to one to have no regularization. Default:
1
- Returns
a length-2 tuple with first element is total number of iterations, and the second is the sum of tensor V
- Return type
-
forward
(H=None, W=None, Z=None, norm=None)[source]¶ An outer wrapper of
self.reconstruct(H,W,Z)
.Note
Should call the
BaseComponent
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters
H (Tensor, optional) – input activation tensor H. If no tensor was given will use
H
from this module insteadW (Tensor, optional) – input template tensor W. If no tensor was given will use
W
from this module insteadZ (Tensor, optional) – input latent vector Z. If no tensor was given will use
Z
from this module insteadnorm (float, optional) – a scaling value multiply on output before return. Default:
1
- Returns
tensor
- Return type
Tensor
-
class
torchnmf.plca.
PLCA
(Vshape=None, rank=None, **kwargs)[source]¶ Bases:
torchnmf.plca.BaseComponent
Probabilistic Latent Component Analysis (PLCA).
Estimate two marginals \(P(c|z)\) and \(P(n|z)\), which is the matrix W and H, and a prior \(P(z)\) which is the vector Z, that approximate the observed \(P(n,c)\), where \(P(n,c)\) is obtained via
V / V.sum()
so the total probabilities sum to 1.More precisely:
\[P(n, c) \approx \sum_{z}P(c|z)P(z)P(n|z)\]In matrix form:
\[V \approx H diag(Z) W^T\]Its formulation is very similar to NMF, but introduce an extra latent vector to incorporate probabilities concept.
Note
If Vshape argument is given, the model will try to infer the size of
W
,H
andZ
, and override arguments passed through toBaseComponent
.- Parameters
Vshape (tuple, optional) – size of target matrix V
rank (int, optional) – size of hidden dimension
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C)\)
W: \((C, R)\)
H: \((N, R)\)
Z: \((R,)\)
Examples:
>>> V = torch.rand(20, 30) >>> m = PLCA(V.shape, 5) >>> m.W.size() torch.Size([30, 5]) >>> m.H.size() torch.Size([20, 5]) >>> m.Z.size() torch.Size([5]) >>> HZWt = m() >>> HZWt.size() torch.Size([20, 30])
-
class
torchnmf.plca.
SIPLCA
(Vshape=None, rank=None, T=1, **kwargs)[source]¶ Bases:
torchnmf.plca.BaseComponent
Shift Invariant Probabilistic Latent Component Analysis (SI-PLCA).
Estimate two marginals \(P(c,t|z)\) and \(P(n,l|z)\), which is the tensor W and H, and a prior \(P(z)\) which is the vector Z, that approximate the observed \(P(n,c,l)\), where \(P(n,c,l)\) is obtained via
V / V.sum()
so the total probabilities sum to 1.More precisely:
\[P(n, c, l) \approx \sum_{z} \sum_{t} P(c,t|z)P(z)P(n,l-t|z)\]Look at the paper: Shift-Invariant Probabilistic Latent Component Analysis by Paris Smaragdis and Bhiksha Raj (2007) for more details.
Note
If Vshape argument is given, the model will try to infer the size of
W
,H
andZ
, and override arguments passed through toBaseComponent
.- Parameters
Vshape (tuple, optional) – size of target matrix V
rank (int, optional) – size of hidden dimension
T (int, optional) – size of the convolving window
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C, L_{out})\)
W: \((C, R, T)\)
H: \((N, R, L_{in})\)
Z: \((R,)\) where
\[L_{in} = L_{out} - T + 1\]
Examples:
>>> V = torch.rand(33, 50).unsqueeze(0) >>> m = SIPLCA(V.shape, 16, 3) >>> m.W.size() torch.Size([33, 16, 3]) >>> m.H.size() torch.Size([1, 16, 48]) >>> m.Z.size() torch.Size([16]) >>> HZWt = m() >>> HZWt.size() torch.Size([1, 33, 50])
-
class
torchnmf.plca.
SIPLCA2
(Vshape=None, rank=None, kernel_size=1, **kwargs)[source]¶ Bases:
torchnmf.plca.BaseComponent
Shift Invariant Probabilistic Latent Component Analysis across 2 dimensions (SI-PLCA 2D).
Estimate two marginals \(P(c,k_1,k_2|z)\) and \(P(n,l,m|z)\), which is the tensor W and H, and a prior \(P(z)\) which is the vector Z, that approximate the observed \(P(n,c,l,m)\), where \(P(n,c,l,m)\) is obtained via
V / V.sum()
so the total probabilities sum to 1.More precisely:
\[P(n,c,l,m) \approx \sum_{z} \sum_{k_1} \sum_{k_2} P(c,k_1,k_2|z)P(z)P(n,l-k_1,m-k_2|z)\]Look at the paper: Shift-Invariant Probabilistic Latent Component Analysis by Paris Smaragdis and Bhiksha Raj (2007) for more details.
Note
If Vshape argument is given, the model will try to infer the size of
W
,H
andZ
, and override arguments passed through toBaseComponent
.- Parameters
Vshape (tuple, optional) – size of target tensor V
rank (int, optional) – size of hidden dimension
kernel_size (int or tuple, optional) – size of the convolving kernel
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C, L_{out}, M_{out})\)
W: \((C, R, \text{kernel_size}[0], \text{kernel_size}[1])\)
H: \((N, R, L_{in}, M_{in})\)
Z: \((R,)\) where
\[L_{in} = L_{out} - \text{kernel_size}[0] + 1\]\[M_{in} = M_{out} - \text{kernel_size}[1] + 1\]
Examples:
>>> V = torch.rand(33, 50).unsqueeze(0).unsqueeze(0) >>> m = SIPLCA2(V.shape, 16, 3) >>> m.W.size() torch.Size([1, 16, 3, 3]) >>> m.H.size() torch.Size([1, 16, 31, 48]) >>> m.Z.size() torch.Size([16]) >>> HZWt = m() >>> HZWt.size() torch.Size([1, 1, 33, 50])
-
class
torchnmf.plca.
SIPLCA3
(Vshape=None, rank=None, kernel_size=1, **kwargs)[source]¶ Bases:
torchnmf.plca.BaseComponent
Shift Invariant Probabilistic Latent Component Analysis across 3 dimensions (SI-PLCA 3D).
Estimate two marginals \(P(c,k_1,k_2,k_3|z)\) and \(P(n,l,m,o|z)\), which is the tensor W and H, and a prior \(P(z)\) which is the vector Z, that approximate the observed \(P(n,c,l,m,o)\), where \(P(n,c,l,m,o)\) is obtained via
V / V.sum()
so the total probabilities sum to 1.More precisely:
\[P(n,c,l,m,o) \approx \sum_{z} \sum_{k_1} \sum_{k_2} \sum_{k_3} P(c,k_1,k_2,k_3|z)P(z)P(n,l-k_1,m-k_2,o-k_3|z)\]Look at the paper: Shift-Invariant Probabilistic Latent Component Analysis by Paris Smaragdis and Bhiksha Raj (2007) for more details.
Note
If Vshape argument is given, the model will try to infer the size of
W
,H
andZ
, and override arguments passed through toBaseComponent
.- Parameters
Vshape (tuple, optional) – size of target tensor V
rank (int, optional) – size of hidden dimension
kernel_size (int or tuple, optional) – size of the convolving kernel
**kwargs – arguments passed through to
BaseComponent
- Shape:
V: \((N, C, L_{out}, M_{out}, O_{out})\)
W: \((C, R, \text{kernel_size}[0], \text{kernel_size}[1], \text{kernel_size}[2])\)
H: \((N, R, L_{in}, M_{in}, O_{in})\)
Z: \((R,)\) where
\[L_{in} = L_{out} - \text{kernel_size}[0] + 1\]\[M_{in} = M_{out} - \text{kernel_size}[1] + 1\]\[O_{in} = O_{out} - \text{kernel_size}[2] + 1\]
Examples:
>>> V = torch.rand(3, 64, 64, 100).unsqueeze(0) >>> m = SIPLCA3(V.shape, 8, (5, 5, 20)) >>> m.W.size() torch.Size([3, 8, 5, 5, 20]) >>> m.H.size() torch.Size([1, 8, 60, 60, 81]) >>> m.Z.size() torch.Size([8]) >>> HZWt = m() >>> HZWt.size() torch.Size([1, 3, 64, 64, 100])
torchnmf.metrics¶
-
torchnmf.metrics.
beta_div
(input, target, beta=2)[source]¶ The β-divergence loss measure
The loss can be described as:
\[\ell(x, y) = \sum_{n = 0}^{N - 1} \frac{1}{\beta (\beta - 1)}\left ( x_n^{\beta} + \left (\beta - 1 \right ) y_n^{\beta} - \beta x_n y_n^{\beta-1}\right )\]- Parameters
input (Tensor) – tensor of arbitrary shape
target (Tensor) – tensor of the same shape as input
beta (float) – a real value control the shape of loss function
- Returns
single element tensor
- Return type
Tensor
-
torchnmf.metrics.
euclidean
(input, target)[source]¶ The Euclidean distance, which equal to β-divergence loss when β = 2.
\[\ell(x, y) = \frac{1}{2} \sum_{n = 0}^{N - 1} (x_n - y_n)^2\]- Parameters
input (Tensor) – tensor of arbitrary shape
target (Tensor) – tensor of the same shape as input
- Returns
single element tensor
- Return type
Tensor
-
torchnmf.metrics.
is_div
(input, target)[source]¶ The Itakura–Saito divergence, which equal to β-divergence loss when β = 0.
\[\ell(x, y) = \sum_{n = 0}^{N - 1} \frac{x_n}{y_n} - log(\frac{x_n}{y_n}) - 1\]- Parameters
input (Tensor) – tensor of arbitrary shape
target (Tensor) – tensor of the same shape as input
- Returns
single element tensor
- Return type
Tensor
-
torchnmf.metrics.
kl_div
(input, target)[source]¶ The generalized Kullback-Leibler divergence Loss, which equal to β-divergence loss when β = 1.
The loss can be described as:
\[\ell(x, y) = \sum_{n = 0}^{N - 1} x_n log(\frac{x_n}{y_n}) - x_n + y_n\]- Parameters
input (Tensor) – tensor of arbitrary shape
target (Tensor) – tensor of the same shape as input
- Returns
single element tensor
- Return type
Tensor
-
torchnmf.metrics.
sparseness
(x)[source]¶ The sparseness measure proposed in Non-negative Matrix Factorization with Sparseness Constraints, can be caculated as:
\[f(x) = \frac{\sqrt{N} - \frac{\sum_{n=0}^{N-1} |x_n|}{\sqrt{\sum_{n=0}^{N-1} x_n^2}}}{\sqrt{N} - 1}\]- Parameters
x (Tensor) – tensor of arbitrary shape
- Returns
single element tensor with value range between 0 (the most sparse) to 1 (the most dense)
- Return type
Tensor
torchnmf.trainer¶
torchnmf.trainer
is a package implementing various parameter updating algorithms for NMF, and is based on
the same optimizer interface from torch.optim
.
Taking an update step¶
Because current available trainer reevaluate the function multiple times, a closure function is required in each step. The closure should clear the gradients, compute output (or even the loss), and return it.
for i in range(iterations):
def closure():
trainer.zero_grad()
return target, model()
trainer.step(closure)
For torchnmf.trainer.SparsityProj
:
for i in range(iterations):
def closure():
trainer.zero_grad()
output = model()
loss = loss_fn(output, target)
return loss
trainer.step(closure)
Algorithms¶
-
class
torchnmf.trainer.
BetaMu
(params, beta=1, l1_reg=0, l2_reg=0, orthogonal=0)[source]¶ Implements the classic multiplicative updater for NMF models minimizing β-divergence.
Note
To use this optimizer, not only make sure your model parameters are non-negative, but the gradients along the whole computational graph are always non-negative.
- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
beta (float, optional) – beta divergence to be minimized, measuring the distance between target and the NMF model. Default:
1.
l1_reg (float, optional) – L1 regularize penalty. Default:
0.
.l2_reg (float, optional) – L2 regularize penalty (weight decay). Default:
0.
orthogonal (float, optional) – orthogonal regularize penalty. Default:
0.
-
class
torchnmf.trainer.
SparsityProj
(params, sparsity, dim=1, max_iter=10)[source]¶ Implements parseness constrainted gradient projection method described in Non-negative Matrix Factorization with Sparseness Constraints.
- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
sparsity (float) – the target sparseness for params, with 0 < sparsity < 1
dim (int, optional) – dimension over which to compute the sparseness for each parameter. Default:
1
max_iter (int, optional) – maximal number of function evaluations per optimization step. Default:
10