Source code for xicsrt.tools.xicsrt_math_jax

# -*- coding: utf-8 -*-
"""
.. Authors
    Novimir Pablant <npablant@pppl.gov>
    James Kring <jdk0026@tigermail.auburn.edu>
    Yevgeniy Yakusevich <eugenethree@gmail.com>

A set of mathematical function with jax acceleration. Many of these functions
are exact copies or slight modification of the functions in xicsrt_math. Other
function are specific to this module.

Programming Notes
-----------------

These module was developed to support some specific modeling work by N. Pablant
and is not used in any of the built-in xicsrt code. There is no plan to support
jax generally within xicsrt, so I am not really sure of the best way to handle
this module for the moment. Maybe move it into xicsrt_contrib?
"""

import jax.numpy as np

[docs]def toarray_1d(a): """ Convert the input to a ndarray with at least 1 dimension. This is similar to the numpy function atleast_1d, but has less overhead and is jax compatible. """ a = np.asarray(a) if a.ndim == 0: a = a.reshape(1) return a
[docs]def vector_angle(a, b): """ Find the angle between two vectors. """ a_mod = np.linalg.norm(a) b_mod = np.linalg.norm(b) if a.ndim == 2 & b.ndim == 2: dot = np.einsum('ij,ik->i', a/a_mod, b/b_mod) elif a.ndim == 1 & b.ndim == 1: dot = np.dot(a/a_mod, b/b_mod) else: raise Exception('Input must have 1 or 2 dimensions.') angle = np.arccos(dot) return angle
[docs]def vector_rotate(a, b, theta): """ Rotate vector a around vector b by an angle theta (radians) Programming Notes: u: parallel projection of a on b_hat. v: perpendicular projection of a on b_hat. w: a vector perpendicular to both a and b. """ if a.ndim == 2: b_hat = b / np.linalg.norm(b) dot = np.einsum('ij,j->i', a, b_hat) u = np.einsum('i,j->ij', dot, b_hat) v = a - u w = np.cross(b_hat, v) c = u + v * np.cos(theta) + w * np.sin(theta) elif a.ndim == 1: b_hat = b / np.linalg.norm(b) u = b_hat * np.dot(a, b_hat) v = a - u w = np.cross(b_hat, v) c = u + v * np.cos(theta) + w * np.sin(theta) else: raise Exception('Input array must be 1d (vector) or 2d (array of vectors)') return c
[docs]def sinusoidal_spiral(phi, b, r0, theta0): r = r0 * (np.sin(theta0 + (b-1)*phi)/np.sin(theta0))**(1/(b-1)) return r
[docs]def point_to_external(point_local, orientation, origin): return vector_to_external(point_local, orientation) + origin
[docs]def point_to_local(point_external, orientation, origin): return vector_to_local(point_external - origin, orientation)
[docs]def vector_to_external(vector, orientation): if vector.ndim == 2: vector = np.einsum('ij,ki->kj', orientation, vector) elif vector.ndim == 1: vector = np.einsum('ij,i->j', orientation, vector) else: raise Exception('vector.ndim must be 1 or 2') return vector
[docs]def vector_to_local(vector, orientation): if vector.ndim == 2: vector = np.einsum('ji,ki->kj', orientation, vector) elif vector.ndim == 1: vector = np.einsum('ji,i->j', orientation, vector) else: raise Exception('vector.ndim must be 1 or 2') return vector