Source code for probreg.gmmtree

from __future__ import division, print_function

from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import open3d as o3

from . import _gmmtree
from . import se3_op as so
from . import transformation as tf
from .log import log

EstepResult = namedtuple("EstepResult", ["moments"])
MstepResult = namedtuple("MstepResult", ["transformation", "q"])
MstepResult.__doc__ = """Result of Maximization step.

    Attributes:
        transformation (tf.Transformation): Transformation from source to target.
        q (float): Result of likelihood.
"""


[docs]class GMMTree: """GMM Tree Args: source (numpy.ndarray, optional): Source point cloud data. tree_level (int, optional): Maximum depth level of GMM tree. lambda_c (float, optional): Parameter that determine the pruning of GMM tree. lambda_s (float, optional): Parameter that tolerance for building GMM tree. tf_init_params (dict, optional): Parameters to initialize transformation. """ def __init__( self, source: Optional[np.ndarray] = None, tree_level: int = 2, lambda_c: float = 0.01, lambda_s: float = 0.001, tf_init_params: Dict = {}, ): self._source = source self._tree_level = tree_level self._lambda_c = lambda_c self._lambda_s = lambda_s self._tf_type = tf.RigidTransformation self._tf_result = self._tf_type(**tf_init_params) self._callbacks = [] if not self._source is None: self._nodes = _gmmtree.build_gmmtree(self._source, self._tree_level, self._lambda_s, 1.0e-4)
[docs] def set_source(self, source: np.ndarray) -> None: self._source = source self._nodes = _gmmtree.build_gmmtree(self._source, self._tree_level, self._lambda_s, 1.0e-4)
[docs] def set_callbacks(self, callbacks): self._callbacks = callbacks
[docs] def expectation_step(self, target: np.ndarray) -> EstepResult: res = _gmmtree.gmmtree_reg_estep(target, self._nodes, self._tree_level, self._lambda_c) return EstepResult(res)
[docs] def maximization_step(self, estep_res: EstepResult, trans_p: tf.Transformation) -> MstepResult: moments = estep_res.moments n = len(moments) amat = np.zeros((n * 3, 6)) bmat = np.zeros(n * 3) for i, m in enumerate(moments): if m[0] < np.finfo(np.float32).eps: continue lmd, nn = np.linalg.eigh(self._nodes[i][2]) s = m[1] / m[0] nn = np.multiply(nn, np.sqrt(m[0] / lmd)) sl = slice(3 * i, 3 * (i + 1)) bmat[sl] = np.dot(nn.T, self._nodes[i][1]) - np.dot(nn.T, s) amat[sl, :3] = np.cross(s, nn.T) amat[sl, 3:] = nn.T x, q, _, _ = np.linalg.lstsq(amat, bmat, rcond=-1) rot, t = so.twist_mul(x, trans_p.rot, trans_p.t) return MstepResult(tf.RigidTransformation(rot, t), q)
[docs] def registration(self, target: np.ndarray, maxiter: int = 20, tol: float = 1.0e-4) -> MstepResult: q = None for i in range(maxiter): t_target = self._tf_result.transform(target) estep_res = self.expectation_step(t_target) res = self.maximization_step(estep_res, self._tf_result) self._tf_result = res.transformation for c in self._callbacks: c(self._tf_result.inverse()) log.debug("Iteration: {}, Criteria: {}".format(i, res.q)) if not q is None and abs(res.q - q) < tol: break q = res.q return MstepResult(self._tf_result.inverse(), res.q)
[docs]def registration_gmmtree( source: Union[np.ndarray, o3.geometry.PointCloud], target: Union[np.ndarray, o3.geometry.PointCloud], maxiter: int = 20, tol: float = 1.0e-4, callbacks: List[Callable] = [], **kwargs: Any, ) -> MstepResult: """GMMTree registration Args: source (numpy.ndarray): Source point cloud data. target (numpy.ndarray): Target point cloud data. maxitr (int, optional): Maximum number of iterations to EM algorithm. tol (float, optional): Tolerance for termination. callback (:obj:`list` of :obj:`function`, optional): Called after each iteration. `callback(probreg.Transformation)` Keyword Args: tree_level (int, optional): Maximum depth level of GMM tree. lambda_c (float, optional): Parameter that determine the pruning of GMM tree. lambda_s (float, optional): Parameter that tolerance for building GMM tree. tf_init_params (dict, optional): Parameters to initialize transformation. Returns: MstepResult: Result of the registration (transformation, q) """ cv = lambda x: np.asarray(x.points if isinstance(x, o3.geometry.PointCloud) else x) gt = GMMTree(cv(source), **kwargs) gt.set_callbacks(callbacks) return gt.registration(cv(target), maxiter, tol)