Source code for probreg.gauss_transform

from __future__ import division, print_function

from typing import Optional

import numpy as np

from . import _ifgt


def _gauss_transform_direct(source: np.ndarray, target: np.ndarray, weights: np.ndarray, h: float) -> np.ndarray:
    """
    \sum_{j} weights[j] * \exp{ - \frac{||target[i] - source[j]||^2}{h^2} }
    """
    h2 = h * h
    fn = lambda t: np.dot(weights, np.exp(-np.sum(np.square(t - source), axis=1) / h2))
    return np.apply_along_axis(fn, 1, target)


[docs]class Direct(object): def __init__(self, source, h): self._source = source self._h = h
[docs] def compute(self, target: np.ndarray, weights: np.ndarray) -> np.ndarray: return _gauss_transform_direct(self._source, target, weights, self._h)
[docs]class GaussTransform(object): """Calculate Gauss Transform Args: source (numpy.ndarray): Source data. h (float): Bandwidth parameter of the Gaussian. eps (float): Small floating point used in Gauss Transform. sw_h (float): Value of the bandwidth parameter to switch between direct method and IFGT. """ def __init__(self, source: np.ndarray, h: float, eps: float = 1.0e-4, sw_h: float = 0.01): self._m = source.shape[0] if h < sw_h: self._impl = Direct(source, h) else: self._impl = _ifgt.Ifgt(source, h, eps)
[docs] def compute(self, target: np.ndarray, weights: Optional[np.ndarray] = None): """Compute gauss transform Args: target (numpy.ndarray): Target data. weights (numpy.ndarray): Weights of Gauss Transform. """ if weights is None: weights = np.ones(self._m) if weights.ndim == 1: return self._impl.compute(target, weights) elif weights.ndim == 2: return np.r_[[self._impl.compute(target, w) for w in weights]] else: raise ValueError("weights.ndim must be 1 or 2.")