import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter
from .nmf import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _single, _pair, _triple
from tqdm import tqdm
from typing import Union, Iterable, Optional, List, Tuple
from collections.abc import Iterable as Iterabc
from .metrics import kl_div
from .constants import eps
__all__ = [
'PLCA', 'SIPLCA', 'SIPLCA2', 'SIPLCA3', 'BaseComponent'
]
def _log_probability(V, WZH, W, Z, H, W_alpha, Z_alpha, H_alpha):
return V.view(-1) @ WZH.add(eps).log().view(-1) + W.add(eps).log().mul(W_alpha - 1).sum() + H.add(eps).log().mul(
H_alpha - 1).sum() + Z.add(eps).log().mul(Z_alpha - 1).sum()
@torch.no_grad()
def get_norm(x: Tensor):
if x.ndim > 1:
sum_dims = list(range(x.dim()))
sum_dims.remove(1)
norm = x.sum(sum_dims, keepdim=True)
else:
norm = x.sum()
return norm
[docs]class BaseComponent(torch.nn.Module):
r"""Base class for all PLCA modules.
You can't use this module directly.
Your models should also subclass this class.
Args:
rank (int): size of hidden dimension
W (tuple or Tensor): size or initial probabilities of template tensor W
H (tuple or Tensor): size or initial probabilities of activation tensor H
Z (Tensor): initial probabilities of latent vector Z
trainable_W (bool): controls whether template tensor W is trainable when initial probabilities is given. Default: ``True``
trainable_H (bool): controls whether activation tensor H is trainable when initial probabilities is given. Default: ``True``
trainable_Z (bool): controls whether latent vector Z is trainable when initial probabilities is given. Default: ``True``
Attributes:
W (Tensor or None): the template tensor of the module if corresponding argument is given.
If size is given, values are initialized randomly.
H (Tensor or None): the activation tensor of the module if corresponding argument is given.
If size is given, values are initialized randomly.
Z (Tensor or None): the latent vector of the module if corresponding argument or rank is given.
If rank is given, values are initialized uniformly.
"""
__constants__ = ['rank']
__annotations__ = {'W': Optional[Tensor],
'H': Optional[Tensor],
'Z': Optional[Tensor],
'out_channels': Optional[int],
'kernel_size': Optional[Tuple[int, ...]]}
rank: int
W: Optional[Tensor]
H: Optional[Tensor]
Z: Optional[Tensor]
out_channels: Optional[int]
kernel_size: Optional[Tuple[int, ...]]
def __init__(self,
rank: int = None,
W: Union[Iterable[int], Tensor] = None,
H: Union[Iterable[int], Tensor] = None,
Z: Tensor = None,
trainable_W: bool = True,
trainable_H: bool = True,
trainable_Z: bool = True):
super().__init__()
infer_rank = None
if isinstance(W, Tensor):
assert torch.all(W >= 0.), "Tensor W should be non-negative."
self.register_parameter('W', Parameter(
torch.empty(*W.size()), requires_grad=trainable_W))
self.W.data.copy_(W)
elif isinstance(W, Iterabc):
self.register_parameter('W', Parameter(torch.randn(*W).abs()))
else:
self.register_parameter('W', None)
if getattr(self, "W") is not None:
self.W.data.div_(get_norm(self.W))
infer_rank = self.W.shape[1]
if isinstance(H, Tensor):
assert torch.all(H >= 0.), "Tensor H should be non-negative."
H_shape = H.shape
self.register_parameter('H', Parameter(
torch.empty(*H_shape), requires_grad=trainable_H))
self.H.data.copy_(H)
elif isinstance(H, Iterabc):
self.register_parameter('H', Parameter(torch.randn(*H).abs()))
else:
self.register_parameter('H', None)
if getattr(self, "H") is not None:
self.H.data.div_(get_norm(self.H))
infer_rank = self.H.shape[1]
if isinstance(Z, Tensor):
assert Z.ndim == 1, "Z should be one dimensional."
assert torch.all(Z >= 0.), "Tensor Z should be non-negative."
rank = Z.numel()
self.register_parameter('Z', Parameter(
torch.empty(rank), requires_grad=trainable_Z))
self.Z.data.copy_(Z)
elif isinstance(rank, int):
self.register_parameter('Z', Parameter(torch.ones(rank) / rank))
else:
self.register_parameter('Z', None)
if getattr(self, "Z") is not None:
self.Z.data.div_(get_norm(self.Z))
infer_rank = self.Z.shape[0]
if infer_rank is None:
assert rank, "A rank should be given when W, H and Z are not available!"
else:
if getattr(self, "Z") is not None:
assert self.Z.shape[0] == infer_rank, "Latent size of Z does not match with others!"
if getattr(self, "H") is not None:
assert self.H.shape[1] == infer_rank, "Latent size of H does not match with others!"
if getattr(self, "W") is not None:
assert self.W.shape[1] == infer_rank, "Latent size of W does not match with others!"
self.out_channels = self.W.shape[0]
if self.W.ndim > 2:
self.kernel_size = self.W.shape[2:]
rank = infer_rank
self.rank = rank
[docs] def forward(self, H: Tensor = None, W: Tensor = None, Z: Tensor = None, norm: float = None) -> Tensor:
r"""An outer wrapper of :meth:`self.reconstruct(H,W,Z) <torchnmf.plca.BaseComponent.reconstruct>`.
.. note::
Should call the :class:`BaseComponent` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Args:
H(Tensor, optional): input activation tensor H. If no tensor was given will use :attr:`H` from this module
instead
W(Tensor, optional): input template tensor W. If no tensor was given will use :attr:`W` from this module
instead
Z(Tensor, optional): input latent vector Z. If no tensor was given will use :attr:`Z` from this module
instead
norm(float, optional): a scaling value multiply on output before return. Default: ``1``
Returns:
Tensor: tensor
"""
if H is None:
H = self.H
if W is None:
W = self.W
if Z is None:
Z = self.Z
result = self.reconstruct(H, W, Z)
if norm is None:
return result
return result * norm
[docs] @staticmethod
def reconstruct(H: Tensor, W: Tensor, Z: Tensor) -> Tensor:
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
"""
raise NotImplementedError
[docs] def fit(self,
V: Tensor,
tol: float = 1e-4,
max_iter: int = 200,
verbose: bool = False,
W_alpha: Union[float, Tensor] = 1.,
H_alpha: Union[float, Tensor] = 1.,
Z_alpha: Union[float, Tensor] = 1.):
r"""Learn a PLCA model for the data V by maximizing the following log probability of V
and model params :math:`\theta` using EM algorithm:
.. math::
\mathcal{L} (\theta)= \sum_{k_1...k_N} v_{k_1...k_N}\log{\hat{v}_{k_1...k_N}} \\
+ \sum_k (\alpha_{z,k} - 1) \log z_k \\
+ \sum_{f_1...f_M} (\alpha_{w,f_1...f_M} - 1) \log w_{f_1...f_M} \\
+ \sum_{\tau_1...\tau_L} (\alpha_{h,\tau_1...\tau_L} - 1) \log h_{\tau_1...\tau_L} \\
Where :math:`\hat{V}` is the reconstructed output, N is the number of dimensions of target tensor :math:`V`,
M is the number of dimensions of tensor :math:`W`, and L is the number of dimensions of tensor :math:`H`.
The last three terms come from Dirichlet prior assumption.
To invoke this function, attributes :meth:`H <torchnmf.plca.BaseComponent.H>`,
:meth:`W <torchnmf.plca.BaseComponent.H>` and :meth:`Z <torchnmf.plca.BaseComponent.Z>`
should be presented in this module.
Args:
V (Tensor): data tensor to be decomposed
tol (float): tolerance of the stopping condition. Default: ``1e-4``
max_iter (int): maximum number of iterations before timing out. Default: ``200``
verbose (bool): whether to be verbose. Default: ``False``
W_alpha (float): hyper parameter of Dirichlet prior on W.
Can be a scalar or a tensor that is broadcastable to W.
Set it to one to have no regularization. Default: ``1``
H_alpha (float): hyper parameter of Dirichlet prior on H.
Can be a scalar or a tensor that is broadcastable to H.
Set it to one to have no regularization. Default: ``1``
Z_alpha (float): hyper parameter of Dirichlet prior on Z.
Can be a scalar or a tensor that is broadcastable to Z.
Set it to one to have no regularization. Default: ``1``
Returns:
tuple: a length-2 tuple with first element is total number of iterations, and the second is the sum of tensor V
"""
assert torch.all(V >= 0.), "Target should be non-negative."
W = self.W
H = self.H
Z = self.Z
norm = V.sum()
V = V.contiguous() / norm
with torch.no_grad():
WZH = self.reconstruct(H, W, Z)
loss_init = previous_loss = kl_div(
WZH * norm, V * norm).mul(2).sqrt().item()
with tqdm(total=max_iter, disable=not verbose) as pbar:
for n_iter in range(max_iter):
self.zero_grad()
WZH = self.reconstruct(H, W, Z)
WZH.backward(V / WZH.add(eps))
Z_prior = None
if Z.requires_grad:
Z.data.mul_(Z.grad.relu())
Z_prior = Z.clone()
if Z_alpha != 1:
Z.data.add_(Z_alpha - 1)
F.threshold(Z.data, eps, eps, True)
Z.data.div_(Z.sum())
if W.requires_grad:
W.data.mul_(W.grad.relu())
if Z_prior is None:
W_divider = get_norm(W)
Z_prior = W_divider.squeeze()
else:
W_divider = Z_prior[(
slice(None),) + (None,) * (W.dim() - 2)]
W.data.div_(W_divider)
if W_alpha != 1:
W.data.add_(W_alpha - 1)
F.threshold(W.data, eps, eps, True)
W.data.div_(get_norm(W))
if H.requires_grad:
H.data.mul_(H.grad.relu())
if Z_prior is None:
H_divider = get_norm(H)
else:
H_divider = Z_prior[(
slice(None),) + (None,) * (H.dim() - 2)]
H.data.div_(H_divider)
if H_alpha != 1:
H.data.add_(H_alpha - 1)
F.threshold(H.data, eps, eps, True)
H.data.div_(get_norm(H))
if n_iter % 10 == 9:
with torch.no_grad():
WZH = self.reconstruct(H, W, Z)
loss = kl_div(WZH * norm, V *
norm).mul(2).sqrt().item()
log_pro = _log_probability(
V, WZH, W, Z, H, W_alpha, Z_alpha, H_alpha).item()
pbar.set_postfix(loss=loss, log_likelihood=log_pro)
pbar.update(10)
if (previous_loss - loss) / loss_init < tol:
break
previous_loss = loss
return n_iter, norm
[docs]class PLCA(BaseComponent):
r"""Probabilistic Latent Component Analysis (PLCA).
Estimate two marginals :math:`P(c|z)` and :math:`P(n|z)`, which is the matrix W and H,
and a prior :math:`P(z)` which is the vector Z, that approximate the observed :math:`P(n,c)`,
where :math:`P(n,c)` is obtained via ``V / V.sum()`` so the total probabilities sum to 1.
More precisely:
.. math::
P(n, c) \approx \sum_{z}P(c|z)P(z)P(n|z)
In matrix form:
.. math::
V \approx H diag(Z) W^T
Its formulation is very similar to NMF, but introduce an extra latent vector to incorporate probabilities concept.
Note:
If `Vshape` argument is given, the model will try to infer the size of :meth:`W <torchnmf.plca.BaseComponent.W>`,
:meth:`H <torchnmf.plca.BaseComponent.H>` and :meth:`Z <torchnmf.plca.BaseComponent.Z>`, and override arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`.
Args:
Vshape (tuple, optional): size of target matrix V
rank (int, optional): size of hidden dimension
**kwargs: arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`
Shape:
- V: :math:`(N, C)`
- W: :math:`(C, R)`
- H: :math:`(N, R)`
- Z: :math:`(R,)`
Examples::
>>> V = torch.rand(20, 30)
>>> m = PLCA(V.shape, 5)
>>> m.W.size()
torch.Size([30, 5])
>>> m.H.size()
torch.Size([20, 5])
>>> m.Z.size()
torch.Size([5])
>>> HZWt = m()
>>> HZWt.size()
torch.Size([20, 30])
"""
def __init__(self,
Vshape: Iterable[int] = None,
rank: int = None,
**kwargs):
if isinstance(Vshape, Iterabc):
M, K = Vshape
rank = rank if rank else K
kwargs['W'] = (K, rank)
kwargs['H'] = (M, rank)
super().__init__(rank, **kwargs)
@staticmethod
def reconstruct(H, W, Z):
return H @ (W * Z).T
[docs]class SIPLCA(BaseComponent):
r"""Shift Invariant Probabilistic Latent Component Analysis (SI-PLCA).
Estimate two marginals :math:`P(c,t|z)` and :math:`P(n,l|z)`,
which is the tensor W and H, and a prior :math:`P(z)` which is the vector Z,
that approximate the observed :math:`P(n,c,l)`, where :math:`P(n,c,l)` is obtained via ``V / V.sum()``
so the total probabilities sum to 1.
More precisely:
.. math::
P(n, c, l) \approx \sum_{z} \sum_{t} P(c,t|z)P(z)P(n,l-t|z)
Look at the paper: `Shift-Invariant Probabilistic Latent Component Analysis`_
by Paris Smaragdis and Bhiksha Raj (2007) for more details.
Note:
If `Vshape` argument is given, the model will try to infer the size of :meth:`W <torchnmf.plca.BaseComponent.W>`,
:meth:`H <torchnmf.plca.BaseComponent.H>` and :meth:`Z <torchnmf.plca.BaseComponent.Z>`, and override arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`.
Args:
Vshape (tuple, optional): size of target matrix V
rank (int, optional): size of hidden dimension
T (int, optional): size of the convolving window
**kwargs: arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`
Shape:
- V: :math:`(N, C, L_{out})`
- W: :math:`(C, R, T)`
- H: :math:`(N, R, L_{in})`
- Z: :math:`(R,)` where
.. math::
L_{in} = L_{out} - T + 1
Examples::
>>> V = torch.rand(33, 50).unsqueeze(0)
>>> m = SIPLCA(V.shape, 16, 3)
>>> m.W.size()
torch.Size([33, 16, 3])
>>> m.H.size()
torch.Size([1, 16, 48])
>>> m.Z.size()
torch.Size([16])
>>> HZWt = m()
>>> HZWt.size()
torch.Size([1, 33, 50])
.. _Shift-Invariant Probabilistic Latent Component Analysis:
https://paris.cs.illinois.edu/pubs/plca-report.pdf
"""
def __init__(self,
Vshape: Iterable[int] = None,
rank: int = None,
T: _size_1_t = 1,
**kwargs):
if isinstance(Vshape, Iterabc):
T, = _single(T)
batch, K, M = Vshape
rank = rank if rank else K
kwargs['W'] = (K, rank, T)
kwargs['H'] = (batch, rank, M - T + 1)
super().__init__(rank, **kwargs)
@staticmethod
def reconstruct(H, W, Z):
pad_size = W.shape[2] - 1
return F.conv1d(H, W.flip(2) * Z.view(-1, 1), padding=pad_size)
[docs]class SIPLCA2(BaseComponent):
r"""Shift Invariant Probabilistic Latent Component Analysis across 2 dimensions (SI-PLCA 2D).
Estimate two marginals :math:`P(c,k_1,k_2|z)` and :math:`P(n,l,m|z)`,
which is the tensor W and H, and a prior :math:`P(z)` which is the vector Z,
that approximate the observed :math:`P(n,c,l,m)`, where :math:`P(n,c,l,m)` is obtained via ``V / V.sum()``
so the total probabilities sum to 1.
More precisely:
.. math::
P(n,c,l,m) \approx \sum_{z} \sum_{k_1} \sum_{k_2} P(c,k_1,k_2|z)P(z)P(n,l-k_1,m-k_2|z)
Look at the paper: `Shift-Invariant Probabilistic Latent Component Analysis`_
by Paris Smaragdis and Bhiksha Raj (2007) for more details.
Note:
If `Vshape` argument is given, the model will try to infer the size of :meth:`W <torchnmf.plca.BaseComponent.W>`,
:meth:`H <torchnmf.plca.BaseComponent.H>` and :meth:`Z <torchnmf.plca.BaseComponent.Z>`, and override arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`.
Args:
Vshape (tuple, optional): size of target tensor V
rank (int, optional): size of hidden dimension
kernel_size (int or tuple, optional): size of the convolving kernel
**kwargs: arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`
Shape:
- V: :math:`(N, C, L_{out}, M_{out})`
- W: :math:`(C, R, \text{kernel_size}[0], \text{kernel_size}[1])`
- H: :math:`(N, R, L_{in}, M_{in})`
- Z: :math:`(R,)` where
.. math::
L_{in} = L_{out} - \text{kernel_size}[0] + 1
.. math::
M_{in} = M_{out} - \text{kernel_size}[1] + 1
Examples::
>>> V = torch.rand(33, 50).unsqueeze(0).unsqueeze(0)
>>> m = SIPLCA2(V.shape, 16, 3)
>>> m.W.size()
torch.Size([1, 16, 3, 3])
>>> m.H.size()
torch.Size([1, 16, 31, 48])
>>> m.Z.size()
torch.Size([16])
>>> HZWt = m()
>>> HZWt.size()
torch.Size([1, 1, 33, 50])
.. _Shift-Invariant Probabilistic Latent Component Analysis:
https://paris.cs.illinois.edu/pubs/plca-report.pdf
"""
def __init__(self,
Vshape: Iterable[int] = None,
rank: int = None,
kernel_size: _size_2_t = 1,
**kwargs):
if isinstance(Vshape, Iterabc):
kernel_size = _pair(kernel_size)
H, W = kernel_size
batch, channel, K, M = Vshape
rank = rank if rank else K
kwargs['W'] = (channel, rank,) + kernel_size
kwargs['H'] = (batch, rank, K - H + 1, M - W + 1)
super().__init__(rank, **kwargs)
@staticmethod
def reconstruct(H, W, Z):
pad_size = (W.shape[2] - 1, W.shape[3] - 1)
out = F.conv2d(H, W.flip((2, 3)) * Z.view(-1, 1, 1), padding=pad_size)
return out
[docs]class SIPLCA3(BaseComponent):
r"""Shift Invariant Probabilistic Latent Component Analysis across 3 dimensions (SI-PLCA 3D).
Estimate two marginals :math:`P(c,k_1,k_2,k_3|z)` and :math:`P(n,l,m,o|z)`,
which is the tensor W and H, and a prior :math:`P(z)` which is the vector Z,
that approximate the observed :math:`P(n,c,l,m,o)`, where :math:`P(n,c,l,m,o)` is obtained via ``V / V.sum()``
so the total probabilities sum to 1.
More precisely:
.. math::
P(n,c,l,m,o) \approx \sum_{z} \sum_{k_1} \sum_{k_2} \sum_{k_3} P(c,k_1,k_2,k_3|z)P(z)P(n,l-k_1,m-k_2,o-k_3|z)
Look at the paper: `Shift-Invariant Probabilistic Latent Component Analysis`_
by Paris Smaragdis and Bhiksha Raj (2007) for more details.
Note:
If `Vshape` argument is given, the model will try to infer the size of :meth:`W <torchnmf.plca.BaseComponent.W>`,
:meth:`H <torchnmf.plca.BaseComponent.H>` and :meth:`Z <torchnmf.plca.BaseComponent.Z>`, and override arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`.
Args:
Vshape (tuple, optional): size of target tensor V
rank (int, optional): size of hidden dimension
kernel_size (int or tuple, optional): size of the convolving kernel
**kwargs: arguments passed through to :meth:`BaseComponent <torchnmf.plca.BaseComponent>`
Shape:
- V: :math:`(N, C, L_{out}, M_{out}, O_{out})`
- W: :math:`(C, R, \text{kernel_size}[0], \text{kernel_size}[1], \text{kernel_size}[2])`
- H: :math:`(N, R, L_{in}, M_{in}, O_{in})`
- Z: :math:`(R,)` where
.. math::
L_{in} = L_{out} - \text{kernel_size}[0] + 1
.. math::
M_{in} = M_{out} - \text{kernel_size}[1] + 1
.. math::
O_{in} = O_{out} - \text{kernel_size}[2] + 1
Examples::
>>> V = torch.rand(3, 64, 64, 100).unsqueeze(0)
>>> m = SIPLCA3(V.shape, 8, (5, 5, 20))
>>> m.W.size()
torch.Size([3, 8, 5, 5, 20])
>>> m.H.size()
torch.Size([1, 8, 60, 60, 81])
>>> m.Z.size()
torch.Size([8])
>>> HZWt = m()
>>> HZWt.size()
torch.Size([1, 3, 64, 64, 100])
.. _Shift-Invariant Probabilistic Latent Component Analysis:
https://paris.cs.illinois.edu/pubs/plca-report.pdf
"""
def __init__(self,
Vshape: Iterable[int] = None,
rank: int = None,
kernel_size: _size_3_t = 1,
**kwargs):
if isinstance(Vshape, Iterabc):
kernel_size = _triple(kernel_size)
D, H, W = kernel_size
batch, channel, N, K, M = Vshape
rank = rank if rank else K
kwargs['W'] = (channel, rank) + kernel_size
kwargs['H'] = (batch, rank, N - D + 1, K - H + 1, M - W + 1)
super().__init__(rank, **kwargs)
@staticmethod
def reconstruct(H, W, Z):
pad_size = (W.shape[2] - 1, W.shape[3] - 1, W.shape[4] - 1)
out = F.conv3d(H, W.flip((2, 3, 4)) *
Z.view(-1, 1, 1, 1), padding=pad_size)
return out