# -*- coding: utf-8 -*-
"""
.. Authors:
Novimir Pablant <npablant@pppl.gov>
Define the :class:`ShapeMeshTorus` class.
"""
import numpy as np
from scipy.spatial import Delaunay
from xicsrt.util import profiler
from xicsrt.tools.xicsrt_doc import dochelper
from xicsrt.tools import xicsrt_math as xm
from xicsrt.optics._ShapeMesh import ShapeMesh
[docs]
@dochelper
class ShapeMeshTorus(ShapeMesh):
"""
A toroidal crystal implemented using a mesh-grid.
This class meant to be used for three reasons:
- A toroidal optic shape usable with large radii of curvature
- As an example and template for how to implement a mesh-grid optic.
- As a verification of the mesh-grid implementation.
The analytical :class:`ShapeTorus` object will be much faster.
**Programming Notes**
This optic is built in local coordinates with the mesh surface normal
generally in the local z = [0, 0, 1] direction and with
config['trace_local'] = True. This is recommended because the mesh
implementation performs triangulation and interpolation in the local
x-y plane.
"""
[docs]
def default_config(self):
"""
radius_major: float (1.0)
The radius of curvature of the crystal in the toroidal (xaxis)
direction. This is not the same as the geometric major radius of the
axis of a toroid, which in our case would be r_major-r_minor.
radius_minor: float (0.1)
The radius of the curvature of the crystal in the poloidal (yaxis)
direction. This is the same as the geometric minor radius of a
toroid.
normal_method: str ('analytic')
Specify how to calculate the normal vectors at each of the grid
points. Supported values are 'analytic' and 'fd'.
When set to 'fd' a finite difference method will be used. This is
primarily here as an example for cases in which the surface position
is easy to calculate but where surface normals are difficult. A
better option in these cases however is to use auto-differentiation
to calculate analytical derivatives (for example using jax).
mesh_size : (float, float) ((11,11))
The number of mesh points in the x and y directions.
mesh_coarse_size : (float, float) ((5,5))
The number of mesh points in the x and y directions.
"""
config = super().default_config()
config['mesh_refine'] = True
config['mesh_size'] = (11,11)
config['mesh_coarse_size'] = (5,5)
config['mesh_xsize'] = None
config['mesh_ysize'] = None
config['radius_major'] = 1.0
config['radius_minor'] = 0.1
config['normal_method'] = 'analytic'
# The meshgrid is defined in local coordinates.
config['trace_local'] = True
return config
[docs]
def setup(self):
super().setup()
self.log.debug('Yo mama was here.')
# Calculate the angles that define the physical mesh size
if (xsize := self.param['mesh_xsize']) is None:
xsize = self.param['xsize']
if (ysize := self.param['mesh_ysize']) is None:
ysize = self.param['ysize']
half_major = np.arcsin(xsize/2/(self.param['radius_major']))
half_minor = np.arcsin(ysize/2/self.param['radius_minor'])
self.param['angle_major'] = [-1*half_major, half_major]
self.param['angle_minor'] = [-1*half_minor, half_minor]
# Generate the fine mesh.
mesh_points, mesh_normals, mesh_faces = self.generate_mesh(self.param['mesh_size'])
self.param['mesh_points'] = mesh_points
self.param['mesh_normals'] = mesh_normals
self.param['mesh_faces'] = mesh_faces
# Generate the coarse mesh.
mesh_points, mesh_normals, mesh_faces = self.generate_mesh(self.param['mesh_coarse_size'])
self.param['mesh_coarse_points'] = mesh_points
self.param['mesh_coarse_normals'] = mesh_normals
self.param['mesh_coarse_faces'] = mesh_faces
# Calculate final width and height of the optic for debugging.
mesh_local = self.param['mesh_points']
mesh_xsize = np.max(mesh_local[:,0])-np.min(mesh_local[:,0])
mesh_ysize = np.max(mesh_local[:,1])-np.min(mesh_local[:,1])
self.log.debug(f"Mesh xsize x ysize: {mesh_xsize:0.3f}x{mesh_ysize:0.3f}")
[docs]
def torus(self, a, b):
"""
Return a 3D surface coordinate given a set of two angles.
"""
C0 = np.array([0.0, 0.0, 0.0])
C0_zaxis = np.array([0.0, 0.0, 1.0])
C0_xaxis = np.array([1.0, 0.0, 0.0])
maj_r = self.param['radius_major']
min_r = self.param['radius_minor']
C0_yaxis = np.cross(C0_xaxis, C0_zaxis)
O = C0 + maj_r * C0_zaxis
C_norm = xm.vector_rotate(C0_zaxis, C0_yaxis, a)
C = O - maj_r * C_norm
Q = C + C_norm * min_r
axis = np.cross(C_norm, C0_yaxis)
X_norm = xm.vector_rotate(C_norm, axis, b)
X = Q - X_norm * min_r
return X, X_norm
[docs]
def shape(self, a, b):
return self.torus(a, b)
[docs]
def shape_fd(self, a, b, delta=None):
profiler.start('finite difference')
if delta is None: delta = 1e-8
xyz, _ = self.torus(a, b)
xyz1, _ = self.torus(a + delta, b)
xyz2, _ = self.torus(a, b + delta)
vec1 = xyz1 - xyz
vec2 = xyz2 - xyz
norm_fd = np.cross(vec1, vec2)
norm_fd /= np.linalg.norm(norm_fd)
profiler.stop('finite difference')
return xyz, norm_fd
[docs]
def shape_jax(self, a, b):
raise NotImplementedError()
[docs]
def calculate_mesh(self, a, b):
profiler.start('calculate_mesh')
num_a = len(a)
num_b = len(b)
aa, bb = np.meshgrid(a, b, indexing='ij')
xx = np.empty((num_a, num_b))
yy = np.empty((num_a, num_b))
zz = np.empty((num_a, num_b))
normal_xx = np.empty((num_a, num_b))
normal_yy = np.empty((num_a, num_b))
normal_zz = np.empty((num_a, num_b))
# ------------------------------------------------
# Now calculate the xyz values at each grid point.
for ii_a in range(num_a):
for ii_b in range(num_b):
a = aa[ii_a, ii_b]
b = bb[ii_a, ii_b]
# Temporary for development.
if self.param['normal_method'] == 'analytic':
xyz, norm = self.shape(a, b)
elif self.param['normal_method'] == 'fd':
xyz, norm = self.shape_fd(a, b)
elif self.param['normal_method'] == 'jax':
xyz, norm = self.shape_jax(a, b)
else:
raise Exception(f"normal_method {self.param['normal_method']} unknown.")
xx[ii_a, ii_b] = xyz[0]
yy[ii_a, ii_b] = xyz[1]
zz[ii_a, ii_b] = xyz[2]
normal_xx[ii_a, ii_b] = norm[0]
normal_yy[ii_a, ii_b] = norm[1]
normal_zz[ii_a, ii_b] = norm[2]
profiler.stop('calculate_mesh')
return xx, yy, zz, normal_xx, normal_yy, normal_zz
[docs]
def generate_mesh(self, mesh_size=None):
"""
This method creates the meshgrid for the crystal
"""
profiler.start('generate_mesh')
# --------------------------------
# Setup the basic grid parameters.
a_range = self.param['angle_major']
b_range = self.param['angle_minor']
num_a = mesh_size[0]
num_b = mesh_size[1]
self.log.debug(f'num_a, num_b: {num_a}, {num_b}, total: {num_a*num_b}')
a = np.linspace(a_range[0], a_range[1], num_a)
b = np.linspace(b_range[0], b_range[1], num_b)
xx, yy, zz, normal_xx, normal_yy, normal_zz = \
self.calculate_mesh(a, b)
aa, bb = np.meshgrid(a, b, indexing='ij')
angles_2d = np.stack((aa.flatten(), bb.flatten()), axis=0).T
points = np.stack((xx.flatten(), yy.flatten(), zz.flatten())).T
normals = np.stack((normal_xx.flatten(), normal_yy.flatten(), normal_zz.flatten())).T
delaunay = Delaunay(angles_2d)
faces = delaunay.simplices
# It's also possible to triangulate using the x,y coordinates.
# This does not work well for the toroidal shape.
#
# delaunay = Delaunay(points[:, 0:2])
# faces = delaunay.simplices
profiler.stop('generate_mesh')
return points, normals, faces