Source code for xicsrt.optics._ShapeMesh

# -*- coding: utf-8 -*-
"""
.. Authors
    Novimir Pablant <npablant@pppl.gov>
    James Kring <jdk0026@tigermail.auburn.edu>
    Yevgeniy Yakusevich <eugenethree@gmail.com>
"""
import numpy as np

from xicsrt.util import profiler
from xicsrt.tools.xicsrt_doc import dochelper
from xicsrt.optics._ShapeObject import ShapeObject

from scipy.spatial import Delaunay
from scipy.spatial import cKDTree
from scipy.interpolate import CloughTocher2DInterpolator as Interpolator

[docs] @dochelper class ShapeMesh(ShapeObject): """ A shape that uses a mesh grid instead of an analytical shape. This implementation assumes that the mesh surface normal is generally in the local z = [0, 0, 1] direction; both triangulation and interpolation are done in the x-y plane and are therefore best behaved when this assumption holds. It is recommended to build optics in local coordinates with this assumption and set config['trace_local'] = True. Supplying a well behaved Delaunay Triangulation via config['mesh_delaunay'] can help to make the mesh work while raytracing in global coordinates, but problems will still be encountered if the surface normal is more aligned with the x or y directions. **Programming Notes** Raytracing of mesh optics is fundamentally slow, because of the need to find which mesh face is intersected by each ray. For the simplest implementations this requires testing each ray against each mesh face leading to the speed scaling as the number of mesh faces (or equivalently num_faces^2). Some optimization of the basic calculation been completed. The mesh_intersect_1 method implements the Möller–Trumbore algorithm and is the fastest pure python algorithm found so far. To further improve performance this class (optionally) also makes use of pre-selection with a coarse mesh. First the intersection with the coarse mesh is found for each ray. Then only the 8 nearby faces on the full mesh are checked for the final intersection location. This method improves the performance to (num_faces_coarse + 8). The current algorithm for pre-selection (mesh refinement) is not perfect in that the nearby faces are not always appropriately chosen leading to a small number of rays being 'lost'. These errors can be minimized by increasing the resolution of the coarse mesh and ensuring that the grid spacing is approximately equal in the x and y directions. Further performance improvement could be gained by using numba or jax. This would allow the Möller–Trumbore algorithm to be implemented as a loop (instead of being vectorized) where the calculation can be terminated early when no hit is found. .. Todo:: XicsrtOpticMesh: Improve the pre-selection (mesh refinement algorithm) to eliminate ray losses. The current method is as follows: 1. Calculate intersection with coarse grid. 2. Find the point on the fine grid closest to the intersection. 3. Test all faces on the fine grid that contain this point. The problem is that the closest point may not always be part of the face that actually has the intersection. This can happen if the fine and coarse grid have very different densities, but also even in the perfect case if the ray hits near the edge of a face and the grid density is not even in the x and y directions. What is needed is a better selection of nearby faces. There is also a potential to improve computational speed slightly by testing fewer faces on the fine grid. """
[docs] def default_config(self): """ mesh_points mesh_normals mesh_faces mesh_coarse_points mesh_coarse_normals mesh_coarse_faces mesh_interpolate mesh_refine """ config = super().default_config() config['mesh_points'] = None config['mesh_normals'] = None config['mesh_faces'] = None config['mesh_coarse_points'] = None config['mesh_coarse_normals'] = None config['mesh_coarse_faces'] = None config['mesh_interpolate'] = None config['mesh_refine'] = None return config
[docs] def check_param(self): super().check_param() if self.param['mesh_interpolate'] is None: self.param['mesh_interpolate'] = (self.param['mesh_normals'] is not None) elif self.param['mesh_interpolate']: if self.param['mesh_normals'] is None: raise Exception('Surface normal vectors must be defined in order to use mesh interpolation.') if self.param['mesh_refine'] is None: if self.param['mesh_coarse_points'] is not None: self.param['mesh_refine'] = True spread_x = np.max(self.param['mesh_points'][:, 0]) - np.min(self.param['mesh_points'][:, 0]) spread_y = np.max(self.param['mesh_points'][:, 1]) - np.min(self.param['mesh_points'][:, 1]) spread_z = np.max(self.param['mesh_points'][:, 2]) - np.min(self.param['mesh_points'][:, 2]) if (spread_z > spread_x) or (spread_z > spread_y): self.log.warning('Mesh is not oriented with the surface normals near the local z direction.\n' 'This may lead to unexpected and incorrect results.')
[docs] def initialize(self): super().initialize() self.mesh_initialize()
[docs] def intersect(self, rays): """ Calculate ray intersections with the mesh. """ profiler.start('mesh_intersect') if not self.param['mesh_refine']: xloc, mask, hits = self.mesh_intersect_1(rays, self.param['mesh']) else: xloc_c, mask_c, hits_c = self.mesh_intersect_1( rays, self.param['mesh_coarse'], ) num_rays_coarse = np.sum(mask_c) faces_idx, faces_mask = self.find_near_faces(xloc_c, self.param['mesh'], mask_c) xloc, mask, hits = self.mesh_intersect_2( rays, self.param['mesh'], mask_c, faces_idx, faces_mask, ) num_rays_fine = np.sum(mask) num_rays_lost = num_rays_coarse - num_rays_fine if not num_rays_lost == 0: self.log.warning(f'Rays lost in mesh refinement: {num_rays_lost:0.0f} of {num_rays_coarse:6.2e}') if self.param['mesh_interpolate']: xloc, norm = self.mesh_interpolate(xloc, self.param['mesh'], mask) else: norm = self.mesh_normals(hits, self.param['mesh'], mask) profiler.stop('mesh_intersect') return xloc, norm, mask
[docs] def mesh_interpolate(self, X, mesh, mask): profiler.start('mesh_interpolate') # Interpolate the z-coordinate of the intersection. # Here I assume that the surface normal is generally aligned in the # z direction. In that case the x and y intersections will be quite # accurate, and a correction in the z intersection will help correct # for the flatness of each triangle in the triangulation. X[:, 2] = mesh['interp']['z'](X[:, 0], X[:, 1]) normals = np.empty(X.shape) normals[:, 0] = mesh['interp']['normal_x'](X[:, 0], X[:, 1]) normals[:, 1] = mesh['interp']['normal_y'](X[:, 0], X[:, 1]) normals[:, 2] = mesh['interp']['normal_z'](X[:, 0], X[:, 1]) profiler.start('normalize') normals = np.einsum( 'i,ij->ij' ,1.0/np.linalg.norm(normals, axis=1) ,normals ,optimize=True) profiler.stop('normalize') profiler.stop('mesh_interpolate') return X, normals
[docs] def _mesh_precalc(self, points, normals, faces): profiler.start('_mesh_precalc') output = {} output['points'] = points output['normals'] = normals output['faces'] = faces # Perform 2D Delaunay triangulation using the x and y locations. # This will work fine in most cases, but may cause problems if the # optic is oriented so that the normal is in the x or y direction. # # It is recommended that mesh optics be built using a local # coordinate system that makes the x and y coordinates sensible for # 2d interpolation. delaunay = Delaunay(points[:, 0:2]) # If pre-triangulated faces were provided, use those for the ray # intersection Calculations if faces is None: faces = delaunay.simplices output['faces'] = faces if self.param['mesh_interpolate']: # Create a set of interpolators. profiler.start('Create Interpolators') interp = {} output['interp'] = interp interp['z'] = Interpolator(delaunay, points[:, 2].flatten()) interp['normal_x'] = Interpolator(delaunay, normals[:, 0].flatten()) interp['normal_y'] = Interpolator(delaunay, normals[:, 1].flatten()) interp['normal_z'] = Interpolator(delaunay, normals[:, 2].flatten()) profiler.stop('Create Interpolators') # Copying these makes the code easier to read, # but may increase memory usage for dense meshes. p0 = points[faces[..., 0], :] p1 = points[faces[..., 1], :] p2 = points[faces[..., 2], :] # Calculate the normals at each face. faces_center = np.mean(np.array([p0, p1, p2]), 0) faces_normal = np.cross((p0 - p1), (p2 - p1)) faces_normal /= np.linalg.norm(faces_normal, axis=1)[:, None] output['faces_center'] = faces_center output['faces_normal'] = faces_normal # Generate a tree for the points. points_tree = cKDTree(points) output['points_tree'] = points_tree # Build up a lookup table for the faces around each point. # This is currently slow for large arrays. points_idx = np.arange(len(points)) p_faces_idx, p_faces_mask = \ self.find_point_faces(points_idx, faces) output['p_faces_idx'] = p_faces_idx output['p_faces_mask'] = p_faces_mask # centers_tree = cKDTree(points) # output['centers_tree'] = centers_tree profiler.stop('_mesh_precalc') return output
[docs] def mesh_initialize(self): """ Pre-calculate a number of mesh properties that are needed in the other mesh methods. """ profiler.start('mesh_initialize') dummy = self._mesh_precalc( self.param['mesh_points'] , self.param['mesh_normals'] , self.param['mesh_faces']) self.param['mesh'] = {} for key in dummy: self.param['mesh'][key] = dummy[key] if self.param['mesh_coarse_points'] is not None: dummy = self._mesh_precalc( self.param['mesh_coarse_points'] , self.param['mesh_coarse_normals'] , self.param['mesh_coarse_faces']) self.param['mesh_coarse'] = {} for key in dummy: self.param['mesh_coarse'][key] = dummy[key] profiler.stop('mesh_initialize')
[docs] def mesh_intersect_1(self, rays, mesh): """ Find the intersection of rays with the mesh using the Möller–Trumbore algorithm. """ profiler.start('mesh_intersect_1') O = rays['origin'] D = rays['direction'] m = rays['mask'].copy() X = np.full(D.shape, np.nan, dtype=np.float64) # Copying these makes the code easier to read, # but may increase memory usage for dense meshes. p0 = mesh['points'][mesh['faces'][..., 0], :] p1 = mesh['points'][mesh['faces'][..., 1], :] p2 = mesh['points'][mesh['faces'][..., 2], :] epsilon = 1e-15 num_rays = len(m) hits = np.empty(num_rays, dtype=np.int64) m_temp = np.empty(num_rays, dtype=bool) m_temp_2 = np.zeros(num_rays, dtype=bool) for ii in range(mesh['faces'].shape[0]): m_temp[:] = m edge1 = p1[ii, :] - p0[ii, :] edge2 = p2[ii, :] - p0[ii, :] h = np.cross(D, edge2) f = np.einsum('i,ji->j', edge1, h, optimize=True) m_temp &= ~((f > -epsilon) & (f < epsilon)) if not np.any(m_temp): continue f = 1.0 / f s = O - p0[ii, :] u = f * np.einsum('ij,ij->i', s, h, optimize=True) m_temp &= ~((u < 0.0) | (u > 1.0)) if not np.any(m_temp): continue q = np.cross(s, edge1) v = f * np.einsum('ij,ij->i', D, q, optimize=True) m_temp &= ~((v < 0.0) | (u + v > 1.0)) if not np.any(m_temp): continue t = f * np.einsum('i,ji->j', edge2, q, optimize=True) # Update overall hit array and hit mask. m_temp_2[m_temp] = m_temp[m_temp] hits[m_temp] = ii X[m_temp] = O[m_temp] + t[m_temp, None] * D[m_temp, :] # Update the mask not to include any rays that didn't hit the mesh. m &= m_temp_2 profiler.stop('mesh_intersect_1') return X, m, hits
[docs] def mesh_intersect_2( self, rays, mesh, mask, faces_idx, faces_mask, ): """ Check for ray intersection with a list of mesh faces. Programming Notes ----------------- Because of the mesh indexing, the arrays here have different dimensions than in mesh_intersect_1, and need a different vectorization. At the moment I am using a less efficient mesh intersection method. This should be updated to use the same method as mesh_intersect_1, but with the proper vectorization. """ profiler.start('mesh_intersect_2') O = rays['origin'] D = rays['direction'] m = mask.copy() X = np.full(D.shape, np.nan, dtype=np.float64) num_rays = len(m) hits = np.empty(num_rays, dtype=np.int64) epsilon = 1e-15 # Copying these makes the code easier to read, # but may increase memory usage for dense meshes. faces = mesh['faces'][faces_idx] p0 = mesh['points'][faces[..., 0], :] p1 = mesh['points'][faces[..., 1], :] p2 = mesh['points'][faces[..., 2], :] n = mesh['faces_normal'][faces_idx] # distance = np.dot((p0 - O), n) / np.dot(D, n) t0 = p0[:, :, :] - O[None, :, :] t1 = np.einsum('ijk, ijk -> ij', t0, n, optimize=True) t2 = np.einsum('jk, ijk -> ij', D, n, optimize=True) dist = t1 / t2 t3 = np.einsum('jk,ij -> ijk', D, dist, optimize=True) intersect = t3 + O # Pre-calculate some vectors and do the calculation in a # single step in an attempt to optimize this calculation. a = intersect - p0 b = intersect - p1 c = intersect - p2 diff = (np.linalg.norm(np.cross(b, c), axis=2) + np.linalg.norm(np.cross(c, a), axis=2) + np.linalg.norm(np.cross(a, b), axis=2) - np.linalg.norm(np.cross((p0 - p1), (p0 - p2)), axis=2) ) # .. ToDo: # For now hard code the floating point tolerance. # A better way of handling floating point errors is needed. test = (diff < 1e-10) & (dist >= 0) & faces_mask m &= np.any(test, axis=0) # idx_hits tells us which of the 8 faces had a hit. idx_hits = np.argmax(test[:, m], axis=0) # Now index the faces_idx to git the actual face number. hits[m] = faces_idx[idx_hits, m] X[m] = intersect[idx_hits, m, :] profiler.stop('mesh_intersect_2') return X, m, hits
[docs] def mesh_normals(self, hits, mesh, mask): m = mask normals = np.zeros((len(m), 3), dtype=np.float64) normals[m, :] = mesh['faces_normal'][hits[m], :] return normals
[docs] def mesh_get_index(self, hits, faces): """ Match faces to face indexes, with a loop over faces. """ profiler.start('mesh_get_index') idx_hits = np.empty(hits.shape[0], dtype=np.int32) for ii, ff in enumerate(faces): m_temp = np.all(np.equal(ff, hits), axis=1) idx_hits[m_temp] = ii profiler.stop('mesh_get_index') return idx_hits
[docs] def find_point_faces(self, p_idx, faces, mask=None): """ Find all of the the faces that include a given mesh point. """ profiler.start('find_point_faces') if mask is None: mask = np.ones(p_idx.shape, dtype=np.bool_) m = mask p_faces_idx = np.zeros((8, len(m)), dtype=np.int32) p_faces_mask = np.zeros((8, len(m)), dtype=np.bool_) for ii_p in p_idx: ii_f = np.nonzero(np.equal(faces, p_idx[ii_p]))[0] faces_num = len(ii_f) p_faces_idx[:faces_num, ii_p] = ii_f p_faces_mask[:faces_num, ii_p] = True profiler.stop('find_point_faces') return p_faces_idx, p_faces_mask
[docs] def find_near_faces(self, X, mesh, mask): m = mask profiler.start('find_near_faces') idx = mesh['points_tree'].query(X[m])[1] faces_idx = np.zeros((8, len(m)), dtype=np.int32) faces_mask = np.zeros((8, len(m)), dtype=np.bool_) faces_idx[:, m] = mesh['p_faces_idx'][:, idx] faces_mask[:, m] = mesh['p_faces_mask'][:, idx] profiler.stop('find_near_faces') return faces_idx, faces_mask