r"""Utils for the ESPM package"""
import numpy as np
from scipy.sparse import lil_matrix, block_diag
from scipy.optimize import nnls
from espm.conf import SYMBOLS_PERIODIC_TABLE, NUMBER_PERIODIC_TABLE
import json
from exspy.misc.material import atomic_to_weight, density_of_mixture
from functools import wraps
import re
_qtg_widgets = []
_plt_figures = []
[docs]
def process_losses(losses):
r""" Process the losses to be plotted
Parameters
----------
losses: np.ndarray
Array of losses (output of `espm.estimators.NMFEstimator.get_losses` method)
Returns
-------
values: np.ndarray
Array of values
names: list
List of names
"""
names = losses.dtype.names
values = [[] for _ in names]
for data in losses:
for i, d in enumerate(data):
values[i].append(d)
values = np.array(values)
return values, names
[docs]
def create_laplacian_matrix(nx, ny=None):
r"""
Helper method to create the laplacian matrix for the laplacian regularization
Parameters
----------
:param nx: height of the original image
:param ny: width of the original image
Returns
-------
:rtype: scipy.sparse.csr_matrix
:return:the n x n laplacian matrix, where n = nx*ny
"""
if ny is None:
ny = nx
assert(nx>1)
assert(ny>1)
#Blocks corresponding to the corner of the image (linking row elements)
top_block=lil_matrix((ny,ny),dtype=np.float32)
top_block.setdiag([2]+[3]*(ny-2)+[2])
top_block.setdiag(-1,k=1)
top_block.setdiag(-1,k=-1)
#Blocks corresponding to the middle of the image (linking row elements)
mid_block=lil_matrix((ny,ny),dtype=np.float32)
mid_block.setdiag([3]+[4]*(ny-2)+[3])
mid_block.setdiag(-1,k=1)
mid_block.setdiag(-1,k=-1)
#Construction of the diagonal of blocks
list_blocks=[top_block]+[mid_block]*(nx-2)+[top_block]
blocks=block_diag(list_blocks)
#Diagonals linking different rows
blocks.setdiag(-1,k=ny)
blocks.setdiag(-1,k=-ny)
return blocks
[docs]
def rescaled_DH(D,H) :
r"""Rescale the matrices D and H such that the columns of H sums approximately to one.
:param np.array 2D D: n x k matrix
:param np.array 2D H: k x m matrix
:return: D_rescale, H_rescale
:rtype: np.array 2D, np.array 2D
"""
_, p = H.shape
o = np.ones((p,))
s = np.linalg.lstsq(H.T, o, rcond=None)[0]
if (s<=0).any():
s = np.maximum(nnls(H.T, o)[0], 1e-10)
D_rescale = D@np.diag(1/s)
H_rescale = np.diag(s)@H
return D_rescale, H_rescale
[docs]
def bin_spim(data,n,m):
r"""
Take a 3D array of size (x,y,k) [px, py, e]
Returns a 3D array of size (n,m,k) [new_px, new_py, e]
"""
# return a matrix of shape (n,m,k)
bs = data.shape[0]//n,data.shape[1]//m # blocksize averaged over
k = data.shape[2]
return np.reshape(np.array([np.sum(data[k1*bs[0]:(k1+1)*bs[0],k2*bs[1]:(k2+1)*bs[1]],axis=(0,1)) for k1 in range(n) for k2 in range(m)]),(n,m,k))
[docs]
def number_to_symbol_dict (func) :
r"""
Decorator
Takes a dict of elements (a.k.a chemical composition) with atomic numbers as keys (e.g. 26 for Fe)
returns a dict of elements with symbols as keys (e.g. Fe for iron)
"""
@wraps(func)
def inner(*args,**kwargs) :
elts_dict = kwargs["elements_dict"]
new_dict = {}
with open(NUMBER_PERIODIC_TABLE,"r") as f :
NPT = json.load(f)["table"]
for key in elts_dict.keys() :
if is_symbol(key) :
new_dict[key] = elts_dict[key]
elif is_number(key) :
new_dict[NPT[str(key)]["symbol"]] = elts_dict[key]
else :
raise ValueError("Input has to be either atomic number, either chemical symbols")
kwargs["elements_dict"] = new_dict
return func(*args,**kwargs)
return inner
[docs]
def symbol_to_number_dict (func) :
r"""
Decorator
Takes a dict of elements (a.k.a chemical composition) with symbols as keys (e.g. Fe for iron)
returns a dict of elements with atomic numbers as keys (e.g. 26 for iron)
"""
@wraps(func)
def inner(*args,**kwargs) :
elts_dict = kwargs["elements_dict"]
new_dict = {}
with open(SYMBOLS_PERIODIC_TABLE,"r") as f :
SPT = json.load(f)["table"]
for key in elts_dict.keys() :
if is_number(key) :
new_dict[int(key)] = elts_dict[key]
elif is_symbol(key) :
new_dict[SPT[key]["number"]] = elts_dict[key]
else :
raise ValueError("Input has to be either atomic number, either chemical symbols")
kwargs["elements_dict"] = new_dict
return func(*args,**kwargs)
return inner
[docs]
def symbol_to_number_list (func) :
r"""
Decorator
Takes a dict of elements (a.k.a chemical composition) with symbols as keys (e.g. Fe for iron)
returns a dict of elements with atomic numbers as keys (e.g. 26 for iron)
"""
@wraps(func)
def inner(*args,**kwargs) :
elts_list = kwargs["elements"]
new_list = []
with open(SYMBOLS_PERIODIC_TABLE,"r") as f :
SPT = json.load(f)["table"]
for key in elts_list :
if is_number(key) :
new_list.append(int(key))
elif is_symbol(key) :
new_list.append(SPT[key]["number"])
else :
raise ValueError("Input has to be either atomic number, either chemical symbols")
kwargs["elements"] = new_list
return func(*args,**kwargs)
return inner
[docs]
def number_to_symbol_list (func) :
r"""
Decorator
Takes a dict of elements (a.k.a chemical composition) with symbols as keys (e.g. Fe for iron)
returns a dict of elements with atomic numbers as keys (e.g. 26 for iron)
"""
@wraps(func)
def inner(*args,**kwargs) :
elts_list = kwargs["elements"]
new_list = []
with open(NUMBER_PERIODIC_TABLE,"r") as f :
NPT = json.load(f)["table"]
for key in elts_list :
if is_number(key) :
new_list.append(NPT[str(key)]["symbol"])
elif is_symbol(key) :
new_list.append(key)
else :
raise ValueError("Input has to be either atomic number, either chemical symbols")
kwargs["elements"] = new_list
return func(*args,**kwargs)
return inner
[docs]
@number_to_symbol_dict
def atomic_to_weight_dict (*,elements_dict = {}) :
r"""
Wrapper to the atomic_to_weight function of hyperspy. Takes a dict of chemical composition expressed in atomic fractions.
Returns a dict of chemical composition expressed in atomic weight fratiom.
"""
if len(elements_dict.keys()) == 0 :
return elements_dict
else :
list_elts = []
list_at = []
for elt in elements_dict.keys() :
list_elts.append(elt)
list_at.append(elements_dict[elt])
list_wt = atomic_to_weight(list_at,list_elts)/100
new_dict = {}
for i, elt in enumerate(list_elts) :
new_dict[elt] = list_wt[i]
return new_dict
[docs]
@number_to_symbol_dict
def approx_density(atomic_fraction = False,*,elements_dict = {}) :
r"""
Wrapper to the density_of_mixture function of hyperspy. Takes a dict of chemical composition expressed in atomic weight fractions.
Returns an approximated density.
"""
if len(elements_dict.keys()) == 0 :
return 1.0
else :
list_elts = []
list_wt = []
if atomic_fraction :
elements_dict = atomic_to_weight_dict(elements_dict = elements_dict)
for elt in elements_dict.keys() :
list_elts.append(elt)
list_wt.append(elements_dict[elt])
return density_of_mixture(list_wt,list_elts)
[docs]
def arg_helper(params, d_params, replace = True):
r""" Check if all parameter of d_params are in params. If not, they are added to params with the default value.
Parameters
----------
params : dict
Dictionary of parameters to be checked.
d_params : dict
Dictionary of default parameters.
Returns
-------
params : dict
Dictionary of parameters with the default parameters added if not present.
"""
for key in d_params.keys():
params[key] = params.get(key, d_params[key])
if isdict(params[key]) and isdict(d_params[key]):
params[key] = arg_helper(params[key], d_params[key], replace=replace)
check_keys(params, d_params, replace = replace)
return params
[docs]
def check_keys(params, d_params, upperkeys = '',toprint = True, replace = True):
r""" Check if all parameter of d_params are in params. If not, they are added to params with the default value.
Parameters
----------
params : dict
Dictionary of parameters to be checked.
d_params : dict
Dictionary of default parameters.
upperkeys : str
String of the upper keys.
toprint : bool
If True, print the warning.
Returns
-------
params : dict
Dictionary of parameters with the default parameters added if not present.
Examples
--------
>>> params = {'a':1,'b':2}
>>> d_params = {'a':1,'b':2,'c':3}
>>> check_keys(params,d_params)
>>> params
{'a': 1, 'b': 2, 'c': 3}
"""
keys = set(d_params.keys())
for key in params.keys():
if key not in keys:
if toprint :
print('Warning! Optional argument: {}[\'{}\'] specified by user but not used'.format(upperkeys,key))
else:
if isdict(params[key]):
# if not(isdict(d_params[key])):
# print('Warning! Optional argument: {}{} is not supposed to be a dictionary'.format(upperkeys,key))
# else:
# check_keys(params[key],d_params[key],upperkeys=upperkeys+'[\'{}\']'.format(key))
if isdict(d_params[key]):
if toprint :
check_keys(params[key],d_params[key],upperkeys=upperkeys+'[\'{}\']'.format(key), toprint = toprint, replace = replace)
else:
if replace :
# If we prefer to keep the values of the default parameters
pass
else :
# If we prefer to let the values of the default parameters unchanged
# useful in EDS_espm to keep the original metadata
params[key] = d_params[key]
return True
[docs]
def isdict(p):
r"""Return True if the variable a dictionary.
:param p: variable to check
:type p: any
:return: True if p is a dictionary
:rtype: bool
"""
return type(p) is dict
[docs]
def is_symbol (i) :
r""" Return True if i is a chemical symbol
:param i: variable to check
:type i: any
:return: True if i is a chemical symbol
:rtype: bool
"""
symb_list = symbol_list()
if i in symb_list :
return True
else :
return False
[docs]
def is_number (i) :
r""" Return True if i is a number
:param i: variable to check
:type i: any
:return: True if i is a number
:rtype: bool
"""
try :
int(i)
return True
except ValueError :
return False
[docs]
def symbol_list () :
symbol_list = []
with open(NUMBER_PERIODIC_TABLE,"r") as f :
NPT = json.load(f)["table"]
for num in NPT.keys() :
symbol_list.append(NPT[num]["symbol"])
return symbol_list
[docs]
def close_all():
r"""Close all opened windows."""
import matplotlib.pyplot as plt
global _qtg_widgets
for widget in _qtg_widgets:
widget.close()
_qtg_widgets = []
global _plt_figures
for fig in _plt_figures:
plt.close(fig)
_plt_figures = []
[docs]
def get_explained_intensity_W(G, W, H) :
r""" Compute the explained intensity of each element of W.
:param np.array 2D G: G matrix of the ESpM-NMF decomposition
:param np.array 2D W: W matrix of the ESpM-NMF decomposition
:param np.array 2D H: H matrix of the ESpM-NMF decomposition
:return: np.array 2D
"""
# I couldn't find a linear algebra trick
int_matrix = np.zeros(W.shape)
for i in range(W.shape[0]) :
for j in range(W.shape[1]) :
int_matrix[i,j] = np.sum(G[:,i, np.newaxis]*W[i,j]*H[np.newaxis,j,:])
return int_matrix