{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Regularized Poisson NMF on a toy dataset\n", "\n", "This notebook is part of the ESPM package. It is available on [github](https://github.com/adriente/espm/blob/main/notebooks/toy-ML.ipynb) \n", "\n", "In this notebook, we show how to solve a regularized Poisson Non Negative Matrix Factorization (NMF) on a toy dataset. The general goal of this problem is to separate a matrix $X$ into two matrices $D$ and $H$ such that $X \\approx DH$. \n", "We will assume that $D$ is the product of two matrices $G$ and $W$ such that $D = GW$. The matrix $G$ is assumed to be known and therefore $W$ is the matrix we want to find. In the field of electro-microscopy, the matrix $GW$ represent the phases, the matrix $H$ is the matrix of maps (or weights) and $X$ is the data matrix. The maps $H$ are 2D images.\n", "Furthermore, we will assume that $W, H$ are non-negative or more generally greater than a small positive value $\\epsilon$. \n", "\n", "The data measured data $X$ is follow a Poisson distribution, i.e. $X_{ij} \\sim \\text{Poisson}(n_p\\dot{X}_{ij})$. In the context of electro-microscopy, $n_p$ is the number of photons per pixel. $\\dot{X} = G \\dot{W} \\dot{H}$ would correspond a noiseless measurement. \n", "\n", "The size of:\n", "\n", "* $X$ is (n, p),\n", "* $W$ is (m, k),\n", "* $H$ is (k, p),\n", "* $G$ is (n, m).\n", "\n", "In general, m is assumed to be much smaller than n and $G$ is assumed to be known. The parameter `shape_2d` defines the shape of the images for the columns of $H$ and $X$, i.e. `shape_2d[0]*shape_2d[1] = p`.\n", "\n", "Mathematically. the problem can be formulated as:\n", "$$\\dot{W}, \\dot{H} = \\arg\\min_{W\\geq\\epsilon, H\\geq\\epsilon, \\sum_i H_{ij} = 1} D_{GKL}(X || GWH) + \\lambda tr ( H^\\top \\Delta H) + \\mu \\sum_{i,j} (\\log H_{ij} + \\epsilon_{reg})$$\n", "\n", "Here $D_{GKL}$ is the fidelity term, i.e. the Generalized KL divergence \n", "\n", "$$D_{GKL}(X \\| Y) = \\sum_{i,j} X_{ij} \\log \\frac{X_{ij}}{Y_{ij}} - X_{ij} + Y_{ij} $$\n", "\n", "The loss is regularized using two terms: a Laplacian regularization on $H$ and a log regularization on $H$. \n", "$\\lambda$ and $\\mu$ are the regularization parameters.\n", "The Laplacian regularization is defined as:\n", "\n", "$$ \\lambda tr ( H^\\top \\Delta H) $$\n", "\n", "where $\\Delta$ is the Laplacian operator (it can be created using the function :mod:`espm.utils.create_laplacian_matrix`). **Note that the columns of the matrices $H$ and $X$ are assumed to be images.** \n", "\n", "The log regularization is defined as:\n", "\n", "$$ \\mu \\sum_{i,j} (\\log H_{ij} + \\epsilon_{reg}) $$\n", "\n", "where $\\epsilon_{reg}$ is the slope of log regularization at 0. This term acts similarly to an L1 penalty but affects less larger values. \n", "\n", "Finally, we assume $W,H\\geq \\epsilon$ and that the lines of $H$ sum to 1: $$\\sum_i H_{ij} = 1.$$ This is done by adding the parameter `simplex_H=True` to the class `espm.estimators.SmoothNMF`. This constraint is essential as it prevent the algorithm to find a solution where all the maps tend to 0 because of the other constraints. Note that the same constraint could be applied to $W$ using `simplex_W=True`. These two constraint should not be activated at the same time!\n", "\n", "\n", "In this notebook, we will use the class `espm.estimators.SmoothNMF` to solve the problem.\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Imports and function definition\n", "\n", "Let's start by importing the necessary libraries." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Some useful modules for notebooks\n", "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from espm.estimators import SmoothNMF\n", "from espm.measures import find_min_angle, ordered_mse, ordered_mae, ordered_r2\n", "from espm.models import ToyModel\n", "from espm.weights import generate_weights as gw\n", "from espm.datasets.base import generate_spim_sample\n", "from espm.utils import process_losses" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now we define the parameters that will be used to generate the data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "seed = 42 # for reproducibility\n", "\n", "m = 15 # Number of components\n", "n = 200 # Length of the phases\n", "\n", "n_poisson = 300 # Average poisson number per pixel (this number will be splitted on the m dimension)\n", "\n", "densities = np.random.uniform(0.1, 2.0, 3) # Random densities\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_toy_sample():\n", " model_params = {\"L\": n, \"C\": m, \"K\": 3, \"seed\": seed}\n", " misc_params = {\"N\": n_poisson, \"seed\": seed, 'densities' : densities, \"model\": \"ToyModel\"}\n", "\n", " toy_model = ToyModel(**model_params)\n", " toy_model.generate_phases()\n", " phases = toy_model.phases.T\n", " weights = gw.generate_weights(\"toy_problem\", None)\n", "\n", " sample = generate_spim_sample(phases, weights, model_params,misc_params, seed = seed)\n", " return sample\n", "\n", "def to_vec(X):\n", " n = X.shape[2] \n", " return X.transpose(2,0,1).reshape(n, -1)\n", "\n", "sample = get_toy_sample()\n", "\n", "GW = sample[\"GW\"].T\n", "G = sample[\"G\"]\n", "H = to_vec(sample[\"H\"])\n", "X = to_vec(sample[\"X\"])\n", "Xdot = to_vec(sample[\"Xdot\"])\n", "shape_2d = sample[\"shape_2d\"]\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let us look at the dimension of our problem." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"\"\"\n", "- X: {X.shape}\n", "- Xdot: {Xdot.shape}\n", "- G: {G.shape}\n", "- GW: {GW.shape}\n", "- H: {H.shape}\n", "- shape_2d: {shape_2d}\n", "\"\"\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Here `Xdot` contains the noisless data such that the ground truth `H` and `GW` satisfies:\n", "$$X = GWH$$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.testing.assert_allclose(Xdot, GW @ H)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let us plot the ground truth maps $H$ (sometimes also called weights). Here the variable `shape_2d` is used to reshape the columns of $H$ into images." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "vmin, vmax = 0,1\n", "cmap = plt.cm.gray_r\n", "plt.figure(figsize=(10, 3))\n", "for i, hdot in enumerate(H):\n", " plt.subplot(1,3,i+1)\n", " plt.imshow(H[i].reshape(shape_2d), cmap=cmap, vmin=vmin, vmax=vmax)\n", " plt.axis(\"off\")\n", " plt.title(f\"GT - Map {i+1}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can also plot the phases corresponding to these maps. The phases corresponds to the matrix $GW$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "e = np.linspace(0,1, GW.shape[0])\n", "plt.plot(e, GW)\n", "plt.title(\"GT - Phases\")\n", "plt.xlabel(\"Frequency [normalized]\")\n", "plt.legend([f\"phase {i+1}\" for i in range(3)]);" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The matrix $G$ is assumed to be known appriori. Here we plot the first 5 lines of $G$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "l = np.linspace(0, 1,n)\n", "plt.plot(l, G[:,:5]);" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Note that the function `create_toy_sample` did not return $W$ but only $GW$. \n", "\n", "Using the ground truth $GW$, it can be computed as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "W = np.linalg.lstsq(GW, G, rcond=None)[0].T\n", "W.shape" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Solving the problem" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Parameters\n", "\n", "Let us define the different hyperparameters of the problem. Feel free to change them and see how the results change." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let us first define our regularisation parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lambda_L = 2 # Smoothness of the maps\n", "mu = 0.05 # Sparsity of the maps" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can additionally add a simplex constraint to the problem by setting `simplex_H=True`. This will add the constraint that the line of $H$ sum to 1. This constraint should be activated for the regularizers to work!. Practically, it will prevent the algorithm from simply minimizing $W$ and increasing $H$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "simplex_H = True\n", "simplex_W = False\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can decide to specify $G$ or not. If the matrix $G$ is not specified, the algorithm will directly estimate $GW$ insead of $W$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Gused = G" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Note that wihtout regularisation, i.e. with the parameters\n", "```python\n", "lambda_L = 0\n", "mu = 0 \n", "Gused = None\n", "simplex_H=False\n", "simplex_W = False\n", "``` \n", "we recover the classical Poisson/KL NMF problem. Our algorithm will apply the MU algorithm from Lee and Seung (2001)." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let us define the parameters for the algorithm. Here the class `espm.estimator.SmoothNMF` heritates from sckit-learn's `NMF` class. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "K = len(H) # Number of components / let assume that we know it\n", "params = {}\n", "params[\"tol\"]=1e-6 # Tolerance for the stopping criterion.\n", "params[\"max_iter\"] = 200 # Maximum number of iterations before timing out.\n", "params[\"hspy_comp\"] = False # If should be set to True if hspy data format is used.\n", "params[\"verbose\"] = 1 # Verbosity level.\n", "params[\"eval_print\"] = 10 # Print the evaluation every eval_print iterations.\n", "params[\"epsilon_reg\"] = 1 # Regularization parameter \n", "params[\"linesearch\"] = False # Use linesearch to accelerate convergence\n", "params[\"shape_2d\"] = shape_2d # Shape of the 2D maps\n", "params[\"n_components\"] = K # Number of components\n", "params[\"normalize\"] = True # Normalize the data. It helps to keep the range of the regularization parameters lambda_L and mu in a reasonable range.\n", "\n", "estimator = SmoothNMF(mu=mu, lambda_L=lambda_L, G = Gused, simplex_H=simplex_H, simplex_W=simplex_W, **params)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The function `fit_transform` will solve the problem and return the matrix $GW$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "GW_est = estimator.fit_transform(X)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can plot the different losses to see how the algorithm converges." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "losses = estimator.get_losses()\n", "\n", "values, names = process_losses(losses)\n", "\n", "plt.figure(figsize=(10, 6))\n", "for name, value in zip(names[:4], values[:4]):\n", " plt.plot(value, label=name)\n", " plt.xlabel(\"iteration\")\n", "plt.yscale(\"log\")\n", "plt.legend()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "losses[0]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can easily recover $W$ and $H$ from the estimator." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "H_est = estimator.H_\n", "W_est = estimator.W_" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let us compute some metrics to evaluate the quality of the solution. We will use the metrics defined in the module `espm.metrics`.\n", "\n", "We first compute the angles between the phases `GW` and the estimated phases `GW_est`. The angles are computed using the function `compute_angles`. This function also return the indices to match the estimations and the ground truths. Indeed, there is no reason for them to have the same order. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "angles, true_inds = find_min_angle(GW.T, GW_est.T, unique=True, get_ind=True)\n", "print(\"angles : \", angles)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can compute the Mean Squared Error (MSE) between the maps `H` and the estimated maps `H_est`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mse = ordered_mse(H, H_est, true_inds)\n", "print(\"mse : \", mse)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let us plot the results. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "nbsphinx-thumbnail" ] }, "outputs": [], "source": [ "\n", "def plot_results(Ddot, D, Hdotflat, Hflat):\n", " fontsize = 30\n", " scale = 15\n", " aspect_ratio = 1.4\n", " marker_list = [\"-o\",\"-s\",\"->\",\"-<\",\"-^\",\"-v\",\"-d\"]\n", " mark_space = 20\n", " # cmap = plt.cm.hot_r \n", " cmap = plt.cm.gray_r\n", " vmax = 1\n", " vmin = 0\n", " K = len(H)\n", " L = len(D)\n", " \n", " angles, true_inds = find_min_angle(Ddot.T, D.T, unique=True, get_ind=True)\n", " mse = ordered_mse(Hdotflat, Hflat, true_inds)\n", " mae = ordered_mae(Hdotflat, Hflat, true_inds)\n", " r2 = ordered_r2(Hdotflat, Hflat, true_inds)\n", "\n", "\n", " fig, axes = plt.subplots(K,3,figsize = (scale/K * 3 * aspect_ratio,scale))\n", " x = np.linspace(0,1, num = L)\n", " for i in range(K): \n", " axes[2,i].plot(x,Ddot.T[i,:],'g-',label='ground truth',linewidth=4)\n", " axes[2,i].plot(x,D[:,true_inds[i]],'r--',label='reconstructed',linewidth=4)\n", " axes[2,i].set_title(\"{:.2f} deg\".format(angles[i]),fontsize = fontsize-2)\n", " axes[2,i].set_xlim(0,1)\n", "\n", " axes[1,i].imshow((Hflat[true_inds[i],:]).reshape(shape_2d),vmin = vmin, vmax = vmax , cmap=cmap)\n", " axes[1,i].set_title(\"R2: {:.2f}\".format(r2[true_inds[i]]),fontsize = fontsize-2)\n", " # axes[i,1].set_ylim(0.0,1.0)\n", " axes[1,i].tick_params(axis = \"both\",labelleft = False, labelbottom = False,left = False, bottom = False)\n", "\n", " im = axes[0,i].imshow(Hdotflat[i].reshape(shape_2d),vmin = vmin, vmax = vmax, cmap=cmap)\n", " axes[0,i].set_title(\"Phase {}\".format(i),fontsize = fontsize)\n", " axes[0,i].tick_params(axis = \"both\",labelleft = False, labelbottom = False,left = False, bottom = False)\n", " axes[2,0].legend()\n", "\n", " rows = [\"True maps\",\"Reconstructed maps\",\"Spectra\"]\n", "\n", " for ax, row in zip(axes[:,0], rows):\n", " ax.set_ylabel(row, rotation=90, fontsize=fontsize)\n", "\n", "\n", " fig.subplots_adjust(right=0.84)\n", " # put colorbar at desire position\n", " cbar_ax = fig.add_axes([0.85, 0.5, 0.01, 0.3])\n", " fig.colorbar(im,cax=cbar_ax)\n", "\n", " # fig.tight_layout()\n", "\n", " print(\"angles : \", angles)\n", " print(\"mse : \", mse)\n", " print(\"mae : \", mae)\n", " print(\"r2 : \", r2)\n", "\n", " return fig\n", "\n", "fig = plot_results(GW, GW_est, H, H_est)\n", "plt.show() \n" ] } ], "metadata": { "kernelspec": { "display_name": "espm", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" }, "vscode": { "interpreter": { "hash": "8d9f0750808bbb89d6238996d547b1a0ab25cae94b245e8c4572708d26b443fd" } } }, "nbformat": 4, "nbformat_minor": 4 }