import numpy as np
from espm.estimators.updates import multiplicative_step_h, multiplicative_step_w, multiplicative_step_hq, proj_grad_step_h, proj_grad_step_w, gradH, gradW, estimate_Lipschitz_bound_h, estimate_Lipschitz_bound_w
from espm.measures import trace_xtLx, log_reg
from espm.estimators import NMFEstimator
from espm.estimators.surrogates import diff_surrogate, quadratic_surrogate
from espm.conf import dicotomy_tol, sigmaL
from copy import deepcopy
# from espm.measures import KL_loss_surrogate, KLdiv_loss, log_reg, log_surrogate
[docs]
class SmoothNMF(NMFEstimator):
r"""SmoothNMF - NMF with a smooth regularization term
We encourage to read the example available in the documentation:
https://espm.readthedocs.io/en/latest/introduction/notebooks/toy-problem.html
The corresponding notebook is available on github:
https://github.com/adriente/espm/blob/main/notebooks/toy-ML.ipynb
The class `SmoothNMF` implements the regularized NMF algorithm. It solves problems of the form:
.. math::
\dot{W}, \dot{H} = \arg \min_{W \geq \epsilon, H \geq \epsilon} D_{GKL}(X || GWH) +
\lambda_L tr(H \Delta H^\top) + \mu \sum_{ij} \log(H_{ij} + \epsilon_{reg})
where:
- `D_{GKL}` is the Generalized KL divergence loss function defined as:
.. math::
D_{GKL}(X || Y) = \sum_{i,j} X_{ij} \log \frac{X_{ij}}{Y_{ij}} - X_{ij} + Y_{ij}
See the documentation of the class :mod:`espm.estimators.NMFEstimator` for more details.
- `\Delta` is the Laplacian operator (it can be created using the function `create_laplacian_matrix` from the `utils` module).
- `\epsilon_{reg}` is the slope of the log regularization/sparsity at 0 (you probably want to leave this to 1).
- `\lambda_L` is a regularization parameter, which encourages smoothness in the columns of `H`.
- `\mu` is a regularization parameter, which is similar to an L1 sparsity penalty.
The size of:
- `X` is `(n, p)`
- `W` is `(m, k)`
- `H` is `(k, p)`
- `G` is `(n, m)`
The columns of the matrices `H` and `X` are assumed to be images, typically for the smoothness regularization.
The parameter `shape_2d` defines the shape of the images, i.e., `shape_2d[0]*shape_2d[1] = p`.
Parameters
----------
lambda_L : float, default=1.0
Regularization parameter for the smooth regularization term.
linesearch : bool, default=False
If True, use a line search to find the step size.
mu : float, default=0
Regularization parameter for the log regularization/sparsity term.
epsilon_reg : float, default=1
Slope of the log regularization/sparsity at 0.
algo : str, default="log_surrogate"
Algorithm to use for the smooth regularization term. Can be "log_surrogate", "l2_surrogate", or "projected_gradient".
simplex_H : bool, default=False
If True, force the solution of H to be in the simplex.
simplex_W : bool, default=True
If True, force the solution of W to be in the simplex.
dicotomy_tol : float, default=1e-3
Tolerance for the dichotomy algorithm.
gamma : float, default=None
Initial value for the step size. If None, it is set to the Lipschitz constant of the gradient.
**kwargs : dict
Additional parameters for the `NMFEstimator` class.
"""
loss_names_ = NMFEstimator.loss_names_ + ["log_reg_loss"] + ["Lapl_reg_loss"] + ["gamma"]
# args and kwargs are copied from the init to the super instead of capturing them in *args and **kwargs to be scikit-learn compliant.
def __init__(self, lambda_L = 0.0, linesearch=False, mu=0, epsilon_reg=1, algo="log_surrogate", dicotomy_tol=dicotomy_tol, gamma=None, **kwargs):
super().__init__( **kwargs)
self.lambda_L = lambda_L
self.linesearch = linesearch
self.mu = mu
self.epsilon_reg = epsilon_reg
self.dicotomy_tol = dicotomy_tol
assert algo in ["l2_surrogate", "log_surrogate", "projected_gradient", "bmd"]
self.algo = algo
self.gamma = gamma
self.check_params()
[docs]
def check_params(self) :
assert self.algo in ["l2_surrogate", "log_surrogate", "projected_gradient", "bmd"], "The algorithm must be 'l2_surrogate', 'log_surrogate', 'bmd' or 'projected_gradient'"
assert self.lambda_L >= 0 and self.epsilon_reg > 0.0 and np.all(np.array(self.mu)>=0), "The regularization parameters must be positive"
assert (self.simplex_H and not(self.simplex_W)) or (not(self.simplex_H) and self.simplex_W) or (not(self.simplex_H) and not(self.simplex_W)), "Only one of simplex_H and simplex_W can be True"
if self.linesearch:
assert not self.l2
assert self.lambda_L > 0, "The regularization parameter lambda_L must be non-zero when using linesearch"
if self.algo=="l2_surrogate":
assert not self.l2, "The l2 parameter must be False when using l2_surrogate"
def _iteration(self, W, H):
# KL_surr = KL_loss_surrogate(self.X_, W, H, H, eps=0)
# log_surr = log_surrogate(H, H, mu=self.mu, epsilon=self.epsilon_reg)
# print("loss before:", KL_surr, log_surr, log_surr+KL_surr)
if self.n_iter_ == 0:
if self.gamma is None:
if self.algo in ["l2_surrogate", "log_surrogate", "bmd"]:
self.gamma_ = sigmaL
else:
gamma_W = estimate_Lipschitz_bound_w(self.log_shift, self.X_, self.G_, k=self.n_components)
gamma_H = estimate_Lipschitz_bound_h(self.log_shift,
self.X_,
self.G_,
k=self.n_components,
lambda_L=self.lambda_L,
mu=self.mu,
epsilon_reg=self.epsilon_reg)
self.gamma_ = [gamma_H, gamma_W]
else:
self.gamma_ = deepcopy(self.gamma)
# 1. Update for H
if self.linesearch:
Hold = H.copy()
if self.algo=="l2_surrogate":
H = multiplicative_step_hq(self.X_,
self.G_,
W,
H,
simplex_H=self.simplex_H,
log_shift=self.log_shift,
safe=self.debug,
dicotomy_tol=self.dicotomy_tol,
lambda_L=self.lambda_L,
L=self.L_,
sigmaL=self.gamma_,
fixed_H=self.fixed_H)
elif self.algo=="log_surrogate":
H = multiplicative_step_h(self.X_,
self.G_,
W,
H,
simplex_H=self.simplex_H,
mu=self.mu,
log_shift=self.log_shift,
epsilon_reg=self.epsilon_reg,
safe=self.debug,
dicotomy_tol=self.dicotomy_tol,
lambda_L=self.lambda_L,
L=self.L_,
l2=self.l2,
fixed_H=self.fixed_H,
sigmaL=self.gamma_)
elif self.algo=="projected_gradient":
H = proj_grad_step_h(self.X_,
self.G_,
W,
H,
simplex_H=self.simplex_H,
mu=self.mu,
log_shift=self.log_shift,
epsilon_reg=self.epsilon_reg,
safe=self.debug,
dicotomy_tol=self.dicotomy_tol,
lambda_L=self.lambda_L,
L=self.L_,
l2=self.l2,
fixed_H=self.fixed_H,
gamma=self.gamma_[0])
elif self.algo=="bmd":
H = multiplicative_step_h(self.X_,
self.G_,
W,
H,
simplex_H=self.simplex_H,
mu=self.mu,
log_shift=self.log_shift,
epsilon_reg=self.epsilon_reg,
safe=self.debug,
dicotomy_tol=self.dicotomy_tol,
lambda_L=self.lambda_L,
L=self.L_,
l2=self.l2,
fixed_H=self.fixed_H,
sigmaL=self.gamma_,
use_bregman=True)
else:
raise ValueError("Unknown algorithm")
if self.linesearch:
if self.algo in ["l2_surrogate", "log_surrogate", "bmd"]:
d = diff_surrogate(Hold, H, L=self.L_, sigmaL=self.gamma_, algo=self.algo)
if d>0:
self.gamma_ = self.gamma_ / 1.05
else:
self.gamma_ = self.gamma_ * 1.5
else:
gradf_xt = gradH(self.X_,
self.G_,
W,
Hold,
mu= self.mu,
lambda_L=self.lambda_L,
L=self.L_,
epsilon_reg=self.epsilon_reg,
log_shift=self.log_shift,
safe=self.debug)
f_xt = self.loss(W, Hold, X = self.X_, average=False)
f_x = self.loss(W, H, X = self.X_, average=False)
g_xxt = quadratic_surrogate(H, Hold, f_xt, gradf_xt, self.gamma_[0])
d = g_xxt - f_x
if d>0:
self.gamma_[0] = self.gamma_[0] / 1.05
else:
self.gamma_[0] = self.gamma_[0] * 1.5
# 2. Update for W
if self.algo in ["l2_surrogate", "log_surrogate"]:
W = multiplicative_step_w(self.X_,
self.G_,
W,
H,
log_shift=self.log_shift,
safe=self.debug,
l2=self.l2,
simplex_W=self.simplex_W,
fixed_W=self.fixed_W,
physics_model=self.physics_model_)
elif self.algo=="bmd":
W = multiplicative_step_w(self.X_,
self.G_,
W,
H,
log_shift=self.log_shift,
safe=self.debug,
l2=self.l2,
simplex_W=self.simplex_W,
fixed_W=self.fixed_W,
use_bregman=True,
physics_model=self.physics_model_)
else:
if self.linesearch:
Wold = W.copy()
W = proj_grad_step_w(self.X_,
self.G_,
W,
H,
log_shift=self.log_shift,
safe=self.debug,
gamma=self.gamma_[1],
simplex_W=self.simplex_W)
if self.linesearch:
gradf_xt = gradW(self.X_, self.G_, Wold, H, log_shift=self.log_shift, safe=self.debug)
f_xt = self.loss(Wold, H, X = self.X_, average=False)
f_x = self.loss(W, H, X = self.X_, average=False)
g_xxt = quadratic_surrogate(W, Wold, f_xt, gradf_xt, self.gamma_[1])
d = g_xxt - f_x
if d>0:
self.gamma_[1] = self.gamma_[1] / 1.05
else:
self.gamma_[1] = self.gamma_[1] * 1.5
# KL_surr = KL_loss_surrogate(self.X_, W, H, Hold, eps=0)
# log_surr = log_surrogate(H, Hold, mu=self.mu, epsilon=self.epsilon_reg)
# print("surrogate before:", KL_surr, log_surr, log_surr+KL_surr)
# KL_surr = KL_loss_surrogate(self.X_, W, H, H, eps=0)
# log_surr = log_surrogate(H, H, mu=self.mu, epsilon=self.epsilon_reg)
# print("loss after:", KL_surr, log_surr, log_surr+KL_surr)
return W, H
[docs]
def loss(self, W, H, average=True, X = None):
"""Compute the loss function."""
lkl = super().loss(W, H, average=average, X = X)
reg = log_reg(H, self.mu, self.epsilon_reg, average=False)
if average:
reg = reg / self.GWH_numel_
self.detailed_loss_.append(reg)
l2 = 0.5 * self.lambda_L * trace_xtLx(self.L_, H.T, average=False)
if average:
l2 = l2 / self.GWH_numel_
self.detailed_loss_.append(l2)
if isinstance(self.gamma_, list):
self.detailed_loss_.append(self.gamma_[0])
else:
self.detailed_loss_.append(self.gamma_)
return lkl + reg + l2