Source code for pysap.base.transform
# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Wavelet transform base module.
"""
# System import
from pprint import pprint
import uuid
import os
from warnings import warn
# Package import
import pysap
from .utils import with_metaclass
from pysap.plotting import plot_transform
# Third party import
import numpy
[docs]class MetaRegister(type):
""" Simple Python metaclass registry pattern.
"""
REGISTRY = {}
def __new__(cls, name, bases, attrs):
""" Allocation.
Parameters
----------
name: str
the name of the class.
bases: tuple
the base classes.
attrs:
the attributes defined for the class.
"""
new_cls = type.__new__(cls, name, bases, attrs)
if name in cls.REGISTRY:
raise ValueError(
"'{0}' name already used in registry.".format(name))
if name not in ("WaveletTransformBase", "ISAPWaveletTransformBase",
"PyWaveletTransformBase"):
cls.REGISTRY[name] = new_cls
return new_cls
[docs]class WaveletTransformBase(with_metaclass(MetaRegister)):
""" Data structure representing a signal wavelet decomposition.
Available transforms are define in 'pysap.transform'.
"""
def __init__(self, nb_scale, verbose=0, dim=2, use_wrapping=False,
**kwargs):
""" Initialize the WaveletTransformBase class.
Parameters
----------
data: numpy.ndarray
the input data.
nb_scale: int
the number of scale of the decomposition that includes the
approximation scale.
verbose: int, default 0
control the verbosity level.
dim: int, default 2
define the data dimension.
use_wrapping: bool, default False
if set, in the case of ISAP, use the command lines rather than the
bindings.
"""
# Wavelet transform parameters
self.nb_scale = nb_scale
self.name = None
self.bands_names = None
self.nb_band_per_scale = None
self.is_decimated = None
self.data_dim = dim
self.use_wrapping = use_wrapping
# Data that can be decalred afterward
self._data = None
self._image_metadata = {}
self._data_shape = None
self._iso_shape = None
self._analysis_data = None
self._analysis_shape = None
self._analysis_header = None
self._analysis_buffer_shape = None
self.verbose = verbose
self.kwargs = kwargs
# Transformation
self._init_transform(**self.kwargs)
def __reduce__(self):
""" The interface to pickle dump call.
Return
------
reduced_instance: tuple,
two or five items long tuple to define the init of a pickled
instance.
"""
return (self.__class__, (self.nb_scale, self.verbose))
def __getitem__(self, given):
""" Access the analysis designated scale/band coefficients.
Parameters
----------
given: int, slice or tuple
the scale and band indices.
Returns
-------
coeffs: numpy.ndarray or list
the analysis coefficients.
"""
# Convert given index to generic scale/band index
if not isinstance(given, tuple):
given = (given, slice(None))
# Check that we have a valid given index
if len(given) != 2:
raise ValueError("Expect a scale/band int or 2-uplet index.")
# Check some data are stored in the structure
if self._analysis_data is None:
raise ValueError("Please specify first the decomposition "
"coefficients array.")
# Handle multi-dim slice object
if isinstance(given[0], slice):
start = given[0].start or 0
stop = given[0].stop or self.nb_scale
step = given[0].step or 1
coeffs = [self.__getitem__((index, given[1]))
for index in range(start, stop, step)]
elif isinstance(given[1], slice):
start = given[1].start or 0
stop = given[1].stop or self.nb_band_per_scale[given[0]]
step = given[1].step or 1
coeffs = [self.band_at(given[0], index)
for index in range(start, stop, step)]
else:
coeffs = [self.band_at(given[0], given[1])]
# Format output
if len(coeffs) == 1:
coeffs = coeffs[0]
return coeffs
def __setitem__(self, given, array):
""" Set the analysis designated scale/band coefficients.
Parameters
----------
given: tuple
the scale and band indices.
array: numpy.ndarray
the specific scale/band data as an array.
"""
# Check that we have a valid given index
if len(given) != 2:
raise ValueError("Expect a scale/band int or 2-uplet index.")
# Check given index
if isinstance(given[0], slice) or isinstance(given[1], slice):
raise ValueError("Expect a scale/band int index (no slice).")
# Check some data are stored in the structure
if self._analysis_data is None:
raise ValueError("Please specify first the decomposition "
"coefficients array.")
# Handle multidim slice object
if isinstance(given[0], slice):
start = given[0].start or 0
stop = given[0].stop or self.nb_scale
step = given[0].step or 1
coeffs = [self.__getitem__((index, given[1]))
for index in range(start, stop, step)]
elif isinstance(given[1], slice):
start = given[1].start or 0
stop = given[1].stop or self.nb_band_per_scale[given[0]]
step = given[1].step or 1
coeffs = [self.band_at(given[0], index)
for index in range(start, stop, step)]
else:
coeffs = [self.band_at(given[0], given[1])]
# Format output
if len(coeffs) == 1:
coeffs = coeffs[0]
return coeffs
##########################################################################
# Properties
##########################################################################
[docs] def _set_data(self, data):
""" Set the input data array.
Parameters
----------
data: numpy.ndarray or pysap.Image
input data/signal.
"""
if self.verbose > 0 and self._data is not None:
print("[info] Replacing existing input data array.")
# Ensure that the shape is square except when the family is pywt
if self.__family__ != 'pywt' and \
not all([e == data.shape[0] for e in data.shape]):
raise ValueError("Expect a square shape data.")
if data.ndim != self.data_dim:
raise ValueError("This wavelet can only be applied on {0}D "
"square images".format(self.data_dim))
if self.is_decimated and not (data.shape[0] // 2**(self.nb_scale) > 0):
raise ValueError("Can't decimate the data with the specified "
"number of scales.")
if isinstance(data, pysap.Image):
self._data = data.data
self._image_metadata = data.metadata
else:
self._data = data
self._data_shape = self._data.shape
self._iso_shape = self._data_shape[0]
if self.use_wrapping:
self._set_transformation_parameters()
self._compute_transformation_parameters()
[docs] def _get_data(self):
""" Get the input data array.
Returns
-------
data: numpy.ndarray
input data/signal.
"""
return self._data
[docs] def _set_analysis_data(self, analysis_data):
""" Set the decomposition coefficients array.
Parameters
----------
analysis_data: list
decomposition coefficients array.
"""
if self.verbose > 0 and self._analysis_data is not None:
print("[info] Replacing existing decomposition coefficients "
"array.")
if len(analysis_data) != sum(self.nb_band_per_scale):
raise ValueError("The wavelet coefficients do not correspond to "
"the wavelet transform parameters.")
self._analysis_data = analysis_data
[docs] def _get_analysis_data(self):
""" Get the decomposition coefficients array.
Returns
-------
analysis_data: numpy.ndarray
decomposition coefficients array.
"""
return self._analysis_data
[docs] def _set_analysis_header(self, analysis_header):
""" Set the decomposition coefficients header.
Parameters
----------
analysis_header: dict
decomposition coefficients array.
"""
if self.verbose > 0 and self._analysis_header is not None:
print("[info] Replacing existing decomposition coefficients "
"header.")
self._analysis_header = analysis_header
[docs] def _get_analysis_header(self):
""" Get the decomposition coefficients header.
Returns
-------
analysis_header: dict
decomposition coefficients header.
"""
return self._analysis_header
[docs] def _get_info(self):
""" Return the transformation information. This iformation is only
available when using the Python bindings.
"""
if hasattr(self.trf, "info"):
self.trf.info()
data = property(_get_data, _set_data)
analysis_data = property(_get_analysis_data, _set_analysis_data)
analysis_header = property(_get_analysis_header, _set_analysis_header)
info = property(_get_info)
##########################################################################
# Public members
##########################################################################
[docs] @classmethod
def bands_shapes(cls, bands_lengths, ratio=None):
""" Return the different bands associated shapes given there lengths.
Parameters
----------
bands_lengths: numpy.ndarray (<nb_scale>, max(<nb_band_per_scale>, 0))
array holding the length between two bands of the data
vector per scale.
ratio: numpy.ndarray, default None
a array containing ratios for eeach scale and each band.
Returns
-------
bands_shapes: list of list of 2-uplet (<nb_scale>, <nb_band_per_scale>)
structure holding the shape of each bands at each scale.
"""
if ratio is None:
ratio = numpy.ones_like(bands_lengths)
bands_shapes = []
for band_number, scale_data in enumerate(bands_lengths):
scale_shapes = []
for scale_number, scale_padd in enumerate(scale_data):
shape = (
int(numpy.sqrt(
scale_padd * ratio[band_number, scale_number])),
int(numpy.sqrt(
scale_padd / ratio[band_number, scale_number])))
scale_shapes.append(shape)
bands_shapes.append(scale_shapes)
return bands_shapes
[docs] def show(self):
""" Display the different bands at the different decomposition scales.
"""
warn(
'The show method has been deprecated and will be removed in a '
+ 'future release. In the future Please use the Transform.data '
+ 'and/or Transform.analysis_data attributes and your plotting '
+ 'package of choice. To use this deprecated function you will '
+ 'need to install pyqtgraph manually.',
FutureWarning
)
plot_transform(self)
[docs] def analysis(self, **kwargs):
""" Decompose a real or complex signal using ISAP.
Fill the instance 'analysis_data' and 'analysis_header' parameters.
Parameters
----------
kwargs: dict (optional)
the parameters that will be passed to
'pysap.extensions.mr_tansform'.
"""
# Checks
if self._data is None:
raise ValueError("Please specify first the input data.")
# Analysis
if numpy.iscomplexobj(self._data):
analysis_data_real, self.analysis_header = self._analysis(
self._data.real, **kwargs)
analysis_data_imag, _ = self._analysis(
self._data.imag, **kwargs)
if isinstance(analysis_data_real, numpy.ndarray):
self._analysis_data = (
analysis_data_real + 1.j * analysis_data_imag)
else:
self._analysis_data = [
re + 1.j * ima
for re, ima in zip(analysis_data_real, analysis_data_imag)]
else:
self._analysis_data, self._analysis_header = self._analysis(
self._data, **kwargs)
[docs] def synthesis(self):
""" Reconstruct a real or complex signal from the wavelet coefficients
using ISAP.
Returns
-------
data: pysap.Image
the reconstructed data/signal.
"""
# Checks
if self._analysis_data is None:
raise ValueError("Please specify first the decomposition "
"coefficients array.")
if self.use_wrapping and self._analysis_header is None:
raise ValueError("Please specify first the decomposition "
"coefficients header.")
# Message
if self.verbose > 1:
print("[info] Synthesis header:")
pprint(self._analysis_header)
# Reorganize the coefficents with ISAP convention
# TODO: do not backup the list of bands
if self.use_wrapping:
analysis_buffer = numpy.zeros(
self._analysis_buffer_shape, dtype=self.analysis_data[0].dtype)
for scale, nb_bands in enumerate(self.nb_band_per_scale):
for band in range(nb_bands):
self._set_linear_band(scale, band, analysis_buffer,
self.band_at(scale, band))
_saved_analysis_data = self._analysis_data
self._analysis_data = analysis_buffer
self._analysis_data = [self.unflatten_fct(self)]
# Synthesis
if numpy.iscomplexobj(self._analysis_data[0]):
data_real = self._synthesis(
[arr.real for arr in self._analysis_data],
self._analysis_header)
data_imag = self._synthesis(
[arr.imag for arr in self._analysis_data],
self._analysis_header)
data = data_real + 1.j * data_imag
else:
data = self._synthesis(
self._analysis_data, self._analysis_header)
# TODO: remove this code asap
if self.use_wrapping:
self._analysis_data = _saved_analysis_data
return pysap.Image(data=data, metadata=self._image_metadata)
[docs] def band_at(self, scale, band):
""" Get the band at a specific scale.
Parameters
----------
scale: int
index of the scale.
band: int
index of the band.
Returns
-------
band_data: nd-arry
the requested band data array.
"""
# Message
if self.verbose > 1:
print("[info] Accessing scale '{0}' and band '{1}'...".format(
scale, band))
# Get the band array
index = numpy.sum(self.nb_band_per_scale[:scale]).astype(int) + band
band_data = self.analysis_data[index]
return band_data
##########################################################################
# Private members
##########################################################################
[docs] def _init_transform(self):
""" Define the transform.
Attributes
----------
trf: object
the transformation.
"""
raise NotImplementedError("Abstract method should not be declared "
"in derivate classes.")
[docs] def _get_linear_band(self, scale, band, analysis_data):
""" Access the desired band data from a 1D linear analysis buffer.
Parameters
----------
scale: int
index of the scale.
band: int
index of the band.
analysis_data: numpy.ndarray (N, )
the analysis buffer.
Returns
-------
band_data: nd-arry (M, )
the requested band buffer.
"""
# Compute selected scale/band start/stop indices
start_scale_padd = self.scales_padds[scale]
start_band_padd = (
self.bands_lengths[scale, :band + 1].sum()
- self.bands_lengths[scale, band])
start_padd = start_scale_padd + start_band_padd
stop_padd = start_padd + self.bands_lengths[scale, band]
# Get the band array
band_data = analysis_data[start_padd: stop_padd].reshape(
self.bands_shapes[scale][band])
return band_data
[docs] def _set_linear_band(self, scale, band, analysis_data, band_data):
""" Set the desired band data in a 1D linear analysis buffer.
Parameters
----------
scale: int
index of the scale.
band: int
index of the band.
analysis_data: numpy.ndarray (N, )
the analysis buffer.
band_data: numpy.ndarray (M, M)
the band data to be added in the analysis buffer.
Returns
-------
analysis_data: nd-arry (N, )
the updated analysis buffer.
"""
# Compute selected scale/band start/stop indices
start_scale_padd = self.scales_padds[scale]
start_band_padd = (
self.bands_lengths[scale, :band + 1].sum()
- self.bands_lengths[scale, band])
start_padd = start_scale_padd + start_band_padd
stop_padd = start_padd + self.bands_lengths[scale, band]
# Get the band array
analysis_data[start_padd: stop_padd] = band_data.flatten()
return analysis_data
[docs] def _set_transformation_parameters(self):
""" Define the transformation class parameters.
Attributes
----------
name: str
the name of the decomposition.
bands_names: list of str
the name of the different bands.
flatten_fct: callable
a function used to reorganize the ISAP decomposition coefficients,
see 'pysap/extensions/formating.py' module for more details.
unflatten_fct: callable
a function used to reorganize the decomposition coefficients using
ISAP convention, see 'pysap/extensions/formating.py' module for
more details.
is_decimated: bool
True if the decomposition include a decimation of the
band number of coefficients.
nb_band_per_scale: numpy.ndarray (<nb_scale>, )
vector of int holding the number of band per scale.
bands_lengths: numpy.ndarray (<nb_scale>, max(<nb_band_per_scale>, 0))
array holding the length between two bands of the data
vector per scale.
bands_shapes: list of list of 2-uplet (<nb_scale>, <nb_band_per_scale>)
structure holding the shape of each bands at each scale.
isap_transform_id: int
the label of the ISAP transformation.
"""
raise NotImplementedError("Abstract method should not be declared "
"in derivate classes.")
[docs] def _compute_transformation_parameters(self):
""" Compute information in order to split scale/band flatten data.
Attributes
----------
scales_lengths: numpy.ndarray (<nb_scale>, )
the length of each band.
scales_padds: numpy.ndarray (<nb_scale> + 1, )
the index of the data associated to each scale.
"""
if self.bands_lengths is None:
raise ValueError(
"The transformation parameters have not been set.")
self.scales_lengths = self.bands_lengths.sum(axis=1)
self.scales_padds = numpy.zeros((self.nb_scale + 1, ), dtype=int)
self.scales_padds[1:] = self.scales_lengths.cumsum()
[docs] def _analysis(self, data, **kwargs):
""" Decompose a real signal using ISAP.
Parameters
----------
data: numpy.ndarray
a real array to be decomposed.
kwargs: dict (optional)
the parameters that will be passed to
'pysap.extensions.mr_tansform'.
Returns
-------
analysis_data: numpy.ndarray
the decomposition coefficients.
analysis_header: dict
the decomposition associated information.
"""
raise NotImplementedError("Abstract method should not be declared "
"in derivate classes.")
[docs] def _synthesis(self, analysis_data, analysis_header):
""" Reconstruct a real signal from the wavelet coefficients using ISAP.
Parameters
----------
analysis_data: list of numpy.ndarray
the wavelet coefficients array.
analysis_header: dict
the wavelet decomposition parameters.
Returns
-------
data: numpy.ndarray
the reconstructed data array.
"""
raise NotImplementedError("Abstract method should not be declared "
"in derivate classes.")