# -*- coding: utf-8 -*-
"""
.. Authors
Novimir Pablant <npablant@pppl.gov>
"""
import numpy as np
import logging
import importlib
import glob
import os
from copy import deepcopy
import importlib.util
from xicsrt.util import profiler
[docs]
class Dispatcher():
"""
A class to help find, initialize and then dispatch calls to
raytracing objects.
A dispatcher is used within XICSRT to find and instantiate objects based
on their specification within the config dictionary. These objects are
then tracked within the dispatcher, allowing methods to be called on all
objects sequentially.
"""
[docs]
def __init__(self, config=None, section=None):
self.log = logging.getLogger(self.__class__.__name__)
self.config = config
self.section = section
pathlist = []
pathlist.extend(config['general'].get('pathlist', []))
pathlist.extend(config['general'].get('pathlist_default', []))
self.pathlist = pathlist
self.objects = dict()
self.meta = dict()
self.image = dict()
self.history = dict()
[docs]
def instantiate(self, names=None):
if names is None:
names = self.config[self.section].keys()
elif isinstance(names, str):
names = [names]
strict = self.config['general']['strict_config_check']
obj_info = self.find_xicsrt_objects(self.pathlist)
# self.log.debug(obj_info)
for key in names:
obj = self._instantiate_single(
obj_info
,self.config[self.section][key]
,strict=strict)
self.objects[key] = obj
[docs]
def find_xicsrt_objects(self, pathlist):
"""
Return a dictionary with all the XICSRT objects found in the given
list of paths. Objects are identified by looking for python files
that start with '_Xicsrt' prefix.
Programming Notes
-----------------
If a given path does not exist glob will just return and empty list.
For this reason no path existence checking is needed (unless we want
to raise a user friendly error).
"""
filepath_list = []
name_list = []
for pp in pathlist:
filepath_list.extend(glob.glob(os.path.join(pp, '_Xicsrt*.py')))
for ff in filepath_list:
filename = os.path.basename(ff)
objectname = os.path.splitext(filename)[0]
objectname = objectname[1:]
name_list.append(objectname)
output = dict()
for ii, ff in enumerate(name_list):
output[ff] = {
'filepath': filepath_list[ii],
'name': name_list[ii]
}
return output
[docs]
def _instantiate_single(self, obj_info, config, strict=None):
"""
Instantiate an object from a list of filenames and a class name.
"""
if config['class_name'] in obj_info:
info = obj_info[config['class_name']]
else:
raise Exception('Could not find {} in available objects.'.format(config['class_name']))
spec = importlib.util.spec_from_file_location(info['name'], info['filepath'])
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
cls = getattr(mod, info['name'])
obj = cls(config, initialize=False, strict=strict)
return obj
[docs]
def get_object(self, name):
if not name in self.objects:
self.instantiate(name)
return self.objects[name]
[docs]
def check_config(self, *args, **kwargs):
for key, obj in self.objects.items():
obj.check_config(*args, **kwargs)
[docs]
def check_param(self, *args, **kwargs):
for key, obj in self.objects.items():
obj.check_param(*args, **kwargs)
[docs]
def get_config(self, *args, **kwargs):
config = dict()
for key, obj in self.objects.items():
config[key] = obj.get_config(*args, **kwargs)
return config
[docs]
def setup(self, *args, **kwargs):
for key, obj in self.objects.items():
obj.setup(*args, **kwargs)
[docs]
def initialize(self, *args, **kwargs):
for key, obj in self.objects.items():
obj.initialize(*args, **kwargs)
[docs]
def generate_rays(self, keep_meta=None, keep_history=None):
"""
Generates rays from all sources.
"""
if keep_meta is None: keep_meta = True
if keep_history is None: keep_history = False
if len(self.objects) == 0:
raise Exception('No ray sources defined.')
elif not len(self.objects) == 1:
raise NotImplementedError('Multiple ray sources are not currently supported.')
for key, obj in self.objects.items():
rays = obj.generate_rays()
if keep_meta:
self.meta[key] = dict()
self.meta[key]['num_out'] = np.sum(rays['mask'])
if keep_history:
self.history[key] = deepcopy(rays)
return rays
[docs]
def trace(self, rays, keep_meta=None, keep_history=None, keep_images=None):
"""
Perform raytracing for each object in sequence.
"""
if keep_meta is None: keep_meta = True
if keep_history is None: keep_history = False
if keep_images is None: keep_images = False
profiler.start('Dispatcher: raytrace')
for key, obj in self.objects.items():
profiler.start('Dispatcher: trace_global')
rays = obj.trace_global(rays)
profiler.stop('Dispatcher: trace_global')
if keep_meta:
self.meta[key] = dict()
self.meta[key]['num_out'] = np.sum(rays['mask'])
if keep_history:
self.history[key] = deepcopy(rays)
if keep_images:
profiler.start('Dispatcher: collect')
self.image[key] = obj.make_image(rays)
profiler.stop('Dispatcher: collect')
profiler.stop('Dispatcher: raytrace')
return rays
[docs]
def apply_filters(self, filters):
# Used by dispatcher objects to apply filters.
# 'filters' is a filter dispatcher object that contains filter objects
# 'self' should be a source or optics dispatcher object
# read the filter list for each source and dispatch the matching filters
for key in self.objects:
if not 'filters' in self.objects[key].config:
break
if self.objects[key].config['filters'] is None:
break
for filter_name in filters.objects:
if filter_name in self.objects[key].config['filters']:
self.objects[key].filter_objects.append(filters.objects[filter_name])