import numpy as np
import multiprocessing
import copy
import time
import enum
import warnings
import textwrap
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
import matplotlib
import astropy.io.fits as fits
import astropy.units as u
from .matrixDFT import MatrixFourierTransform
from . import utils
from . import conf
from . import accel_math
from .accel_math import xp, _scipy, _float, _complex
if accel_math._NUMEXPR_AVAILABLE:
import numexpr as ne
import logging
_log = logging.getLogger('poppy')
__all__ = ['Wavefront', 'OpticalSystem', 'CompoundOpticalSystem',
'OpticalElement', 'ArrayOpticalElement', 'FITSOpticalElement', 'Rotation', 'Detector']
# internal constants for types of plane
class PlaneType(enum.Enum):
unspecified = 0
pupil = 1 # pupil plane
image = 2 # image plane
detector = 3
rotation = 4 # coordinate system rotation
intermediate = 5 # arbitrary plane between pupil and image
inversion = 6 # coordinate system inversion (flip axes, e.g. like going through focus)
_PUPIL = PlaneType.pupil
_IMAGE = PlaneType.image
_DETECTOR = PlaneType.detector # specialized type of image plane
_ROTATION = PlaneType.rotation # not a real optic, just a coordinate transform
_INTERMED = PlaneType.intermediate # for Fresnel propagation
_RADIANStoARCSEC = 180. * 60 * 60 / np.pi
def _wrap_propagate_for_multiprocessing(args):
""" This is an internal helper routine for parallelizing computations across multiple processors.
Python's multiprocessing module allows easy execution of tasks across
many CPUs or even distinct machines. It relies on Python's pickle mechanism to
serialize and pass objects between processes. One side effect of this is
that object instance methods cannot be pickled on their own, and thus cannot be easily
invoked in other processes.
Here, we work around that by pickling the entire object and argument list, packed
as a tuple, transmitting that to the new process, and then unpickling that,
unpacking the results, and *then* at last making our instance method call.
"""
optical_system, wavelength, retain_intermediates, retain_final, normalize, usefftwflag = args
conf.use_fftw = usefftwflag # passed in from parent process
# we're in a different Python interpreter process so we
# need to load the wisdom here too
if conf.use_fftw and accel_math._FFTW_AVAILABLE:
utils._loaded_fftw_wisdom = False
utils.fftw_load_wisdom()
return optical_system.propagate_mono(wavelength,
retain_intermediates=retain_intermediates,
retain_final=retain_final,
normalize=normalize)
class BaseWavefront(ABC):
""" Abstract base class for wavefronts.
In general you should not need to use this class directly; use either
Wavefront or FresnelWavefront child classes for most purposes.
Parameters
----------
wavelength : float
Wavelength of light in meters
npix : int
Size parameter for wavefront array to create, per side.
diam : float, optional
For _PUPIL wavefronts, sets physical size corresponding to npix. Units are meters.
At most one of diam or pixelscale should be set when creating a wavefront.
pixelscale : float, optional
For PlaneType.image PLANE wavefronts, use this pixel scale.
oversample : int, optional
how much to oversample by in FFTs. Default is 2.
Note that final propagations to Detectors use a different algorithm
and, optionally, a separate oversampling factor.
dtype : numpy.dtype, optional
default is double complex.
"""
@utils.quantity_input(wavelength=u.meter, diam=u.meter)
def __init__(self, wavelength=1e-6 * u.meter, npix=1024, dtype=None, diam=1.0 * u.meter,
oversample=2):
self.oversample = oversample
self.wavelength = wavelength # Wavelength in meters (or other unit if specified)
self.diam = diam # array size in meters
self.pixelscale = None
"Pixel scale, in arcsec/pixel or meters/pixel depending on plane type"
self.pixelscale = self.diam / (npix * u.pixel)
self.planetype = PlaneType.pupil # assume we begin at an entrance pupil
self._image_centered = 'array_center' # one of 'array_center', 'pixel', 'corner'
# This records where the coordinate origin is
# in image planes, and depends on how the image
# plane was produced (e.g. FFT implies pixel)
"Are FT'ed image planes centered on a pixel or on a corner between pixels? "
if dtype is None:
dtype = _complex()
self.wavefront = xp.ones((npix, npix), dtype=dtype) # the actual complex wavefront array
self.ispadded = False # is the wavefront padded for oversampling?
self.history = [] # List of strings giving a descriptive history of actions
# performed on the wavefront. Saved to FITS headers.
self.history.append("Created wavefront: wavelength={0:.4g}, diam={1:.4g}".format(self.wavelength, self.diam))
self.history.append(" using array size %s" % (self.wavefront.shape,))
self.location = 'Entrance Pupil' # Descriptive string for where a wavefront is instantaneously located.
# Used mostly for titling displayed plots.
self.current_plane_index = 0 # For tracking stages in a calculation
def __str__(self):
# TODO add switches for image/pupil planes
return """Wavefront:
wavelength = {}
shape = {}
sampling = {}""".format(self.wavelength.to(u.micron), self.wavefront.shape, self.pixelscale)
def copy(self):
"""Return a copy of the wavefront as a different object."""
return copy.deepcopy(self)
def normalize(self):
"""Set this wavefront's total intensity to 1 """
sqrt_ti = np.sqrt(self.total_intensity)
if sqrt_ti == 0:
_log.warning("Total intensity is zero when trying to normalize the wavefront. Cannot normalize.")
elif not np.isfinite(sqrt_ti):
_log.warning("Total intensity is NaN or Inf when trying to normalize the wavefront. Cannot normalize.")
else:
self.wavefront /= sqrt_ti
def __imul__(self, optic):
"""Multiply a Wavefront by an OpticalElement or scalar"""
if isinstance(optic, CoordinateTransform):
return self # a coord transform doesn't actually affect the wavefront via multiplication,
# but instead via forcing a call to rotate() or invert() in propagate_to...
elif np.isscalar(optic):
self.wavefront *= optic # it's just a scalar
self.history.append("Multiplied WF by scalar value " + str(optic))
return self
elif not isinstance(optic, OpticalElement):
raise ValueError('Wavefronts can only be *= multiplied by OpticalElements or scalar values')
if isinstance(optic, Detector):
# detectors don't modify a wavefront, but we do update the label
self.location = 'at ' + optic.name
return self
phasor = optic.get_phasor(self)
if not np.isscalar(phasor) and phasor.size > 1:
assert self.wavefront.shape == phasor.shape, "Phasor shape {} does not match wavefront shape {}".format(
phasor.shape, self.wavefront.shape)
self.wavefront *= phasor
msg = " Multiplied WF by phasor for " + str(optic)
_log.debug(msg)
self.history.append(msg)
self.location = 'after ' + optic.name
return self
def __mul__(self, optic):
""" Multiply a wavefront by an OpticalElement or scalar """
new = self.copy()
new *= optic
return new
__rmul__ = __mul__ # either way works.
def __iadd__(self, wave):
"""Add another wavefront to this one"""
if not isinstance(wave, self.__class__):
raise ValueError('Wavefronts can only be summed with other Wavefronts of the same class.')
if not self.wavefront.shape == wave.wavefront.shape:
raise ValueError('Wavefronts can only be added if they have the same size and shape: {} vs {} '.format(
self.wavefront.shape, wave.wavefront.shape))
try:
if not np.isclose(self.pixelscale.value, wave.pixelscale.to(self.pixelscale.unit).value):
raise ValueError('Wavefronts can only be added if they have the same pixelscale: {} vs {}'.format(
self.pixelscale, wave.pixelscale))
except u.UnitConversionError:
raise ValueError('Wavefronts can only be added if they have equivalent units: {} vs {}'.format(
self.pixelscale.unit, wave.pixelscale.unit))
self.wavefront += wave.wavefront
self.history.append("Summed with another wavefront!")
return self
def __add__(self, wave):
new = self.copy()
new += wave
return new
def as_fits(self, what='intensity', includepadding=False, **kwargs):
""" Return a wavefront as a pyFITS HDUList object
Parameters
----------
what : string
what kind of data to write. Must be one of 'all', 'parts', 'intensity',
'phase' or 'complex'. The default is to write intensity.
'all' means write a file containing intensity, amplitude, and phase
in a data cube of shape (3, N, N). 'parts' omits intensity and
produces a (2, N, N) array with amplitude and phase. 'intensity'
and 'phase' write out 2D arrays with the corresponding values.
includepadding : bool
include any "padding" region, if present, in the returned FITS file?
"""
def get_unpadded(attribute_array):
if self.planetype == PlaneType.pupil and self.ispadded and not includepadding:
return utils.remove_padding(attribute_array.copy(), self.oversample)
else:
return attribute_array.copy()
# Note: FITS HDULists cannot contain arrays which are in GPU memory;
# thus if _USE_CUPY is true, we must get() the arrays back from GPU into CPU memory.
# This pattern occurs several times in the below.
if what.lower() == 'all':
intens = get_unpadded(self.intensity)
outarr = xp.zeros((3, intens.shape[0], intens.shape[1]))
outarr[0, :, :] = intens
outarr[1, :, :] = get_unpadded(self.amplitude)
outarr[2, :, :] = get_unpadded(self.phase)
outfits = fits.HDUList(fits.PrimaryHDU(outarr.get() if accel_math._USE_CUPY else outarr))
outfits[0].header['PLANE1'] = 'Wavefront Intensity'
outfits[0].header['PLANE2'] = 'Wavefront Amplitude'
outfits[0].header['PLANE3'] = 'Wavefront Phase'
elif what.lower() == 'parts':
amp = get_unpadded(self.amplitude)
outarr = xp.zeros((2, amp.shape[0], amp.shape[1]))
outarr[0, :, :] = amp
outarr[1, :, :] = get_unpadded(self.phase)
outfits = fits.HDUList(fits.PrimaryHDU(outarr.get() if accel_math._USE_CUPY else outarr))
outfits[0].header['PLANE1'] = 'Wavefront Amplitude'
outfits[0].header['PLANE2'] = 'Wavefront Phase'
elif what.lower() == 'intensity':
outfits = fits.HDUList(fits.PrimaryHDU(get_unpadded(self.intensity.get() if accel_math._USE_CUPY else self.intensity)))
outfits[0].header['PLANE1'] = 'Wavefront Intensity'
elif what.lower() == 'phase':
outfits = fits.HDUList(fits.PrimaryHDU(get_unpadded(self.phase.get() if accel_math._USE_CUPY else self.phase)))
outfits[0].header['PLANE1'] = 'Phase'
elif what.lower() == 'complex':
real = get_unpadded(self.wavefront.real)
outarr = xp.zeros((2, real.shape[0], real.shape[1]))
outarr[0, :, :] = real
outarr[1, :, :] = get_unpadded(self.wavefront.imag)
outfits = fits.HDUList(fits.PrimaryHDU(outarr.get() if accel_math._USE_CUPY else outarr))
outfits[0].header['PLANE1'] = 'Real part of complex wavefront'
outfits[0].header['PLANE2'] = 'Imaginary part of complex wavefront'
else:
raise ValueError("Unknown string for what to return: " + what)
outfits[0].header['WAVELEN'] = (self.wavelength.to(u.meter).value, 'Wavelength in meters')
outfits[0].header['DIFFLMT'] = ((self.wavelength / self.diam * u.radian).to(u.arcsec).value,
'Diffraction limit lambda/D in arcsec')
outfits[0].header['OVERSAMP'] = (self.oversample, 'Oversampling factor for FFTs in computation')
outfits[0].header['DET_SAMP'] = (self.oversample, 'Oversampling factor for MFT to detector plane')
if self.planetype == PlaneType.image:
outfits[0].header['PIXELSCL'] = (self.pixelscale.to(u.arcsec / u.pixel).value,
'Scale in arcsec/pix (after oversampling)')
fov_arcsec = self.fov.to(u.arcsec).value
if np.isscalar(fov_arcsec):
outfits[0].header['FOV'] = (fov_arcsec, 'Field of view in arcsec (full array)')
else:
outfits[0].header['FOV_X'] = (fov_arcsec[1], 'Field of view in arcsec (full array), X direction')
outfits[0].header['FOV_Y'] = (fov_arcsec[0], 'Field of view in arcsec (full array), Y direction')
outfits[0].header['PIXUNIT'] = 'arcsecond'
else:
outfits[0].header['PIXELSCL'] = (self.pixelscale.to(u.meter / u.pixel).value, 'Pixel scale in meters/pixel')
outfits[0].header['DIAM'] = (self.diam.to(u.meter).value, 'Pupil diameter in meters (not incl padding)')
outfits[0].header['PIXUNIT'] = 'meter'
for h in self.history:
outfits[0].header.add_history(h)
return outfits
def writeto(self, filename, overwrite=True, **kwargs):
"""Write a wavefront to a FITS file.
Parameters
----------
filename : string
filename to use
what : string
what to write. Must be one of 'parts', 'intensity', 'complex'
overwrite : bool, optional
overwhat existing? default is True
Returns
-------
outfile: file on disk
The output is written to disk.
"""
self.as_fits(**kwargs).writeto(filename, overwrite=overwrite)
_log.info(" Wavefront saved to %s" % filename)
def display(self, what='intensity', nrows=1, row=1, showpadding=False,
imagecrop=None, pupilcrop=None,
colorbar=False, crosshairs=False, ax=None, title=None, vmin=None,
vmax=None, vmax_wfe=None, scale=None, use_angular_coordinates=None,
angular_coordinate_unit=u.arcsec):
"""Display wavefront on screen
Parameters
----------
what : string
What to display. Must be one of {intensity, phase, wfe, best, 'both'}.
'intensity' shows the wavefront intensity, 'wfe' shows the wavefront
error in meters or microns, 'phase' is similar to 'wfe' but shows wavefront
phase in radians at the given wavelength.
'Best' implies to display the phase if there is nonzero OPD,
or else display the intensity for a perfect pupil.
'both' will show two panels, for the wavefront intensity and wavefront error.
nrows : int
Number of rows to display in current figure (used for
showing steps in a calculation)
row : int
Which row to display this one in? If set to None, use the
wavefront's self.current_plane_index
vmin, vmax : floats
min and maximum values to display. When left unspecified, these default
to [0, intens.max()] for linear (scale='linear') intensity plots,
[1e-6*intens.max(), intens.max()] for logarithmic (scale='log') intensity
plots, and [-0.25, 0.25] waves for phase plots.
vmax_wfe : float
Max value to display for phase or wfe, *if* distinct from vmax for intensity
for the case of what='both'. In other words you can use this to separately set
the min/max scales for the two plots. The minimum scale for wfe will be
set to the negative of this to ensure a balanced display scale symmetric around zero.
This parameter is ignored if the 'what' parameter is not equal to 'both'.
scale : string
'log' or 'linear', to define the desired display scale type for
intensity. Default is log for image planes, linear otherwise.
imagecrop : float, optional
Crop the displayed image to a smaller region than the full array.
For image planes in angular coordinates, this is given in units of
arcseconds (unless you specify an angular_coordinate_unit parameter to
choose another unit). The default is 5, so only the innermost 5x5 arcsecond
region will be shown. This default may be changed in the
POPPY config file. If the image size is < 5 arcsec then the
entire image is displayed.
For planes in linear physical coordinates such as pupils, this
is given in units of meters, and the default is no cropping
(i.e. the entire array will be displayed unless this keyword
is set explicitly).
showpadding : bool, optional
For wavefronts that have been padded with zeros for oversampling,
show the entire padded arrays, or just the good parts?
Default is False, to show just the central region of interest.
colorbar : bool
Display colorbar
crosshairs : bool
Display a crosshairs indicator showing the axes centered on (0,0)
ax : matplotlib Axes, optional
axes to display into. If not set, will create new axes.
use_angular_coordinates : bool, optional
Should the axes be labeled in angular units, e.g. arcseconds?
This is used by FresnelWavefront, where non-angular
coordinates are possible everywhere. When using Fraunhofer
propagation, this should be left as None so that the
coordinates are inferred from the planetype attribute.
(Default: None, infer coordinates from planetype)
angular_coordinate_unit : astropy unit
Unit to use for angular coordinates display; default is arcsecond.
Returns
-------
figure : matplotlib figure
The current figure is modified.
"""
if scale is None:
scale = 'log' if self.planetype == PlaneType.image else 'linear'
if row is None:
row = self.current_plane_index
intens = self.intensity.copy()
# make a version of the phase where we try to mask out
# areas with particularly low intensity
phase = self.phase.copy()
mean_intens = np.mean(intens[intens != 0])
phase[intens < mean_intens / 100] = np.nan
amp = self.amplitude
y, x = self.coordinates()
# GPU arrays don't work in matplotlib
y = accel_math.ensure_not_on_gpu(y)
x = accel_math.ensure_not_on_gpu(x)
if self.planetype == PlaneType.pupil and self.ispadded and not showpadding:
intens = utils.remove_padding(intens, self.oversample)
phase = utils.remove_padding(phase, self.oversample)
amp = utils.remove_padding(amp, self.oversample)
y = utils.remove_padding(y, self.oversample)
x = utils.remove_padding(x, self.oversample)
if use_angular_coordinates is None:
use_angular_coordinates = self.planetype == PlaneType.image
if use_angular_coordinates and angular_coordinate_unit != u.arcsec:
# Update pixel coordinates for non-arcsecond pixel scales, if requested
x *= (1*u.arcsec).to_value(angular_coordinate_unit)
y *= (1*u.arcsec).to_value(angular_coordinate_unit)
pixelscale_unit = angular_coordinate_unit/u.pixel if use_angular_coordinates else u.m/u.pixel
unit_label = str(pixelscale_unit*u.pixel)
# extent specifications need to include the *full* data region, including the half pixel
# on either side outside of the pixel center coordinates. And remember to swap Y and X.
# Recall that for matplotlib,
# extent = [xmin, xmax, ymin, ymax]
# in this case those are coordinates in units of pixels. Recall that we define pixel
# coordinates to be at the *center* of the pixel, so we compute here the coordinates at the
# outside of those pixels.
# This is needed to get the coordinates right when displaying very small arrays
halfpix = self.pixelscale.to_value(pixelscale_unit) * 0.5
extent = np.asarray([x.min() - halfpix, x.max() + halfpix, y.min() - halfpix, y.max() + halfpix])
# implement semi-intelligent selection of what to display, if the user wants
if what == 'best':
if self.planetype == PlaneType.image:
what = 'intensity' # always show intensity for image planes
elif phase[(np.isfinite(phase))].sum() == 0:
what = 'intensity' # for perfect pupils
# FIXME re-implement this in some better way that doesn't depend on
# optic positioning in the plot grid!
# elif int(row) > 2:
# what = 'intensity' # show intensity for coronagraphic downstream propagation.
else:
what = 'phase' # for aberrated pupils
# compute plot parameters for the subplot grid
nc = int(np.ceil(np.sqrt(nrows)))
nr = int(np.ceil(float(nrows) / nc))
if (nrows - nc * (nc - 1) == 1) and (nr > 1): # avoid just one alone on a row by itself...
nr -= 1
nc += 1
# prepare color maps and normalizations for intensity and phase
if vmax is None:
vmax = 0.25 if what == 'phase' else intens.max()
if scale == 'linear':
if vmin is None:
vmin = 0.25 if what == 'phase' else 0
norm_inten = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
cmap_inten = copy.copy(getattr(matplotlib.cm, conf.cmap_pupil_intensity))
cmap_inten.set_bad('0.0')
else:
if vmin is None:
vmin = vmax * 1e-6
norm_inten = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
cmap_inten = copy.copy(getattr(matplotlib.cm, conf.cmap_sequential))
cmap_inten.set_bad(cmap_inten(0))
cmap_phase = copy.copy(getattr(matplotlib.cm, conf.cmap_diverging))
cmap_phase.set_bad('0.3')
if what == 'both':
vmx = vmax_wfe if vmax_wfe is not None else np.clip(max(vmax, np.abs(vmin)), -np.pi, np.pi)
else:
# we can apply here a prior that the phase is always in the range of
# -pi to +pi, and we should display with a balanced color scale.
vmx = np.clip(max(vmax, np.abs(vmin)), -np.pi, np.pi)
norm_phase = matplotlib.colors.Normalize(vmin=-vmx, vmax=vmx)
def wrap_lines_title(title):
# Helper fn to add line breaks in plot titles,
# tweaked to put in particular places for aesthetics
for prep in ['after', 'before']:
if prep in title:
part1, part2 = title.split(prep)
return part1 + prep + "\n" + "\n".join(textwrap.wrap(part2, 30))
return "\n".join(textwrap.wrap(title, 30))
# now display the chosen selection..
if what == 'intensity':
if ax is None:
ax = plt.subplot(nr, nc, int(row))
utils.imshow_with_mouseover(
intens,
ax=ax,
extent=extent,
norm=norm_inten,
cmap=cmap_inten,
origin='lower'
)
if title is None:
title = wrap_lines_title("Intensity " + self.location)
ax.set_title(title)
ax.set_xlabel(unit_label)
if colorbar:
plt.colorbar(ax.images[0], ax=ax, orientation='vertical', shrink=0.8)
plot_axes = [ax]
to_return = ax
elif what == 'phase':
# Display wavefront phase in radians.
if ax is None:
ax = plt.subplot(nr, nc, int(row))
utils.imshow_with_mouseover(
phase,
ax=ax,
extent=extent,
norm=norm_phase,
cmap=cmap_phase,
origin='lower'
)
if title is None:
title = wrap_lines_title("Phase " + self.location)
ax.set_title(title)
ax.set_xlabel(unit_label)
if colorbar:
plt.colorbar(ax.images[0], ax=ax, orientation='vertical', shrink=0.8, label='Phase [radians]')
plot_axes = [ax]
to_return = ax
elif what == 'wfe':
# Display wavefront error in meters of optical path difference.
# Show units of nanometers, for convenience
# Set up WFE arrays, similar to how we set up phase and amp
wfe = self.wfe.to(u.nanometer).value.copy()
if self.planetype == PlaneType.pupil and self.ispadded and not showpadding:
wfe = utils.remove_padding(wfe, self.oversample)
wfe[intens < mean_intens / 100] = np.nan
vmx = np.nanmax(np.abs(wfe))
norm_wfe = matplotlib.colors.Normalize(vmin=-vmx, vmax=vmx)
if ax is None:
ax = plt.subplot(nr, nc, int(row))
utils.imshow_with_mouseover(
wfe,
ax=ax,
extent=extent,
norm=norm_wfe,
cmap=cmap_phase,
origin='lower'
)
if title is None:
title = wrap_lines_title("WFE " + self.location)
ax.set_title(title)
ax.set_xlabel(unit_label)
if colorbar:
plt.colorbar(ax.images[0], ax=ax, orientation='vertical', shrink=0.8, label='WFE [nm]')
plot_axes = [ax]
to_return = ax
elif what == 'both':
ax1 = plt.subplot(nrows, 2, (row * 2) - 1)
plt.imshow(amp, extent=extent, cmap=cmap_inten, norm=norm_inten, origin='lower')
ax1.set_title("Wavefront amplitude")
ax1.set_ylabel(unit_label)
ax1.set_xlabel(unit_label)
if colorbar:
plt.colorbar(orientation='vertical', ax=ax1, shrink=0.8)
ax2 = plt.subplot(nrows, 2, row * 2)
plt.imshow(phase, extent=extent, cmap=cmap_phase, norm=norm_phase, origin='lower')
if colorbar:
plt.colorbar(orientation='vertical', ax=ax2, shrink=0.8, label='Phase [radians]')
ax2.set_xlabel(unit_label)
ax2.set_title("Wavefront phase")
plot_axes = [ax1, ax2]
to_return = (ax1, ax2)
elif what == 'amplitude':
if ax is None:
ax = plt.subplot(nr, nc, int(row))
utils.imshow_with_mouseover(
amp,
ax=ax,
extent=extent,
norm=norm_inten,
cmap=cmap_inten,
origin='lower'
)
if title is None:
title = wrap_lines_title("Amplitude " + self.location)
ax.set_title(title)
ax.set_xlabel(unit_label)
if colorbar:
plt.colorbar(ax.images[0], ax=ax, orientation='vertical', shrink=0.8)
plot_axes = [ax]
to_return = ax
else:
raise ValueError("Invalid value for what to display; must be: "
"'intensity', 'amplitude', 'phase', or 'both'.")
# now apply axes cropping and/or overplots, if requested.
for ax in plot_axes:
if crosshairs:
ax.axhline(0, ls=":", color='white')
ax.axvline(0, ls=":", color='white')
if use_angular_coordinates and imagecrop is None:
# the value from the configuration is implicitly in arcseconds
imagecrop = conf.default_image_display_fov * u.arcsec
if imagecrop is not None:
# this is either a bare float, in which case interpret it as arcseconds or meters depending on the mode,
# else a quantity already in either angular or pixel units.
if not isinstance(imagecrop, u.Quantity):
imagecrop *= u.arcsec if use_angular_coordinates else u.meter
imagecrop_value = imagecrop.to_value(pixelscale_unit*u.pixel)
cropsize_x = min((imagecrop_value / 2, intens.shape[1] / 2. * self.pixelscale.to_value(pixelscale_unit)))
cropsize_y = min((imagecrop_value / 2, intens.shape[0] / 2. * self.pixelscale.to_value(pixelscale_unit)))
ax.set_xbound(-cropsize_x, cropsize_x)
ax.set_ybound(-cropsize_y, cropsize_y)
ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(5))
ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(5))
plt.draw()
return to_return
def _display_after_optic(self, optic, default_nplanes=2, **kwargs):
""" Convenience function for displaying a wavefront during propagations.
Checks for hint information attached to either the wavefront or the
current optic, and uses that to configure the plot as desired.
Called from within the various propagate() functions.
Parameters
----------
optic : OpticalElement instance
An optic that might have display hint information attached
default_nplanes :
How many rows to use for the display, if this is not
already annotated onto this wavefront object itself.
Returns the plot axes instance.
"""
display_what = getattr(optic, 'wavefront_display_hint', 'best')
display_vmax = getattr(optic, 'wavefront_display_vmax_hint', None)
display_vmin = getattr(optic, 'wavefront_display_vmin_hint', None)
display_crop = getattr(optic, 'wavefront_display_imagecrop', None)
display_nrows = getattr(self, '_display_hint_expected_nplanes', default_nplanes)
ax = self.display(what=display_what,
row=None,
nrows=display_nrows,
colorbar=False,
vmax=display_vmax, vmin=display_vmin,
imagecrop=display_crop,
**kwargs)
if hasattr(optic, 'display_annotate'):
optic.display_annotate(optic, ax) # atypical calling convention needed empirically
return ax
# add convenient properties for intensity, phase, amplitude, total_flux
@property
def amplitude(self):
"""Electric field amplitude of the wavefront """
return np.abs(self.wavefront)
@property
def intensity(self):
"""Electric field intensity of the wavefront (i.e. field amplitude squared)"""
if accel_math._USE_NUMEXPR:
w = self.wavefront
return ne.evaluate("real(abs(w))**2")
else:
return np.abs(self.wavefront) ** 2
@property
def phase(self):
"""Phase of the wavefront, in radians"""
return np.angle(self.wavefront)
@property
def wfe(self):
"""Wavefront error of the wavefront, in meters as an astropy.Quantity"""
if accel_math._USE_CUPY:
# since this returns an astropy quantity, the array must not be in GPU memory
return accel_math.ensure_not_on_gpu(self.phase) * (self.wavelength / (2 * np.pi))
else:
return self.phase * (self.wavelength/(2*np.pi))
@property
def shape(self):
""" Shape of the wavefront array"""
return self.wavefront.shape
@property
def dtype(self):
""" Numpy Data type """
return self.wavefront.dtype
@property
def total_intensity(self):
"""Integrated intensity over the entire spatial/angular extent of the wavefront"""
return self.intensity.sum()
# methods for wavefront propagation:
@abstractmethod
def propagate_to(self, optic):
""" Placeholder for wavefront propagation.
To be implemented by subclasses
"""
pass
def _resample_wavefront_pixelscale(self, detector):
""" Resample a wavefront to a desired detector sampling.
The interpolation is done via the scipy.ndimage.zoom function, by default
using cubic interpolation. If you wish a different order of interpolation,
set the `.interp_order` attribute of the detector instance.
Parameters
----------
detector : Detector class instance
Detector that defines the desired pixel scale
Returns
-------
The wavefront object is modified to have the appropriate pixel scale and spatial extent.
"""
import scipy.interpolate
pixscale_ratio = (self.pixelscale / detector.pixelscale).decompose().value
_log.info("Resampling wavefront to detector with {} pixels and {}. Zoom factor is {:.5f}".format(
detector.shape, detector.pixelscale, pixscale_ratio))
_log.debug("Wavefront pixel scale: {:.3f}".format(self.pixelscale.to(detector.pixelscale.unit)))
_log.debug("Desired detector pixel scale: {:.3f}".format(detector.pixelscale))
_log.debug("Wavefront FOV: {} pixels, {:.3f}".format(self.shape,
self.shape[0]*u.pixel*self.pixelscale.to(
detector.pixelscale.unit)))
_log.debug("Desired detector FOV: {} pixels, {:.3f}".format(detector.shape,
detector.shape[0]*u.pixel*detector.pixelscale))
# Provide 2-pixel margin around image to reduce interpolation errors at edge, but also make
# sure that image is centered properly after it gets cropped down to detector size
margin = 2
crop_shape = [margin + shape for shape in self.wavefront.shape]
# Crop wavefront down to detector size + margin- don't waste computation interpolating
# parts of plane that get cropped out later anyways
cropped_wf = utils.pad_or_crop_to_shape(self.wavefront, crop_shape)
pixscale_in = self.pixelscale.to(u.m/u.pix).value
pixscale_out = detector.pixelscale.to(u.m / u.pix).value
if not accel_math._USE_CUPY:
def make_axis(npix, step):
""" Helper function to make coordinate axis for interpolation """
return step * xp.arange(-npix // 2, npix // 2, dtype=xp.float64)
# Input and output axes for interpolation. The interpolated wavefront will be evaluated
# directly onto the detector axis, so don't need to crop afterwards.
x_in = make_axis(crop_shape[0], pixscale_in)
y_in = make_axis(crop_shape[1], pixscale_in)
x_out = make_axis(detector.shape[0], pixscale_out)
y_out = make_axis(detector.shape[1], pixscale_out)
def interpolator(arr):
"""
Bind arguments to scipy's RectBivariateSpline function.
For data on a regular 2D grid, RectBivariateSpline is more efficient than interp2d.
"""
return scipy.interpolate.RectBivariateSpline(x_in, y_in, arr,
kx=detector.interp_order, ky=detector.interp_order)
# Interpolate real and imaginary parts separately
real_resampled = interpolator(cropped_wf.real)(x_out, y_out)
imag_resampled = interpolator(cropped_wf.imag)(x_out, y_out)
new_wf = xp.array(real_resampled + 1j * imag_resampled)
else:
# cupyx does not have RectBivariateSpline or interp2d so wavefront resampling
# is implemented with map_coordinates
#wf_xmin = pixscale * cropped_wf.shape[0]/2
# Note, carefully handle the offset-by-one to be consistent with
# the use of arange above; avoid fencepost error.
wf_xmax = pixscale_in * cropped_wf.shape[0]/2
x,y = xp.ogrid[-wf_xmax:wf_xmax-pixscale_in:cropped_wf.shape[0]*1j,
-wf_xmax:wf_xmax-pixscale_in:cropped_wf.shape[1]*1j]
det_xmax = pixscale_out * detector.shape[0]/2
newx,newy = xp.mgrid[-det_xmax:det_xmax-pixscale_out:detector.shape[0]*1j,
-det_xmax:det_xmax-pixscale_out:detector.shape[1]*1j]
x0 = x[0,0]
y0 = y[0,0]
dx = x[1,0] - x0
dy = y[0,1] - y0
ivals = (newx - x0)/dx
jvals = (newy - y0)/dy
coords = xp.array([ivals, jvals])
new_wf = _scipy.ndimage.map_coordinates(cropped_wf, coords, order=detector.interp_order)
# enforce conservation of energy:
new_wf *= 1. / pixscale_ratio
self.ispadded = False # if a pupil detector, avoid auto-cropping padded pixels on output
self.wavefront = new_wf
self.pixelscale = detector.pixelscale
@utils.quantity_input(Xangle=u.arcsec, Yangle=u.arcsec)
def tilt(self, Xangle=0.0, Yangle=0.0):
""" Tilt a wavefront in X and Y.
Recall from Fourier optics (although this is straightforwardly rederivable by drawing triangles)
that for a wavefront tilted by some angle theta in radians, that a point r meters from the center of
the pupil has:
extra_pathlength = sin(theta) * r
extra_waves = extra_pathlength/ wavelength = r * sin(theta) / wavelength
So we calculate the U and V arrays (corresponding to r for the pupil, in meters from the center)
and then multiply by the appropriate trig factors for the angle.
The sign convention is chosen such that positive Yangle tilts move the star upwards in the
array at the focal plane. (This is sort of an inverse of what physically happens in the propagation
to or through focus, but we're ignoring that here and trying to just work in sky coords)
Parameters
----------
Xangle, Yangle : float
tilt angles, specified in arcseconds
"""
if self.planetype == PlaneType.image:
raise NotImplementedError("Are you sure you want to tilt a wavefront in an _IMAGE plane?")
if np.abs(Xangle) > 0 or np.abs(Yangle) > 0:
xangle_rad = Xangle.to(u.radian).value
yangle_rad = Yangle.to(u.radian).value
if isinstance(self.pixelscale, u.Quantity):
pixelscale = self.pixelscale.to(u.m / u.pixel).value
else:
pixelscale = self.pixelscale
npix = self.wavefront.shape[0]
V, U = xp.indices(self.wavefront.shape, dtype=_float())
V -= (npix - 1) / 2.0
V *= pixelscale
U -= (npix - 1) / 2.0
U *= pixelscale
# SIGN CONVENTION: Following Wyant and Creath:
# Tilt in a wavefront affects the image by causing a shift of its center location in the Gaussian
# image plane. A tilt causing a positive OPD change in the +x direction will cause the image to
# shift in the -x direction
tiltphasor = xp.exp(-2.0j * np.pi * (U * xangle_rad + V * yangle_rad) / self.wavelength.to(u.meter).value)
self.wavefront *= tiltphasor
self.history.append("Tilted wavefront by "
"X={:2.2}, Y={:2.2} arcsec".format(Xangle, Yangle))
else:
_log.warning("Wavefront.tilt() called, but requested tilt was zero. No change.")
def rotate(self, angle=0.0):
"""Rotate a wavefront by some amount, using spline interpolation
(or exact rotation, for multiples of 90 degrees)
Parameters
----------
angle : float
Angle to rotate, in degrees counterclockwise.
"""
# The ndimage rotate function does not work for complex numbers.
# so let's treat the real and imaginary parts individually
k, remainder = np.divmod(angle, 90)
if remainder == 0:
# rotation is a multiple of 90
rot_real = xp.rot90(self.wavefront.real, k=-k) # negative = CCW
rot_imag = xp.rot90(self.wavefront.imag, k=-k)
else:
# arbitrary free rotation with interpolation
rot_real = _scipy.ndimage.rotate(self.wavefront.real, -angle, reshape=False) # negative = CCW
rot_imag = _scipy.ndimage.rotate(self.wavefront.imag, -angle, reshape=False)
self.wavefront = rot_real + 1j * rot_imag
self.history.append('Rotated by {:.2f} degrees, CCW'.format(angle))
def invert(self, axis='both'):
"""Invert coordinates, i.e. flip the direction of the X and Y axes
This models the inversion of axes signs that happens for instance when a beam
passes through a focus.
Parameters
----------
axis : string
either 'both', 'x', or 'y', for which axes to invert
"""
if axis.lower() == 'both':
self.wavefront = self.wavefront[::-1, ::-1]
elif axis.lower() == 'x':
self.wavefront = self.wavefront[:, ::-1]
elif axis.lower() == 'y':
self.wavefront = self.wavefront[::-1]
else:
raise ValueError("Invalid/unknown value for the 'axis' parameter. Must be 'x', 'y', or 'both'.")
self.history.append('Inverted axis direction for {} axes'.format(axis.upper()))
@abstractmethod
def coordinates(self):
""" Return Y, X coordinates for this wavefront, in the manner of numpy.indices()
"""
pass
[docs]
class Wavefront(BaseWavefront):
""" Wavefront in the Fraunhofer approximation: a monochromatic wavefront that
can be transformed between pupil and image planes only, not to intermediate planes
In a pupil plane, a wavefront object `wf` has
* `wf.diam`, a diameter in meters
* `wf.pixelscale`, a scale in meters/pixel
In an image plane, it has
* `wf.fov`, a field of view in arcseconds
* `wf.pixelscale`, a scale in arcsec/pixel
Use the `wf.propagate_to()` method to transform a wavefront between conjugate planes. This will update those
properties as appropriate.
By default, `Wavefronts` are created in a pupil plane. Set `pixelscale=#` to make an image plane instead.
Parameters
----------
wavelength : float
Wavelength of light in meters
npix : int
Size parameter for wavefront array to create, per side.
diam : float, optional
For _PUPIL wavefronts, sets physical size corresponding to npix. Units are meters.
At most one of diam or pixelscale should be set when creating a wavefront.
pixelscale : float, optional
For PlaneType.image PLANE wavefronts, use this pixel scale.
oversample : int, optional
how much to oversample by in FFTs. Default is 2.
Note that final propagations to Detectors use a different algorithm
and, optionally, a separate oversampling factor.
dtype : numpy.dtype, optional
default is double complex.
"""
@utils.quantity_input(wavelength=u.meter, diam=u.meter, pixelscale=u.arcsec / u.pixel)
def __init__(self, wavelength=1e-6 * u.meter, npix=1024, dtype=None, diam=8.0 * u.meter,
oversample=2, pixelscale=None):
super(Wavefront, self).__init__(wavelength=wavelength,
npix=npix,
dtype=dtype,
diam=diam,
oversample=oversample)
if pixelscale is None:
self.pixelscale = self.diam / (npix * u.pixel) # scale in meters/pix or arcsec/pix, as appropriate
self.planetype = PlaneType.pupil # are we at image or pupil?
else:
self.pixelscale = pixelscale # scale in meters/pix or arcsec/pix, as appropriate
self.planetype = PlaneType.image
self._last_transform_type = None # later used to track MFT vs FFT pixel coord centering in coordinates()
self.fov = None # Field of view in arcsec. Applies to an image plane only.
[docs]
def propagate_to(self, optic):
"""Propagates a wavefront object to the next optic in the list.
Modifies this wavefront object itself.
Transformations between pupil and detector planes use MFT or inverse MFT.
Transformations between pupil and other (non-detector) image planes use FFT or inverse FFT, unless
explicitly tagged to use MFT via a propagation hint.
Transformations from any frame through a rotation or coordinate transform plane simply transform the
wavefront accordingly.
Parameters
----------
optic : OpticalElement
The optic to propagate to. Used for determining the appropriate optical plane.
"""
if self.planetype == optic.planetype:
if isinstance(optic, Detector):
_log.debug(" Resampling wavefront to match detector pixelation.")
self._resample_wavefront_pixelscale(optic)
else:
_log.debug(" Wavefront and optic %s already at same plane type, no propagation needed." % optic.name)
self.current_plane_index += 1
return
else:
msg = " Propagating wavefront to %s. " % str(optic)
_log.debug(msg)
self.history.append(msg)
if optic.planetype == PlaneType.rotation: # rotate
self.rotate(optic.angle)
self.location = 'after ' + optic.name
elif optic.planetype == PlaneType.inversion: # invert coordinates
self.invert(axis=optic.axis)
self.location = 'after ' + optic.name
elif ((optic.planetype == PlaneType.detector or getattr(optic, 'propagation_hint', None) == 'MFT')
and self.planetype == PlaneType.pupil): # from pupil to detector in image plane: use MFT
self._propagate_mft(optic)
self.location = 'before ' + optic.name
elif (optic.planetype == PlaneType.pupil and self.planetype == PlaneType.image and
self._last_transform_type == 'MFT'):
# inverse MFT detector to pupil
# n.b. transforming PlaneType.pupil -> PlaneType.detector results in self.planetype == PlaneType.image
# while setting _last_transform_type to MFT
self._propagate_mft_inverse(optic)
self.location = 'before ' + optic.name
elif self.planetype == PlaneType.image and optic.planetype == PlaneType.detector:
raise NotImplementedError('image plane directly to detector propagation (resampling!) not implemented yet')
else:
self._propagate_fft(optic) # FFT pupil to image or image to pupil
self.location = 'before ' + optic.name
self.current_plane_index += 1
def _propagate_fft(self, optic):
""" Propagate from pupil to image or vice versa using a padded FFT
Parameters
----------
optic : OpticalElement
The optic to propagate to. Used for determining the appropriate optical plane.
"""
if self.oversample > 1 and not self.ispadded: # add padding for oversampling, if necessary
assert self.oversample == optic.oversample, "Unexpected sampling inconsistency in _propagate_fft!"
self.wavefront = utils.pad_to_oversample(self.wavefront, self.oversample)
self.ispadded = True
if optic.verbose:
_log.debug(" Padded WF array for oversampling by %dx" % self.oversample)
self.history.append(" Padded WF array for oversampling by %dx" % self.oversample)
# Set up for computation - figure out direction & normalization
if self.planetype == PlaneType.pupil and optic.planetype == PlaneType.image:
# SIGN CONVENTION: plus signs in exponent for basic forward propagation, with
# phase increasing with time. This convention differs from prior poppy version < 1.0.
# A "forward propagation" in the optical sense therefore corresponds to what numpy labels
# as an "inverse" FFT.
propagation_forward = True
fft_forward = False
# (pre-)update state:
self.planetype = PlaneType.image
self.pixelscale = (self.wavelength / self.diam * u.radian / self.oversample).to(u.arcsec) / u.pixel
self.fov = self.wavefront.shape[0] * u.pixel * self.pixelscale
self.history.append(' FFT {}, to IMAGE plane scale={:.4f}'.format(self.wavefront.shape, self.pixelscale))
elif self.planetype == PlaneType.image and optic.planetype == PlaneType.pupil:
# SIGN CONVENTION: plus signs in exponent for basic forward propagation, with
# phase increasing with time. This convention differs from prior poppy version < 1.0.
# A "backwards propagation" in the optical sense therefore corresponds to what numpy labels
# as an "forward" FFT.
propagation_forward = False
fft_forward = True
# (pre-)update state:
self.planetype = PlaneType.pupil
self.pixelscale = self.diam * self.oversample / (self.wavefront.shape[0] * u.pixel)
self.history.append(' FFT {}, to PUPIL scale={:.4f}'.format(self.wavefront.shape, self.pixelscale))
# do FFT
if conf.enable_flux_tests: _log.debug("\tPre-FFT total intensity: " + str(self.total_intensity))
if conf.enable_speed_tests: t0 = time.time() # pragma: no cover
self.wavefront = accel_math.fft_2d(self.wavefront, forward=fft_forward)
if propagation_forward:
# FFT produces pixel-centered images by default, unless the _image_centered param
# has already been set by an FQPM_FFT_aligner class
if self._image_centered != 'corner':
self._image_centered = 'pixel'
self._last_transform_type = 'FFT'
if conf.enable_speed_tests: # pragma: no cover
t1 = time.time()
_log.debug("\tTIME %f s\t for the FFT" % (t1 - t0))
if conf.enable_flux_tests:
_log.debug("\tPost-FFT total intensity: " + str(self.total_intensity))
def _propagate_mft(self, det):
""" Compute from pupil to an image using the Soummer et al. 2007 MFT algorithm
Parameters
----------
det : OpticalElement, must be of type DETECTOR
The target optical plane to propagate to."""
assert self.planetype == PlaneType.pupil
assert (det.planetype == PlaneType.detector or
getattr(det, 'propagation_hint', None) == 'MFT')
if self.ispadded:
# pupil plane is padded - trim that out since it's not needed
self.wavefront = utils.remove_padding(self.wavefront, self.oversample)
self.ispadded = False
self._preMFT_pupil_shape = self.wavefront.shape # save for possible inverseMFT
self._preMFT_pupil_pixelscale = self.pixelscale # save for possible inverseMFT
# the arguments for the matrixDFT are
# - wavefront (assumed to fill the input array)
# - focal plane size in lambda/D units
# - number of pixels on a side in focal plane array.
# extract everything from Quantities to regular scalars here
lam_d = ((self.wavelength / self.diam) * u.radian).to(u.arcsec).value
det_fov_lam_d = det.fov_arcsec.to(u.arcsec).value / lam_d
det_calc_size_pixels = det.fov_pixels.to(u.pixel).value * det.oversample
mft = MatrixFourierTransform(centering='ADJUSTABLE', verbose=False)
pixelscale = det.pixelscale if det.pixelscale is not None else det.fov_arcsec/det.fov_pixels
if not np.isscalar(det_fov_lam_d): # hasattr(det_fov_lam_d,'__len__'):
msg = ' Propagating w/ MFT: {:.4f} fov=[{:.3f},{:.3f}] lam/D npix={} x {}'.format(
pixelscale / det.oversample, det_fov_lam_d[0], det_fov_lam_d[1],
det_calc_size_pixels[0], det_calc_size_pixels[1])
else:
msg = ' Propagating w/ MFT: {:.4f} fov={:.3f} lam/D npix={:d}'.format(
pixelscale / det.oversample, det_fov_lam_d, int(det_calc_size_pixels))
_log.debug(msg)
self.history.append(msg)
if det.offset is not None:
_log.debug(' offset= '+str( det.offset))
_log.debug(' MFT method = ' + mft.centering)
self.wavefront = mft.perform(self.wavefront, det_fov_lam_d, det_calc_size_pixels,
offset=None if det.offset is None else det.offset * det._offset_sign) # sign flip intentional, see note in Detector class
_log.debug(" Result wavefront: at={0} shape={1} ".format(
self.location, str(self.shape)))
self._last_transform_type = 'MFT'
self.planetype = PlaneType.image
self.fov = det.fov_arcsec
self.pixelscale = det.fov_arcsec / det_calc_size_pixels / u.pixel
if not np.isscalar(self.pixelscale.value):
# check for rectangular arrays
if np.isclose(self.pixelscale[0], self.pixelscale[1]):
self.pixelscale = self.pixelscale[0]
# we're in a rectangular array with same scale in both directions, so treat pixelscale as a scalar
else:
raise NotImplementedError(
'Different pixel scales in X and Y directions (i.e. non-square pixels) not yet supported.')
def _propagate_mft_inverse(self, pupil, pupil_npix=None):
""" Compute from an image to a pupil using the Soummer et al. 2007 MFT algorithm
This allows transformation back from an arbitrarily-sampled 'detector' plane to a pupil.
This is only used if transforming back from a 'detector' type plane to a pupil, for instance
inside the semi-analytic coronagraphy algorithm, but is not used in more typical propagations.
"""
assert self.planetype == PlaneType.image
assert pupil.planetype == PlaneType.pupil
# the arguments for the matrixDFT are
# - wavefront (assumed to fill the input array)
# - focal plane size in lambda/D units
# - number of pixels on a side in focal plane array.
# Try to transform to whatever the intrinsic scale of the next pupil is.
# but if this ends up being a scalar (meaning it is an AnalyticOptic) then
# just go back to our own prior shape and pixel scale.
if pupil_npix is None:
if pupil.shape is not None and pupil.shape[0] != 1:
# Use next optic's shape, extent, and pixelscale to define the target sampling
pupil_npix = pupil.shape[0]
next_pupil_diam = pupil.shape[0]*pupil.pixelscale*u.pixel
_log.debug("Got post-invMFT pupil npix from next optic: {} pix, {} diam".format(pupil_npix, next_pupil_diam))
else:
# Use the prior pupil's shape, extent, and pixelscale to define the target sampling
pupil_npix = self._preMFT_pupil_shape[0]
next_pupil_diam = self.diam
_log.debug("Got post-invMFT pupil npix from pre-MFT pupil: {} pix, {} diam ".format(pupil_npix, self.diam))
# extract everything from Quantities to regular scalars here
lam_d = (self.wavelength / next_pupil_diam * u.radian).to(u.arcsec).value
det_fov_lam_d = self.fov.to(u.arcsec).value / lam_d
mft = MatrixFourierTransform(centering='ADJUSTABLE', verbose=False)
# these can be either scalar or 2-element lists/tuples/ndarrays
msg_pixscale = ('{0:.4f}'.format(self.pixelscale) if np.isscalar(self.pixelscale.value) else
'{0:.4f} x {1:.4f} arcsec/pix'.format(self.pixelscale.value[0], self.pixelscale.value[1]))
msg_det_fov = ('{0:.4f} lam/D'.format(det_fov_lam_d) if np.isscalar(det_fov_lam_d) else
'{0:.4f} x {1:.4f} lam/D'.format(det_fov_lam_d[0], det_fov_lam_d[1]))
msg = ' Propagating w/ InvMFT: scale={0} fov={1} npix={2:d} x {2:d}'.format(
msg_pixscale, msg_det_fov, pupil_npix)
_log.debug(msg)
self.history.append(msg)
self.wavefront = mft.inverse(self.wavefront, det_fov_lam_d, pupil_npix)
self._last_transform_type = 'InvMFT'
self.planetype = PlaneType.pupil
self.pixelscale = next_pupil_diam / self.wavefront.shape[0] / u.pixel
self.diam = next_pupil_diam
# note: the following are implemented as static methods to
# allow for reuse outside of this class in the Zernike polynomial
# caching mechanisms. See zernike.py.
[docs]
@staticmethod
def pupil_coordinates(shape, pixelscale):
"""Utility function to generate coordinates arrays for a pupil
plane wavefront
Parameters
----------
shape : tuple of ints
Shape of the wavefront array
pixelscale : float or 2-tuple of floats
the pixel scale in meters/pixel, optionally different in
X and Y
"""
y, x = xp.indices(shape, dtype=_float())
pixelscale_mpix = pixelscale.to(u.meter / u.pixel).value if isinstance(pixelscale, u.Quantity) else pixelscale
if not np.isscalar(pixelscale_mpix):
pixel_scale_x, pixel_scale_y = pixelscale_mpix
else:
pixel_scale_x, pixel_scale_y = pixelscale_mpix, pixelscale_mpix
if accel_math._USE_NUMEXPR:
ny, nx = shape
return (ne.evaluate("pixel_scale_y * (y - (ny-1)/2)"),
ne.evaluate("pixel_scale_x * (x - (nx-1)/2)"))
else:
y -= (shape[0] - 1) / 2.0
x -= (shape[1] - 1) / 2.0
return pixel_scale_y * y, pixel_scale_x * x
[docs]
@staticmethod
def image_coordinates(shape, pixelscale, last_transform_type, image_centered):
"""Utility function to generate coordinates arrays for an image
plane wavefront
Parameters
----------
shape : tuple of ints
Shape of the wavefront array
pixelscale : float or 2-tuple of floats
the pixelscale in meters/pixel, optionally different in
X and Y
last_transform_type : string
Was the last transformation on the Wavefront an FFT
or an MFT?
image_centered : string
Was POPPY trying to keeping the center of the image on
a pixel, crosshairs ('array_center'), or corner?
"""
y, x = xp.indices(shape, dtype=_float())
pixelscale_arcsecperpix = pixelscale.to(u.arcsec / u.pixel).value
if not np.isscalar(pixelscale_arcsecperpix):
pixel_scale_x, pixel_scale_y = pixelscale_arcsecperpix
else:
pixel_scale_x, pixel_scale_y = pixelscale_arcsecperpix, pixelscale_arcsecperpix
# in most cases, the x and y values are centered around the exact center of the array.
# This is not true in general for FFT-produced image planes where the center is in the
# middle of one single pixel (the 0th-order term of the FFT), even though that means that
# the PSF center is slightly offset from the array center.
# On the other hand, if we used the FQPM FFT Aligner optic, then that forces the PSF center
# to the exact center of an array.
# The following are just relevant for the FFT-created images, not for the Detector MFT
# image at the end.
if last_transform_type == 'FFT':
# FFT array sizes will always be even, right?
if image_centered == 'pixel':
# so this goes to an integer pixel
y -= shape[0] / 2.0
x -= shape[1] / 2.0
elif image_centered == 'array_center' or image_centered == 'corner':
# and this goes to a pixel center
y -= (shape[0] - 1) / 2.0
x -= (shape[1] - 1) / 2.0
else:
# MFT produced images are always exactly centered.
y -= (shape[0] - 1) / 2.0
x -= (shape[1] - 1) / 2.0
return pixel_scale_y * y, pixel_scale_x * x
[docs]
def coordinates(self):
""" Return Y, X coordinates for this wavefront, in the manner of numpy.indices()
This function knows about the offset resulting from FFTs. Use it whenever computing anything
measured in wavefront coordinates.
Returns
-------
Y, X : array_like
Wavefront coordinates in either meters or arcseconds for pupil and image, respectively
"""
if self.planetype == PlaneType.pupil:
return type(self).pupil_coordinates(self.shape, self.pixelscale)
elif self.planetype == PlaneType.image:
return Wavefront.image_coordinates(self.shape, self.pixelscale,
self._last_transform_type, self._image_centered)
else:
raise RuntimeError("Unknown plane type (should be pupil or image!)")
[docs]
@classmethod
def from_fresnel_wavefront(cls, fresnel_wavefront, verbose=False):
"""Convert a Fresnel type wavefront object to a Fraunhofer one
Note, this function implicitly assumes this wavefront is at a
pupil plane, so the resulting Fraunhofer wavefront will have
pixelscale in meters/pix rather than arcsec/pix.
Parameters
----------
fresnel_wavefront : Wavefront
The (Fresnel-type) wavefront to be converted.
"""
# Generate a Fraunhofer wavefront with the same sampling
wf = fresnel_wavefront
beam_diam = (wf.wavefront.shape[0]//wf.oversample) * wf.pixelscale*u.pixel
new_wf = Wavefront(diam=beam_diam,
npix=wf.shape[0]//wf.oversample,
oversample=wf.oversample,
wavelength=wf.wavelength)
if verbose:
print(wf.pixelscale, new_wf.pixelscale, new_wf.shape)
# Deal with metadata
new_wf.history = wf.history.copy()
new_wf.history.append("Converted to Fraunhofer propagation")
new_wf.history.append(" Fraunhofer array pixel scale = {:.4g}, oversample = {}".format(new_wf.pixelscale, new_wf.oversample))
# Copy over the contents of the array
new_wf.wavefront = utils.pad_or_crop_to_shape(wf.wavefront, new_wf.shape)
# Copy over misc internal info
if hasattr(wf, '_display_hint_expected_nplanes'):
new_wf._display_hint_expected_nplanes = wf._display_hint_expected_nplanes
new_wf.current_plane_index = wf.current_plane_index
new_wf.location = wf.location
return new_wf
@property
def _on_gpu(self):
"""Where is the data for this wavefront, in GPU memory or CPU memory?
Returns True if data is in GPU memory (only possible with CUPY)
"""
if not accel_math._USE_CUPY:
return False
else:
import cupy as cp
return isinstance(self.wavefront, cp.ndarray)
[docs]
def rotate(self, angle=0.0):
"""Rotate a wavefront by some amount, using spline interpolation
(or exact rotation, for multiples of 90 degrees)
Note, if the wavefront is zero-padded this step will unpad it, as an efficiency to
avoid rotating large arrays of zeros unnecessarily.
Parameters
----------
angle : float
Angle to rotate, in degrees counterclockwise.
"""
if self.ispadded:
# pupil plane is padded - trim out the zeros since it's not needed in the rotation
# If needed in later steps, the padding will be re-added automatically
self.wavefront = utils.remove_padding(self.wavefront, self.oversample)
self.ispadded = False
super().rotate(angle)
# ------ core Optical System classes -------
class BaseOpticalSystem(ABC):
""" Abstract Base class for optical systems
Parameters
----------
name : string
descriptive name of optical system
oversample : int
Either how many times *above* Nyquist we should be
(for pupil or image planes), or how many times a fixed
detector pixel will be sampled. E.g. `oversample=2` means
image plane sampling lambda/4*D (twice Nyquist) and
detector plane sampling 2x2 computed pixels per real detector
pixel. Default is 2.
verbose : bool
whether to be more verbose with log output while computing
pupil_diameter : astropy.Quantity of dimension length
Diameter of entrance pupil. Defaults to size of first optical element
if unspecified, or else 1 meter.
"""
def __init__(self, name="unnamed system", verbose=True, oversample=2,
npix=None, pupil_diameter=None):
self.name = name
self.verbose = verbose
self.planes = [] # List of OpticalElements
self.oversample = oversample
self.npix = npix
self.pupil_diameter = pupil_diameter
self.source_offset_r = 0 # = np.zeros((2)) # off-axis tilt of the source, in ARCSEC
self.source_offset_theta = 0 # in degrees CCW
self.intermediate_wfs = None #
if self.verbose:
_log.info("Initialized OpticalSystem: " + self.name)
def __getitem__(self, num):
return self.planes[num]
def __len__(self):
return len(self.planes)
def _add_plane(self, optic, index=None, logstring=""):
""" utility helper function for adding a generic plane """
if index is None:
self.planes.append(optic)
else:
self.planes.insert(index, optic)
if self.verbose: _log.info("Added {}: {}".format(logstring, optic.name))
return optic
def add_rotation(self, angle=0.0, index=None, *args, **kwargs):
"""
Add a clockwise or counterclockwise rotation around the optical axis
Parameters
----------
angle : float
Rotation angle, counterclockwise. By default in degrees.
index : int
Index into the optical system's planes for where to add the new optic. Defaults to
appending the optic to the end of the plane list.
Returns
-------
poppy.Rotation
The rotation added to the optical system
"""
optic = Rotation(angle=angle, *args, **kwargs)
return self._add_plane(optic, index=index, logstring="rotation plane")
def add_inversion(self, index=None, *args, **kwargs):
"""
Add a coordinate inversion of the wavefront, for instance
a flip in the sign of the X and Y axes due to passage through a focus.
Parameters
----------
index : int
Index into the optical system's planes for where to add the new optic. Defaults to
appending the optic to the end of the plane list.
Returns
-------
poppy.CoordinateInversion
The inversion added to the optical system
"""
optic = CoordinateInversion(*args, **kwargs)
return self._add_plane(optic, index=index, logstring="coordinate inversion plane")
def add_detector(self, pixelscale, oversample=None, index=None, **kwargs):
""" Add a Detector object to an optical system.
By default, use the same oversampling as the rest of the optical system,
but the user can override to a different value if desired by setting `oversample`.
Other arguments are passed to the init method for Detector().
Parameters
----------
pixelscale : float
Pixel scale in arcsec/pixel (or m/pixel for Fresnel optical systems)
oversample : int, optional
Oversampling factor for *this detector*, relative to hardware pixel size.
Optionally distinct from the default oversampling parameter of the OpticalSystem.
index : int
Index into the optical system's planes for where to add the new optic. Defaults to
appending the optic to the end of the plane list.
Returns
-------
poppy.Detector
The detector added to the optical system
"""
if oversample is None:
oversample = getattr(self, 'oversample', 1)
# assume oversample is 1 if not present as an attribute; needed for
# compatibility use in subclass FresnelOpticalSystem.
optic = Detector(pixelscale, oversample=oversample, **kwargs)
return self._add_plane(optic, index=index,
logstring="detector with pixelscale={} and oversampling={}".format(
pixelscale,
oversample))
@abstractmethod
def propagate(self, wavefront):
"""Propagate a wavefront through this system
and return the output wavefront."""
pass
@abstractmethod
@utils.quantity_input(wavelength=u.meter)
def input_wavefront(self, wavelength=1e-6*u.meter):
"""Create an input wavefront suitable for propagation"""
pass
@utils.quantity_input(wavelength=u.meter)
def calc_psf(self, wavelength=1e-6,
weight=None,
save_intermediates=False,
save_intermediates_what='all',
display=False,
return_intermediates=False,
return_final=False,
source=None,
normalize='first',
display_intermediates=False,
progressbar=False,
inwave=None):
"""Calculate a PSF, either multi-wavelength or monochromatic.
The wavelength coverage computed will be:
- multi-wavelength PSF over some weighted sum of wavelengths (if you provide a `source` argument)
- monochromatic (if you provide just a `wavelength` argument)
Parameters
----------
wavelength : float or Astropy.Quantity, optional
wavelength in meters, or some other length unit if specified as an astropy.Quantity. Either
scalar for monochromatic calculation or list or ndarray for multiwavelength calculation.
weight : float, optional
weight by which to multiply each wavelength. Must have same length as
wavelength parameter. Defaults to 1s if not specified.
save_intermediates : bool, optional
whether to output intermediate optical planes to disk. Default is False
save_intermediate_what : string, optional
What to save - phase, intensity, amplitude, complex, parts, all. Default is all.
return_intermediates: bool, optional
return intermediate wavefronts as well as PSF?
return_final: bool, optional
return the complex wavefront at the last surface propagation as well as the PSF.
Useful for getting complex PSF without memory usage of `return_intermediates`
source : dict
a dict containing 'wavelengths' and 'weights' list.
normalize : string, optional
How to normalize the PSF. See the documentation for propagate_mono() for details.
display : bool, optional
whether to plot the results when finished or not.
display_intermediates: bool, optional
Display intermediate optical planes? Default is False. This option is incompatible with
parallel calculations using `multiprocessing`. (If calculating in parallel, it will have no effect.)
progressbar : bool
Optionally display a progress bar indicator for status
while iterating over wavelengths. Note, this requires the
optional dependency package 'tqdm', which is not included as
a requirement.
Returns
-------
outfits :
a fits.HDUList
intermediate_wfs : list of `poppy.Wavefront` objects (optional)
Only returned if `return_intermediates` is specified.
A list of `poppy.Wavefront` objects representing the wavefront at intermediate optical planes.
The 0th item is "before first optical plane", 1st is "after first plane and before second plane", and so on.
final_wfs : `poppy.Wavefront` object (optional)
Only returned if `return_final` is specified.
`poppy.Wavefront` objects representing the wavefront at the last of the optical planes.
"""
tstart = time.time()
if source is not None:
wavelength = source['wavelengths']
weight = source['weights']
# Make sure the wavelength is unit-y
if not isinstance(wavelength, u.Quantity):
wavelength = np.asarray(wavelength) * u.meter
# ensure wavelength is a quantity which is iterable:
# (the check for a quantity of type length is applied in the decorator)
if np.isscalar(wavelength.value):
wavelength = np.asarray([wavelength.value], dtype=_float()) * wavelength.unit
if weight is None:
weight = [1.0] * len(wavelength)
if len(tuple(wavelength)) != len(tuple(weight)):
raise ValueError("Input source has different number of weights and wavelengths...")
# loop over wavelengths
if self.verbose:
_log.info("Calculating PSF with %d wavelengths" % (len(wavelength)))
outfits = None
intermediate_wfs = None
if save_intermediates or return_intermediates:
_log.info("User requested saving intermediate wavefronts in call to poppy.calc_psf")
retain_intermediates = True
else:
retain_intermediates = False
normwts = np.asarray(weight, dtype=_float())
normwts /= normwts.sum()
_USE_FFTW = (conf.use_fftw and accel_math._FFTW_AVAILABLE)
if _USE_FFTW:
utils.fftw_load_wisdom()
if conf.use_multiprocessing and len(wavelength) > 1: # ######## Parallelized computation ############
# Avoid a Mac OS incompatibility that can lead to hard-to-reproduce crashes.
# see issues #23 and #176
if _USE_FFTW:
_log.warning('IMPORTANT WARNING: Python multiprocessing and fftw3 do not appear to play well together. '
'This may crash intermittently')
_log.warning(' We suggest you set poppy.conf.use_fftw to False if you want to use multiprocessing().')
if display:
_log.warning('Display during calculations is not supported for multiprocessing mode. '
'Please set poppy.conf.use_multiprocessing = False if you want to use display=True.')
_log.warning('(Plot the returned PSF with poppy.utils.display_psf.)')
if return_intermediates:
_log.warning('Memory usage warning: When preserving intermediate planes in multiprocessing mode, '
'memory usage scales with the number of planes times number of wavelengths. Disable '
'use_multiprocessing if you are running out of memory.')
if save_intermediates:
_log.warning('Saving intermediate steps does not take advantage of multiprocess parallelism. '
'Set save_intermediates=False for improved speed.')
# do *NOT* just blindly try to create as many processes as one has CPUs, or one per wavelength either
# This is a memory-intensive task so that can end up swapping to disk and thrashing IO
nproc = conf.n_processes if conf.n_processes > 1 \
else utils.estimate_optimal_nprocesses(self, nwavelengths=len(wavelength))
nproc = min(nproc, len(wavelength)) # never try more processes than wavelengths.
# be sure to cast nproc to int below; will fail if given a float even if of integer value
# Use forkserver method (requires Python >= 3.4) for more robustness, instead of just Pool
# Resolves https://github.com/mperrin/poppy/issues/23
ctx = multiprocessing.get_context('forkserver')
pool = ctx.Pool(int(nproc))
# build a single iterable containing the required function arguments
_log.info("Beginning multiprocessor job using {0} processes".format(nproc))
worker_arguments = [(self, wlen, retain_intermediates, return_final, normalize, _USE_FFTW)
for wlen in wavelength]
results = pool.map(_wrap_propagate_for_multiprocessing, worker_arguments)
_log.info("Finished multiprocessor job")
pool.close()
# Sum all the results up into one array, using the weights
outfits, intermediate_wfs = results[0]
outfits[0].data *= normwts[0]
for idx, wavefront in enumerate(intermediate_wfs):
intermediate_wfs[idx] *= normwts[0]
_log.info("got results for wavelength channel {} / {} ({:g} meters)".format(
0, len(tuple(wavelength)), wavelength[0]))
for i in range(1, len(normwts)):
mono_psf, mono_intermediate_wfs = results[i]
wave_weight = normwts[i]
_log.info("got results for wavelength channel {} / {} ({:g} meters)".format(
i, len(tuple(wavelength)), wavelength[i]))
outfits[0].data += mono_psf[0].data * wave_weight
for idx, wavefront in enumerate(mono_intermediate_wfs):
intermediate_wfs[idx] += wavefront * wave_weight
outfits[0].header.add_history("Multiwavelength PSF calc using {} processes completed.".format(nproc))
else: # ######### single-threaded computations (may still use multi cores if FFTW enabled ######
if display:
plt.clf()
iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=len(wavelength))
for wlen, wave_weight in iterate_wrapper(zip(wavelength, normwts)):
mono_psf, mono_intermediate_wfs = self.propagate_mono(
wlen,
retain_intermediates=retain_intermediates,
retain_final=return_final,
display_intermediates=display_intermediates,
normalize=normalize,
inwave=inwave
)
if outfits is None:
# for the first wavelength processed, set up the arrays where we accumulate the output
outfits = mono_psf
outfits[0].data *= wave_weight
intermediate_wfs = mono_intermediate_wfs
for wavefront in intermediate_wfs:
wavefront *= wave_weight # modifies Wavefront in-place
else:
# for subsequent wavelengths, scale and add the data to the existing arrays
outfits[0].data += mono_psf[0].data * wave_weight
for idx, wavefront in enumerate(mono_intermediate_wfs):
intermediate_wfs[idx] += wavefront * wave_weight
# Display WF if requested.
# Note - don't need to display here if we are showing all steps already
if display and not display_intermediates:
cmap = copy.copy(getattr(matplotlib.cm, conf.cmap_sequential))
cmap.set_bad('0.3')
halffov_x = outfits[0].header['PIXELSCL'] * outfits[0].data.shape[1] / 2
halffov_y = outfits[0].header['PIXELSCL'] * outfits[0].data.shape[0] / 2
extent = [-halffov_x, halffov_x, -halffov_y, halffov_y]
unit = "arcsec"
vmax = outfits[0].data.max()
vmin = vmax / 1e4
norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax) # vmin=1e-8,vmax=1e-1)
plt.xlabel(unit)
utils.imshow_with_mouseover(outfits[0].data, extent=extent, norm=norm, cmap=cmap,
origin='lower')
if save_intermediates:
_log.info('Saving intermediate wavefronts:')
for idx, wavefront in enumerate(intermediate_wfs):
filename = 'wavefront_plane_{:03d}.fits'.format(idx)
wavefront.writeto(filename, what=save_intermediates_what)
_log.info(' saved {} to {} ({} / {})'.format(save_intermediates_what, filename,
idx, len(intermediate_wfs)))
tstop = time.time()
tdelta = tstop - tstart
_log.info(" Calculation completed in {0:.3f} s".format(tdelta))
outfits[0].header.add_history("Calculation completed in {0:.3f} seconds".format(tdelta))
if _USE_FFTW and conf.autosave_fftw_wisdom:
utils.fftw_save_wisdom()
# TODO update FITS header for oversampling here if detector is different from regular?
waves = np.asarray([w.to_value(u.meter) for w in wavelength])
wts = np.asarray(weight)
mnwave = (waves * wts).sum() / wts.sum()
outfits[0].header['WAVELEN'] = (mnwave, 'Weighted mean wavelength in meters')
outfits[0].header['NWAVES'] = (waves.size, 'Number of wavelengths used in calculation')
for i in range(waves.size):
outfits[0].header['WAVE' + str(i)] = (waves[i], "Wavelength " + str(i))
outfits[0].header['WGHT' + str(i)] = (wts[i], "Wavelength weight " + str(i))
ffttype = "pyFFTW" if _USE_FFTW else "numpy.fft"
outfits[0].header['FFTTYPE'] = (ffttype, 'Algorithm for FFTs: numpy or fftw')
outfits[0].header['NORMALIZ'] = (normalize, 'PSF normalization method')
if self.verbose:
_log.info("PSF Calculation completed.")
if return_intermediates | return_final:
return outfits, intermediate_wfs
else:
return outfits
@utils.quantity_input(wavelength=u.meter)
def propagate_mono(self,
wavelength=1e-6 * u.meter,
normalize='first',
retain_intermediates=False,
retain_final=False,
display_intermediates=False,
inwave=None):
"""Propagate a monochromatic wavefront through the optical system. Called from within `calc_psf`.
Returns a tuple with a `fits.HDUList` object and a list of intermediate `Wavefront`s (empty if
`retain_intermediates=False`).
Parameters
----------
wavelength : float
Wavelength in meters
normalize : string, {'first', 'last'}
how to normalize the wavefront?
* 'first' = set total flux = 1 after the first optic, presumably a pupil
* 'last' = set total flux = 1 after the entire optical system.
* 'exit_pupil' = set total flux = 1 at the last pupil of the optical system.
* 'first=2' = set total flux = 2 after the first optic (used for debugging only)
display_intermediates : bool
Should intermediate steps in the calculation be displayed on screen? Default: False.
retain_intermediates : bool
Should intermediate steps in the calculation be retained? Default: False.
If True, the second return value of the method will be a list of `poppy.Wavefront` objects
representing intermediate optical planes from the calculation.
retain_final : bool
Should the final complex wavefront be retained? Default: False.
If True, the second return value of the method will be a single element list
(for consistency with retain intermediates) containing a `poppy.Wavefront` object
representing the final optical plane from the calculation.
Overridden by retain_intermediates.
Returns
-------
final_wf : fits.HDUList
The final result of the monochromatic propagation as a FITS HDUList
intermediate_wfs : list
A list of `poppy.Wavefront` objects representing the wavefront at intermediate optical planes.
The 0th item is "before first optical plane", 1st is "after first plane and before second plane", and so on.
(n.b. This will be empty if `retain_intermediates` is False and singular if retain_final is True.)
"""
if conf.enable_speed_tests: # pragma: no cover
t_start = time.time()
if self.verbose:
_log.info(" Propagating wavelength = {0:g}".format(wavelength))
wavefront = self.input_wavefront(wavelength, inwave=inwave)
kwargs = {'normalize': normalize,
'display_intermediates': display_intermediates,
'return_intermediates': retain_intermediates}
# Is there a more elegant way to handle optional return quantities?
# without making them mandatory.
if retain_intermediates:
wavefront, intermediate_wfs = self.propagate(wavefront, **kwargs)
else:
wavefront = self.propagate(wavefront, **kwargs)
intermediate_wfs = []
if (not retain_intermediates) & retain_final: # return the full complex wavefront of the last plane.
intermediate_wfs = [wavefront]
if conf.enable_speed_tests: # pragma: no cover
t_stop = time.time()
_log.debug("\tTIME %f s\tfor propagating one wavelength" % (t_stop - t_start))
return wavefront.as_fits(), intermediate_wfs
def display(self, **kwargs):
""" Display all elements in an optical system on screen.
Any extra arguments are passed to the `optic.display()` methods of each element.
"""
planes_to_display = [p for p in self.planes if (not isinstance(p, Detector) and not p._suppress_display)]
nplanes = len(planes_to_display)
for i, plane in enumerate(planes_to_display):
_log.info("Displaying plane {0:s} in row {1:d} of {2:d}".format(plane.name, i + 1, nplanes))
plane.display(nrows=nplanes, row=i + 1, **kwargs)
[docs]
class OpticalSystem(BaseOpticalSystem):
""" A class representing a series of optical elements,
either Pupil, Image, or Detector planes, through which light
can be propagated.
The difference between
Image and Detector planes is that Detectors have fixed pixels
in terms of arcsec/pixel regardless of wavelength (computed via
MFT) while Image planes have variable pixels scaled in terms of
lambda/D. Pupil planes are some fixed size in meters, of course.
Parameters
----------
name : string
descriptive name of optical system
oversample : int
Either how many times *above* Nyquist we should be
(for pupil or image planes), or how many times a fixed
detector pixel will be sampled. E.g. `oversample=2` means
image plane sampling lambda/4*D (twice Nyquist) and
detector plane sampling 2x2 computed pixels per real detector
pixel. Default is 2.
verbose : bool
whether to be more verbose with log output while computing
pupil_diameter : astropy.Quantity of dimension length
Diameter of entrance pupil. Defaults to size of first optical element
if unspecified, or else 1 meter.
"""
# Methods for adding or manipulating optical planes:
[docs]
def add_pupil(self, optic=None, function=None, index=None, **kwargs):
""" Add a pupil plane optic from file(s) giving transmission or OPD
1) from file(s) giving transmission and/or OPD
[set arguments `transmission=filename` and/or `opd=filename`]
2) from an already-created :py:class:`OpticalElement` object
[set `optic=that object`]
Parameters
----------
optic : poppy.OpticalElement, optional
An already-created :py:class:`OpticalElement` object you would like to add
function : string, optional
Deprecated. The name of some analytic function you would like to use.
Optional `kwargs` can be used to set the parameters of that function.
Allowable function names are Circle, Square, Hexagon, Rectangle, and FQPM_FFT_Aligner
opd, transmission : string, optional
Filenames of FITS files describing the desired optic.
index : int
Index into the optical system's planes for where to add the new optic. Defaults to
appending the optic to the end of the plane list.
Returns
-------
poppy.OpticalElement subclass
The pupil optic added (either `optic` passed in, or a new OpticalElement created)
Note: Now you can use the optic argument for either an OpticalElement or a string function name,
and it will do the right thing depending on type. Both existing arguments are left for compatibility for now.
Any provided parameters are passed to :py:class:`OpticalElement`.
"""
if function is not None:
import warnings
warnings.warn("The function argument to add_pupil is deprecated. Please provide an Optic object instead.",
DeprecationWarning)
if optic is None and function is not None:
# ease of use: 'function' input and providing 'optic' parameter as a string are synonymous.
optic = function
if isinstance(optic, OpticalElement):
# OpticalElement object provided.
# We can use it directly, but make sure the plane type is set.
optic.planetype = PlaneType.pupil
elif isinstance(optic, str):
# convenience code to instantiate objects from a string name.
raise NotImplementedError('Setting optics based on strings is now deprecated.')
elif optic is None and len(kwargs) > 0: # create image from files specified in kwargs
# create image from files specified in kwargs
optic = FITSOpticalElement(planetype=PlaneType.pupil, oversample=self.oversample, **kwargs)
elif optic is None and len(kwargs) == 0: # create empty optic.
from . import optics
optic = optics.ScalarTransmission() # placeholder optic, transmission=100%
optic.planetype = PlaneType.pupil
else:
raise TypeError("Not sure how to handle an Optic input of the provided type, {0}".format(
str(optic.__class__)))
return self._add_plane(optic, index=index, logstring="pupil plane")
[docs]
def add_image(self, optic=None, function=None, index=None, **kwargs):
""" Add an image plane optic to the optical system
That image plane optic can be specified either
1) from file(s) giving transmission or OPD
[set arguments `transmission=filename` and/or `opd=filename`]
2) from an analytic function
[set `function='circle, fieldstop, bandlimitedcoron, or FQPM'`
and set additional kwargs to define shape etc.
3) from an already-created OpticalElement object
[set `optic=that object`]
Parameters
----------
optic : poppy.OpticalElement
An already-created OpticalElement you would like to add
function: string
Name of some analytic function to add.
Optional `kwargs` can be used to set the parameters of that function.
Allowable function names are CircularOcculter, fieldstop, BandLimitedCoron, FQPM
opd, transmission : string
Filenames of FITS files describing the desired optic.
index : int
Index into the optical system's planes for where to add the new optic. Defaults to
appending the optic to the end of the plane list.
Returns
-------
poppy.OpticalElement subclass
The pupil optic added (either `optic` passed in, or a new OpticalElement created)
Notes
-----
Now you can use the optic argument for either an OpticalElement or a
string function name, and it will do the right thing depending on type.
Both existing arguments are left for back compatibility for now.
"""
if isinstance(optic, str):
function = optic
optic = None
if optic is None:
from . import optics
if function == 'CircularOcculter':
fn = optics.CircularOcculter
elif function == 'BarOcculter':
fn = optics.BarOcculter
elif function == 'fieldstop':
fn = optics.FieldStop
elif function == 'BandLimitedCoron':
fn = optics.BandLimitedCoron
elif function == 'FQPM':
fn = optics.IdealFQPM
elif function is not None:
raise ValueError("Analytic mask type '%s' is unknown." % function)
elif len(kwargs) > 0: # create image from files specified in kwargs
fn = FITSOpticalElement
else:
fn = optics.ScalarTransmission # placeholder optic, transmission=100%
optic = fn(oversample=self.oversample, **kwargs)
optic.planetype = PlaneType.image
else:
optic.planetype = PlaneType.image
optic.oversample = self.oversample # these need to match...
return self._add_plane(optic, index=index, logstring="image plane")
[docs]
def describe(self):
""" Print out a string table describing all planes in an optical system"""
print(str(self) + "\n\t" + "\n\t".join([str(p) for p in self.planes]))
# methods for dealing with wavefronts:
[docs]
def propagate(self,
wavefront,
normalize='none',
return_intermediates=False,
display_intermediates=False):
""" Core low-level routine for propagating a wavefront through an optical system
This is a **linear operator** that acts on an input complex wavefront to give an
output complex wavefront.
Parameters
----------
wavefront : Wavefront instance
Wavefront to propagate through this optical system
normalize : string
How to normalize the wavefront?
* 'first' = set total flux = 1 after the first optic, presumably a pupil
* 'last' = set total flux = 1 after the entire optical system.
* 'exit_pupil' = set total flux = 1 at the last pupil of the optical system.
display_intermediates : bool
Should intermediate steps in the calculation be displayed on screen? Default: False.
return_intermediates : bool
Should intermediate steps in the calculation be returned? Default: False.
If True, the second return value of the method will be a list of `poppy.Wavefront` objects
representing intermediate optical planes from the calculation.
Returns a wavefront, and optionally also the intermediate wavefronts after
each step of propagation.
"""
if not isinstance(wavefront, Wavefront):
raise ValueError("First argument to propagate must be a Wavefront.")
intermediate_wfs = []
# note: 0 is 'before first optical plane; 1 = 'after first plane and before second plane' and so on
for optic in self.planes:
if conf.enable_speed_tests: # pragma: no cover
s0 = time.time()
# The actual propagation:
wavefront.propagate_to(optic)
wavefront *= optic
# Normalize if appropriate:
if normalize.lower() == 'first' and wavefront.current_plane_index == 1: # set entrance plane to 1.
wavefront.normalize()
_log.debug("normalizing at first plane (entrance pupil) to 1.0 total intensity")
elif normalize.lower() == 'first=2' and wavefront.current_plane_index == 1:
# this undocumented option is present only for testing/validation purposes
wavefront.normalize()
wavefront *= np.sqrt(2)
elif normalize.lower() == 'exit_pupil': # normalize the last pupil in the system to 1
last_pupil_plane_index = np.where(np.asarray(
[p.planetype is PlaneType.pupil for p in self.planes]))[0].max() + 1
if wavefront.current_plane_index == last_pupil_plane_index:
wavefront.normalize()
_log.debug("normalizing at exit pupil (plane {0}) to 1.0 total intensity".format(
wavefront.current_plane_index))
elif normalize.lower() == 'last' and wavefront.current_plane_index == len(self.planes):
wavefront.normalize()
_log.debug("normalizing at last plane to 1.0 total intensity")
# Optional outputs:
if conf.enable_flux_tests:
_log.debug(" Flux === " + str(wavefront.total_intensity))
if conf.enable_speed_tests: # pragma: no cover
s1 = time.time()
_log.debug(f"\tTIME {s1 - s0:.4f} s\t for propagating past optic '{optic.name}'.")
if return_intermediates: # save intermediate wavefront, summed for polychromatic if needed
intermediate_wfs.append(wavefront.copy())
if display_intermediates:
wavefront._display_after_optic(optic, default_nplanes=len(self))
if return_intermediates:
return wavefront, intermediate_wfs
else:
return wavefront
def _propagation_info(self):
""" Provide some summary information on the optical propagation calculations that
would be done for a given optical system
Right now this mostly is checking whether a given propagation makes use of FFTs or not,
since the padding for oversampled FFTS majorly affects the max memory used for multiprocessing
estimation """
steps = []
for i, p in enumerate(self.planes):
if i == 0:
continue # no propagation needed for first plane
if p.planetype == PlaneType.rotation:
steps.append('rotation')
elif self.planes[i - 1].planetype == PlaneType.pupil and p.planetype == PlaneType.detector:
steps.append('MFT')
elif self.planes[i - 1].planetype == PlaneType.pupil and p.planetype == PlaneType.image:
if i > 1 and steps[-1] == 'MFT':
steps.append('invMFT')
else:
steps.append('FFT')
elif self.planes[i - 1].planetype == PlaneType.image and p.planetype == PlaneType.detector:
steps.append('resample')
else:
steps.append('FFT')
output_shape = [a * self.planes[-1].oversample for a in self.planes[-1].shape]
output_size = output_shape[0] * output_shape[1]
return {'steps': steps, 'output_shape': output_shape, 'output_size': output_size}
[docs]
class CompoundOpticalSystem(OpticalSystem):
""" A concatenation of two or more optical systems,
acting as a single larger optical system.
This can be used to combine together multiple existing
OpticalSystem instances, including mixed lists of both
Fraunhofer and Fresnel type systems.
"""
def __init__(self, optsyslist=None, name=None, **kwargs):
""" Create combined optical system,
Parameters
----------
optsyslist : List of OpticalSystem and/or FresnelOpticalSystem instances.
"""
# validate the input optical systems make sense
if optsyslist is None:
raise ValueError("Missing required optsyslist argument to CompoundOpticalSystem")
elif len(optsyslist) == 0:
raise ValueError("The provided optsyslist argument is an empty list. Must contain at least 1 optical system.")
for item in optsyslist:
if not isinstance(item, BaseOpticalSystem):
raise ValueError("All items in the optical system list must be OpticalSystem instances, not "+repr(item))
if name is None:
name = "CompoundOpticalSystem containing {} systems".format(len(optsyslist))
super(CompoundOpticalSystem, self).__init__(name=name, **kwargs)
self.optsyslist = optsyslist
def _add_plane(self, *args, **kwargs):
raise RuntimeError("Adding individual optical elements is disallowed for CompoundOpticalSystems."
" Add to an OpticalSystem instead.")
def __len__(self):
# The length of a compound optical system is the sum of the lengths of the individual systems
return np.sum([len(optsys) for optsys in self.optsyslist])
[docs]
def propagate(self,
wavefront,
normalize='none',
return_intermediates=False,
display_intermediates=False):
""" Core low-level routine for propagating a wavefront through an optical system
See docstring of OpticalSystem.propagate for details
"""
from poppy.fresnel import FresnelOpticalSystem, FresnelWavefront
if return_intermediates:
intermediate_wfs = []
# helper function for logging:
def loghistory(wavefront, msg):
_log.debug(msg)
wavefront.history.append(msg)
for i, optsys in enumerate(self.optsyslist):
# If necessary, convert wavefront type.
if (isinstance(optsys, FresnelOpticalSystem) and
not isinstance(wavefront, FresnelWavefront)):
wavefront = FresnelWavefront.from_wavefront(wavefront)
loghistory(wavefront, "CompoundOpticalSystem: Converted wavefront to Fresnel type")
elif (not isinstance(optsys, FresnelOpticalSystem) and
isinstance(wavefront, FresnelWavefront)):
wavefront = Wavefront.from_fresnel_wavefront(wavefront)
loghistory(wavefront, "CompoundOpticalSystem: Converted wavefront to Fraunhofer type")
# Propagate
loghistory(wavefront, "CompoundOpticalSystem: Propagating through system {}: {}".format(i+1, optsys.name))
retval = optsys.propagate(wavefront,
normalize=normalize,
return_intermediates=return_intermediates,
display_intermediates=display_intermediates)
# Deal with returned item(s) as appropriate
if return_intermediates:
wavefront, intermediate_wfs_i = retval
intermediate_wfs += intermediate_wfs_i
else:
wavefront = retval
if return_intermediates:
return wavefront, intermediate_wfs
else:
return wavefront
@property
def planes(self):
""" A merged list containing all the planes in all the included optical systems """
out = []
[out.extend(osys.planes) for osys in self.optsyslist]
return out
@planes.setter
def planes(self, value):
# needed for compatibility with superclass init
pass
# ------ core Optical Element Classes ------
[docs]
class OpticalElement(object):
""" Base class for all optical elements, whether from FITS files or analytic functions.
If instantiated on its own, this just produces a null optical element (empty space,
i.e. an identity function on transmitted wavefronts.) Use one of the many subclasses to
create a nontrivial optic.
The OpticalElement class follows the behavior of the Wavefront class, using units
of meters/pixel in pupil space and arcsec/pixel in image space.
The internal implementation of this class represents an optic with an array
for the electric field amplitude transmissivity (or reflectivity), plus an
array for the optical path difference in units of meters. This
representation was chosen since most typical optics of interest will have
wavefront error properties that are independent of wavelength. Subclasses
particularly the AnalyticOpticalElements extend this paradigm with optics
that have wavelength-dependent properties.
The get_phasor() function is used to obtain the complex phasor for any desired
wavelength based on the amplitude and opd arrays. Those can individually be
obtained from the get_transmission() and get_opd() functions.
Parameters
----------
name : string
descriptive name for optic
verbose : bool
whether to be more verbose in log outputs while computing
planetype : int
either poppy.PlaneType.image or poppy.PlaneType.pupil
oversample : int
how much to oversample beyond Nyquist.
interp_order : int
the order (0 to 5) of the spline interpolation used if the optic is resized.
"""
def __init__(self, name="unnamed optic", verbose=True, planetype=PlaneType.unspecified,
oversample=1, interp_order=3):
self.name = name
""" string. Descriptive Name of this optic"""
self.verbose = verbose
self.planetype = planetype # pupil or image
self.oversample = oversample # oversampling factor, none by default
self.ispadded = False # are we padded w/ zeros for oversampling the FFT?
self._suppress_display = False # should we avoid displaying this optic on screen?
# (useful for 'virtual' optics like FQPM aligner)
self.amplitude = xp.asarray([1.])
self.opd = xp.asarray([0.])
self.pixelscale = None
self.interp_order = interp_order
[docs]
def get_transmission(self, wave):
""" Return the electric field amplitude transmission, given a wavelength.
Parameters
----------
wave : float or obj
either a scalar wavelength or a Wavefront object
Returns
-------
ndarray giving electric field amplitude transmission between 0 - 1.0
"""
return xp.asarray(self.amplitude)
[docs]
def get_opd(self, wave):
""" Return the optical path difference, given a wavelength.
In this base class instance, the wavefront parameter 'wave' is not used,
and the .opd attribute of the optic is returned directly.
Subclasses may change this behavior, for instance to evaluate
optical aberrations on the sampling defined for that wavefront,
or to compute the wavelength-dependent aberrations of a refractive optic.
Parameters
----------
wave : float or obj
either a scalar wavelength or a Wavefront object
Returns
-------
ndarray giving OPD in meters
"""
return xp.asarray(self.opd)
[docs]
def get_phasor(self, wave):
""" Compute a complex phasor from an OPD, given a wavelength.
The returned value should be the complex phasor array as appropriate for
multiplying by the wavefront amplitude.
Parameters
----------
wave : float or obj
either a scalar wavelength or a Wavefront object
"""
if isinstance(wave, BaseWavefront):
wavelength = wave.wavelength
else:
wavelength = wave
scale = 2. * np.pi / wavelength.to(u.meter).value
# set the self.phasor attribute:
# first check whether we need to interpolate to do this.
float_tolerance = 0.001 # how big of a relative scale mismatch before resampling?
if self.pixelscale is not None and hasattr(wave, 'pixelscale') and abs(
wave.pixelscale - self.pixelscale) / self.pixelscale >= float_tolerance:
_log.debug("Non-matching pixel scales for wavefront and optic. Need to interpolate. "
"Pixelscales: wave {}, optic {}".format(wave.pixelscale, self.pixelscale))
if hasattr(self, '_resampled_scale') and abs(
self._resampled_scale - wave.pixelscale) / self._resampled_scale >= float_tolerance:
# we already did this same resampling, so just re-use it!
self.phasor = self._resampled_amplitude * xp.exp(1j * self._resampled_opd * scale)
else:
# raise NotImplementedError("Need to implement resampling.")
zoom = (self.pixelscale / wave.pixelscale).decompose().value
original_opd = self.get_opd(wave)
resampled_opd = _scipy.ndimage.zoom(original_opd, zoom, output=original_opd.dtype, order=self.interp_order)
original_amplitude = self.get_transmission(wave)
resampled_amplitude = _scipy.ndimage.zoom(original_amplitude, zoom, output=original_amplitude.dtype, order=self.interp_order)
_log.debug("resampled optic to match wavefront via spline interpolation by a" +
" zoom factor of {:.3g}".format(zoom))
_log.debug("resampled optic shape: {} wavefront shape: {}".format(resampled_amplitude.shape,
wave.shape))
lx, ly = resampled_amplitude.shape
# crop down to match size of wavefront:
lx_w, ly_w = wave.amplitude.shape
border_x = np.abs(lx - lx_w) // 2
border_y = np.abs(ly - ly_w) // 2
if (self.pixelscale * self.amplitude.shape[0] < wave.pixelscale * wave.amplitude.shape[0]) or (
self.pixelscale * self.amplitude.shape[1] < wave.pixelscale * wave.amplitude.shape[0]):
_log.warning("After resampling, optic phasor shape " + str(np.shape(resampled_opd)) +
" is smaller than input wavefront " + str(
(lx_w, ly_w)) + "; will zero-pad the rescaled array.")
self._resampled_opd = xp.zeros([lx_w, ly_w])
self._resampled_amplitude = xp.zeros([lx_w, ly_w])
self._resampled_opd[border_x:border_x + resampled_opd.shape[0],
border_y:border_y + resampled_opd.shape[1]] = resampled_opd
self._resampled_amplitude[border_x:border_x + resampled_opd.shape[0],
border_y:border_y + resampled_opd.shape[1]] = resampled_amplitude
_log.debug("padded an optic with a {:d} x {:d} border to "
"optic to match the wavefront".format(border_x, border_y))
else:
self._resampled_opd = resampled_opd[border_x:border_x + lx_w, border_y:border_y + ly_w]
self._resampled_amplitude = resampled_amplitude[border_x:border_x + lx_w, border_y:border_y + ly_w]
_log.debug("trimmed a border of {:d} x {:d} pixels from "
"optic to match the wavefront".format(border_x, border_y))
self.phasor = self._resampled_amplitude * xp.exp(1j * self._resampled_opd * scale)
else:
# compute the phasor directly, without any need to rescale.
if accel_math._USE_NUMEXPR:
trans = self.get_transmission(wave)
opd = self.get_opd(wave)
self.phasor = ne.evaluate("trans * exp(1j * opd * scale)")
else:
self.phasor = self.get_transmission(wave) * xp.exp(1j * self.get_opd(wave) * scale)
# check whether we need to pad or crop the array before returning or not.
# note: do not pad the phasor if it's just a scalar!
if self.phasor.size != 1 and self.phasor.shape != wave.shape:
# pad to match the wavefront sampling, from whatever sized array we started with.
# Allows more flexibility for differently sized FITS arrays, so long as they all have the
# same pixel scale as checked above!
return utils.pad_or_crop_to_shape(self.phasor, wave.shape)
else:
return self.phasor
[docs]
@utils.quantity_input(opd_vmax=u.meter, wavelength=u.meter)
def display(self, nrows=1, row=1, what='intensity', crosshairs=False, ax=None, colorbar=True,
colorbar_orientation=None, title=None, opd_vmax=0.5e-6 * u.meter,
wavelength=1e-6 * u.meter,
npix=512, grid_size=None):
"""Display plots showing an optic's transmission and OPD.
Parameters
----------
what : str
What to display: 'intensity', 'amplitude', 'phase', 'opd',
or 'both' (meaning intensity and OPD in two subplots)
ax : matplotlib.Axes instance
Axes to display into
nrows, row : integers
number of rows and row index for subplot display
crosshairs : bool
Display crosshairs indicating the center?
colorbar : bool
Show colorbar?
colorbar_orientation : bool
Desired orientation, horizontal or vertical?
Default is horizontal if only 1 row of plots, else vertical
opd_vmax : float
Max absolute value for OPD image display, in meters.
title : string
Plot label
wavelength : float, default 1 micron
For optics with wavelength-dependent behavior, evaluate at this
wavelength for display.
npix : integer
For optics without a fixed pixel sampling, evaluate onto this many
pixels for display.
grid_size : float
For optics without a fixed pixel sampling, evaluate onto this large
a spatial or angular extent for display. Specify in units of
arcsec for image plane optics, meters for all other optics.
If unspecified, a default value will be chosen instead, possibly
from the ._default_display_size attribute, if present.
"""
if colorbar_orientation is None:
colorbar_orientation = "horizontal" if nrows == 1 else 'vertical'
if self.planetype is PlaneType.pupil:
cmap_amp = copy.copy(getattr(matplotlib.cm, conf.cmap_pupil_intensity))
else:
cmap_amp = copy.copy(getattr(matplotlib.cm, conf.cmap_sequential))
cmap_amp.set_bad('0.0')
cmap_opd = copy.copy(getattr(matplotlib.cm, conf.cmap_diverging))
cmap_opd.set_bad('0.3')
norm_amp = matplotlib.colors.Normalize(vmin=0, vmax=1)
opd_vmax_m = opd_vmax.to(u.meter).value
norm_opd = matplotlib.colors.Normalize(vmin=-opd_vmax_m, vmax=opd_vmax_m)
# TODO infer correct units from pixelscale's units?
units = "[arcsec]" if self.planetype == PlaneType.image else "[meters]"
if self.pixelscale is not None:
if self.pixelscale.unit.is_equivalent(u.arcsec/u.pix):
units = "[arcsec]"
if nrows > 1:
# for display inside an optical system, we repurpose the units display to label the plane
units = self.name + "\n" + units
# and wrap long lines if necessary
if len(units) > 20:
units = "\n".join(textwrap.wrap(units, 20))
# Create a wavefront object to use when evaluating/sampling the optic.
if self.pixelscale is not None and self.shape is not None:
# This optic has an inherent sampling. The display wavefront's sampling is
# irrelevant; we get the native pixel scale opd and amplitude regardless and
# display that.
temp_wavefront = Wavefront(wavelength, npix=2)
disp_pixelscale = self.pixelscale
disp_shape = self.shape
else:
# this optic does not have an inherent sampling. Set up the display wavefront based on
# the parameters to this function call, and/or object attributes for defaults.
# The syntax for how to do that depends on image plane vs other kinds of planes.
# This code is partially duplicative of AnalyticOpticalElement.sample()
if self.planetype == PlaneType.image:
if grid_size is not None:
fov = grid_size if isinstance(grid_size, u.Quantity) else grid_size * u.arcsec
elif hasattr(self, '_default_display_size'):
fov = self._default_display_size
else:
fov = 4 * u.arcsec
pixel_scale = fov / (npix * u.pixel)
temp_wavefront = Wavefront(wavelength=wavelength, npix=npix, pixelscale=pixel_scale)
else:
if grid_size is not None:
diam = grid_size if isinstance(grid_size, u.Quantity) else grid_size * u.meter
elif hasattr(self, '_default_display_size'):
diam = self._default_display_size
elif hasattr(self, 'pupil_diam'):
diam = self.pupil_diam * 1
else:
diam = 1.0 * u.meter
temp_wavefront = Wavefront(wavelength=wavelength, npix=npix, diam=diam)
_log.info("Computing {0} for {1} sampled onto {2} pixel grid with "
"pixelscale {3}".format(what, self.name, npix, temp_wavefront.pixelscale))
disp_pixelscale = temp_wavefront.pixelscale
disp_shape = temp_wavefront.shape
# Determine the extent of the image in physical units, for axes labels.
_log.debug("Display pixel scale = {} ".format(disp_pixelscale))
if disp_pixelscale.decompose().unit == u.m / u.pix:
halfsize = disp_pixelscale.to(u.m / u.pix).value * disp_shape[0] / 2
elif disp_pixelscale.decompose().unit == u.radian / u.pix:
halfsize = disp_pixelscale.to(u.arcsec / u.pix).value * disp_shape[0] / 2
else:
raise RuntimeError("Pixelscale units not recognized in display; "
"must be equivalent to arcsec/pix or m/pix")
extent = [-halfsize, halfsize, -halfsize, halfsize]
# Evaluate the wavefront at the desired sampling and pixel scale.
ampl = self.get_transmission(temp_wavefront)
opd = self.get_opd(temp_wavefront).copy()
opd[(ampl == 0)] = np.nan
# define a helper function for the actual plotting - we do it this way so
# we can call it twice if the 'both' option is chosen. This avoids the complexities of the
# earlier version of this function which called itself recursively to show both.
def optic_display_helper(plot_array, ax, title, is_opd=False):
if is_opd:
cmap = cmap_opd
norm = norm_opd
cb_values = np.array([-1, -0.5, 0, 0.5, 1]) * opd_vmax_m
cb_label = 'OPD [meters]'
else:
cmap = cmap_amp
norm = norm_amp
cb_values = [0, 0.25, 0.5, 0.75, 1.0]
cb_label = 'Transmission amplitude'
utils.imshow_with_mouseover(plot_array, ax=ax, extent=extent, cmap=cmap, norm=norm,
origin='lower')
ax.set_title(title)
ax.set_ylabel(units)
ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=4, integer=True))
ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=4, integer=True))
if colorbar:
cb = plt.colorbar(mappable=ax.images[0], ax=ax, orientation=colorbar_orientation, ticks=cb_values)
cb.set_label(cb_label)
if crosshairs:
ax.axhline(0, ls=":", color='k')
ax.axvline(0, ls=":", color='k')
if hasattr(self, 'display_annotate'):
self.display_annotate(self, ax) # atypical calling convention needed empirically
# since Python doesn't seem to automatically pass
# self as first argument for functions added at
# run time as attributes?
if what == 'both':
ax1 = plt.subplot(nrows, 2, row * 2 - 1)
optic_display_helper(ampl**2, ax1, 'Transmittance for '+self.name, False)
ax2 = plt.subplot(nrows, 2, row * 2)
optic_display_helper(opd, ax2, 'OPD for '+self.name, True)
ax2.set_ylabel('') # suppress redundant label which duplicates the intensity plot's label
if title is not None:
plt.suptitle(title)
return ax1, ax2
else:
if what == 'amplitude':
plot_array = ampl
default_title = 'Transmissivity'
is_opd = False
elif what == 'intensity':
plot_array = ampl ** 2
default_title = "Transmittance"
is_opd = False
elif what == 'opd':
plot_array = opd
default_title = "OPD"
is_opd = True
else:
raise ValueError("Invalid value for 'what' parameter. Must be one of {'amplitude', 'intensity', 'opd', 'both'}.")
# now we plot whichever was chosen...
if ax is None:
if nrows > 1:
ax = plt.subplot(nrows, 2, row * 2 - 1)
else:
ax = plt.subplot(1, 1, 1)
if nrows == 1:
if title is None:
title = default_title + " for " + self.name
# do the actual plot! (for all cases except 'both')
optic_display_helper(plot_array, ax, title, is_opd)
return ax
def __str__(self):
if self.planetype == PlaneType.pupil:
return "Pupil plane: {} ".format(self.name)
elif self.planetype == PlaneType.image:
desc = "({}x{} pixels, scale={} arcsec/pixel)".format(self.shape[0], self.shape[0],
self.pixelscale) if \
self.pixelscale is not None else "(Analytic)"
return "Image plane: %s %s" % (self.name, desc)
else:
return "Optic: " + self.name
@property
def shape(self):
""" Return shape of the OpticalElement, as a tuple """
if hasattr(self, 'amplitude'):
return self.amplitude.shape
else:
return None
[docs]
class ArrayOpticalElement(OpticalElement):
""" Defines an arbitrary optic, based on amplitude transmission and/or OPD given as numpy arrays.
This is a very lightweight wrapper for the base OpticalElement class, which just provides some
additional convenience features in the initializer.
"""
def __init__(self, opd=None, transmission=None, pixelscale=None, **kwargs):
super(ArrayOpticalElement, self).__init__(**kwargs)
if opd is not None:
self.opd = opd
if transmission is not None:
self.amplitude = transmission
if opd is None:
self.opd = xp.zeros_like(transmission)
elif transmission is None and opd is not None:
self.amplitude = xp.ones_like(opd)
if pixelscale is not None:
self.pixelscale = pixelscale
[docs]
class FITSOpticalElement(OpticalElement):
""" Defines an arbitrary optic, based on amplitude transmission and/or OPD FITS files.
This optic could be a pupil or field stop, an aberrated mirror, a phase mask, etc.
The FITSOpticalElement class follows the behavior of the Wavefront class, using units
of meters/pixel in pupil space and arcsec/pixel in image space.
The interface is **very** flexible. You can define a FITSOpticalElement either from
* a single FITS file giving the amplitude transmission (in which case phase is zero)
* a single FITS file giving the OPD (in which case transmission is 1 everywhere)
* two FITS files specifying both transmission and OPD.
The FITS file argument(s) can be supplied either as
1. a string giving the path to a file on disk,
2. a FITS HDUlist object, or
3. in the case of OPDs, a tuple consisting of a path to a datacube and an integer index of
a slice in that datacube.
A better interface for slice selection in datacubes is the transmission_index
and opd_index keyword parameters listed below, but the tuple interface is
retained for back compatibility with existing code.
The FITS file header must provide information on the pixel scale, preferentially
via a PIXELSCL keyword, or else you must supply the pixel scale directly. See
more information just below in the parameter documentation for pixelscale.
Parameters
----------
name : string
descriptive name for optic
transmission, opd : string or fits HDUList
Either FITS filenames *or* actual fits.HDUList objects for the
transmission (from 0-1) and opd (in meters)
transmission_slice, opd_slice : integers, optional
If either transmission or OPD files are datacubes, you can specify the
slice index using this argument.
opdunits : string
units for the OPD file. Default is 'meters'. can be 'meter', 'meters',
'micron(s)', 'nanometer(s)', or their SI abbreviations. If this keyword
is not set explicitly, the BUNIT keyword in the FITS header will be checked.
planetype : int
either PlaneType.image or PlaneType.pupil
oversample : int
how much to oversample beyond Nyquist.
flip_x, flip_y : bool
Should the FITS file be inverted in either of these axes after being
loaded? Useful for matching coordinate system orientations. If a flip
is specified, it takes place prior to any shift or rotation operations.
shift : tuple of floats, optional
2-tuple containing X and Y fractional shifts for the pupil. These shifts
are implemented by rounding them to the nearest integer pixel, and doing
integer pixel shifts on the data array, without interpolation. If a
shift is specified, it takes place after any rotation operations.
shift_x, shift_y : floats, optional
Alternate way of specifying shifts, given in meters of shift per each axis.
This is consistent with how AnalyticOpticalElement classes specify shifts.
If a shift is specified, it takes place after any rotation operations.
If both shift and shift_x/shift_y are specified, an error is raised.
rotation : float
Rotation for that optic, in degrees counterclockwise. This is
implemented using spline interpolation via the
scipy.ndimage.interpolation.rotate function.
pixelscale : optical str or float
By default, poppy will attempt to determine the appropriate pixel scale
by examining the FITS header, checking keywords "PIXELSCL", "PUPLSCAL" and/or 'PIXSCALE'.
PIXELSCL is the default and should be preferred for new files; the latter two are
kept for back-compatibility with earlier format input files,
for pupil and image planes respectively. If you would like to override
and use a different keyword, provide that as a string here. Alternatively,
you can just set a floating point value directly too (in meters/pixel
or arcsec/pixel, respectively, for pupil or image planes).
transmission_index, opd_index : ints, optional
If the input transmission or OPD files are datacubes, provide a scalar
index here for which cube slice should be used.
*NOTE:* All mask files must be *squares*.
Also, please note that the adopted convention is for the spectral throughput
(transmission) to be given in appropriate units for acting on the *amplitude*
of the electric field. Thus for example an optic with a uniform transmission
of 0.5 will reduce the electric field amplitude to 0.5 relative to the input,
and thus reduce the total power to 0.25. This distinction only matters in the
case of semitransparent (grayscale) masks.
"""
def __init__(self, name="unnamed optic", transmission=None, opd=None, opdunits=None,
rotation=None, pixelscale=None, planetype=None,
transmission_index=None, opd_index=None,
shift=None, shift_x=None, shift_y=None,
flip_x=False, flip_y=False,
**kwargs):
OpticalElement.__init__(self, name=name, **kwargs)
self.opd_file = None
self.amplitude_file = None
self.amplitude_header = None
self.opd_header = None
self._opd_in_radians = False
self.planetype = planetype
_log.debug("Trans: " + str(transmission))
_log.debug("OPD: " + str(opd))
# ---- Load amplitude transmission file. ---
if opd is None and transmission is None: # no input files, so just make a scalar
_log.warning("No input files specified. You should set transmission=filename or opd=filename.")
_log.warning("Creating a null optical element. Are you sure that's what you want to do?")
self.amplitude = xp.asarray([1.])
self.opd = xp.asarray([0.])
self.pixelscale = None
self.name = "-empty-"
else:
# load transmission file.
if transmission is not None:
if isinstance(transmission, str):
self.amplitude_file = transmission
self.amplitude, self.amplitude_header = fits.getdata(self.amplitude_file, header=True)
self.amplitude = self.amplitude.astype('=f8') # ensure native byte order, see #213
self.amplitude = xp.asarray(self.amplitude) # sets to CuPy array if xp is cupy
if self.name == 'unnamed optic':
self.name = 'Optic from ' + self.amplitude_file
_log.info(self.name + ": Loaded amplitude transmission from " + self.amplitude_file)
elif isinstance(transmission, fits.HDUList):
self.amplitude_file = 'supplied as fits.HDUList object'
self.amplitude = transmission[0].data.astype('=f8') # ensure native byte order, see #213
self.amplitude = xp.asarray(self.amplitude) # sets to CuPy array if xp is cupy
self.amplitude_header = transmission[0].header.copy()
if self.name == 'unnamed optic':
self.name = 'Optic from fits.HDUList object'
_log.info(self.name + ": Loaded amplitude transmission from supplied fits.HDUList object")
else:
raise TypeError('Not sure how to use a transmission parameter of type ' + str(type(transmission)))
# check for datacube?
if len(self.amplitude.shape) > 2:
if transmission_index is None:
_log.info("The supplied pupil amplitude is a datacube but no slice was specified. "
"Defaulting to use slice 0.")
transmission_index = 0
self.amplitude_slice_index = transmission_index
self.amplitude = self.amplitude[self.amplitude_slice_index, :, :]
_log.debug(" Datacube detected, using slice ={0}".format(self.amplitude_slice_index))
else:
_log.debug("No transmission supplied - will assume uniform throughput = 1 ")
# if transmission is none, wait until after OPD is loaded, below, and then create a matching
# amplitude array uniformly filled with 1s.
# ---- Load OPD file. ---
if opd is None:
# if only amplitude set, create an array of 0s with same size.
self.opd = xp.zeros(self.amplitude.shape)
opdunits = 'meter' # doesn't matter, it's all zeros, but this will indicate no need to rescale below.
elif isinstance(opd, fits.HDUList):
# load from fits HDUList
self.opd_file = 'supplied as fits.HDUList object'
self.opd = opd[0].data.astype('=f8')
self.opd_header = opd[0].header.copy()
if self.name == 'unnamed optic':
self.name = 'OPD from supplied fits.HDUList object'
_log.info(self.name + ": Loaded OPD from supplied fits.HDUList object")
elif isinstance(opd, str):
# load from regular FITS filename
self.opd_file = opd
self.opd, self.opd_header = fits.getdata(self.opd_file, header=True)
self.opd = self.opd.astype('=f8')
self.opd = xp.asarray(self.opd) # sets to CuPy array if xp is cupy
if self.name == 'unnamed optic': self.name = 'OPD from ' + self.opd_file
_log.info(self.name + ": Loaded OPD from " + self.opd_file)
elif len(opd) == 2 and isinstance(opd[0], str):
# if OPD is specified as a 2-element iterable, treat the first element as the filename
# and 2nd as the slice of a cube.
self.opd_file = opd[0]
self.opd_slice = opd[1]
self.opd, self.opd_header = fits.getdata(self.opd_file, header=True)
self.opd = self.opd.astype('=f8')
self.opd = self.opd[self.opd_slice, :, :]
if self.name == 'unnamed optic':
self.name = 'OPD from %s, plane %d' % (self.opd_file, self.opd_slice)
_log.info(self.name + ": Loaded OPD from %s, plane %d" % (self.opd_file, self.opd_slice))
else:
raise TypeError('Not sure how to use an OPD parameter of type ' + str(type(transmission)))
# check for datacube?
if len(self.opd.shape) > 2:
if opd_index is None:
_log.info("The supplied pupil OPD is a datacube but no slice was specified. "
"Defaulting to use slice 0.")
opd_index = 0
self.opd_slice = opd_index
self.opd = self.opd[self.opd_slice, :, :]
_log.debug(" Datacube detected, using slice ={0}".format(self.opd_slice))
if transmission is None:
_log.info("No info supplied on amplitude transmission; assuming uniform throughput = 1")
self.amplitude = xp.ones(self.opd.shape)
if opdunits is None:
try:
opdunits = self.opd_header['BUNIT']
except KeyError:
_log.error("No opdunits keyword supplied, and BUNIT keyword not found in header. "
"Cannot determine OPD units")
raise Exception("No opdunit keyword supplied, and BUNIT keyword not found in header. "
"Cannot determine OPD units.")
# normalize and drop any trailing 's'
opdunits = opdunits.lower()
if opdunits.endswith('s'):
opdunits = opdunits[:-1]
# rescale OPD to meters if necessary
if opdunits in ('meter', 'm'):
pass
elif opdunits in ('micron', 'um', 'micrometer'):
self.opd *= 1e-6
elif opdunits in ('nanometer', 'nm'):
self.opd *= 1e-9
elif opdunits == 'radian':
self._opd_in_radians = True
else:
raise ValueError(
"Got opdunits (or BUNIT header keyword) {}. Valid options "
"are meter, micron, nanometer, or radian.".format(repr(opdunits))
)
if self.opd_header is not None and not self._opd_in_radians:
self.opd_header['BUNIT'] = 'meter'
if len(self.opd.shape) != 2 or self.opd.shape[0] != self.opd.shape[1]:
_log.debug('OPD shape: ' + str(self.opd.shape))
raise ValueError("OPD image must be 2-D and square")
if len(self.amplitude.shape) != 2 or self.amplitude.shape[0] != self.amplitude.shape[1]:
raise ValueError("Pupil amplitude image must be 2-D and square")
assert self.amplitude.shape == self.opd.shape, "Amplitude and OPD FITS file shapes are incompatible."
assert self.amplitude.shape[0] == self.amplitude.shape[1], "Amplitude and OPD FITS files must be square."
# ---- transformation: inversion ----
# if an inversion is specified and we're not a null (scalar) opticm then do the inversion:
if flip_y and len(self.amplitude.shape) == 2:
self.amplitude = self.amplitude[::-1]
self.opd = self.opd[::-1]
_log.debug("Inverted optic in the Y axis")
if flip_x and len(self.amplitude.shape) == 2:
self.amplitude = self.amplitude[:, ::-1]
self.opd = self.opd[:, ::-1]
_log.debug("Inverted optic in the X axis")
# ---- transformation: rotation ----
# If a rotation is specified and we're NOT a null (scalar) optic, then do the rotation:
if rotation is not None and len(self.amplitude.shape) == 2:
k, remainder = np.divmod(rotation, 90)
if remainder == 0:
# rotation is a multiple of 90
self.amplitude = xp.rot90(self.amplitude, k=-k) # negative = CCW
self.opd = xp.rot90(self.opd, k=-k)
else:
# arbitrary free rotation with interpolation
# do rotation with interpolation, but try to clean up some of the artifacts afterwards.
# this is imperfect at best, of course...
self.amplitude = _scipy.ndimage.rotate(xp.asarray(self.amplitude) , -rotation, # negative = CCW
reshape=False).clip(min=0, max=1.0)
wnoise = (self.amplitude < 1e-3) & (self.amplitude > 0)
self.amplitude[wnoise] = 0
self.opd = _scipy.ndimage.rotate(self.opd, -rotation, reshape=False) # negative = CCW
_log.info(" Rotated optic by %f degrees counter clockwise." % rotation)
self._rotation = rotation
# ---- Determine the pixel scale for this image. ----
_MISSING_PIXELSCALE_MSG = ("No FITS header keyword for pixel scale found "
"(tried: {}). Supply pixelscale as a float in "
"meters/px or arcsec/px, or as a string specifying which "
"header keyword to use.")
def _find_pixelscale_in_headers(keywords, headers):
"""
Loops through provided possible FITS header keywords and a list of FITS
header objects (may contain Nones), returning the first
(keyword, header value) pair found
"""
for keyword in keywords:
for header in headers:
if header is not None and keyword in header:
return keyword, header[keyword]
raise LookupError(_MISSING_PIXELSCALE_MSG.format(', '.join(keywords)))
# The following logic is convoluted for historical back compatibility.
# All new files should use PIXELSCL. But we still allow reading in
# older files with PIXSCALE or PUPLSCAL.
# This code can probably be simplified.
if pixelscale is None and self.planetype is None:
# we don't know which keywords might be present yet, so check for both keywords
# in both header objects (at least one must be non-None at this point!)
_log.debug(" Looking for 'PUPLSCAL' or 'PIXSCALE' or 'PIXELSCL' in FITS headers to set "
"pixel scale")
keyword, self.pixelscale = _find_pixelscale_in_headers(
('PUPLSCAL', 'PIXSCALE', 'PIXELSCL'),
(self.amplitude_header, self.opd_header)
)
if keyword == 'PUPLSCAL':
self.planetype = PlaneType.pupil
else:
self.planetype = PlaneType.image
elif pixelscale is None and self.planetype == PlaneType.image:
# the planetype tells us which header keyword to check when a keyword is
# not provided (PIXSCALE for image planes)...
_, self.pixelscale = _find_pixelscale_in_headers(
('PIXELSCL', 'PIXSCALE'),
(self.amplitude_header, self.opd_header)
)
elif pixelscale is None and (self.planetype == PlaneType.pupil or self.planetype == _INTERMED):
# ... likewise for pupil planes
_, self.pixelscale = _find_pixelscale_in_headers(
('PIXELSCL', 'PUPLSCAL',),
(self.amplitude_header, self.opd_header)
)
elif isinstance(pixelscale, str):
# If provided as a keyword string, check for it using the same helper function
_log.debug(" Getting pixel scale from FITS keyword:" + pixelscale)
_, self.pixelscale = _find_pixelscale_in_headers(
(pixelscale,),
(self.opd_header, self.amplitude_header)
)
else:
# pixelscale had better be a floating point value here.
try:
_log.debug(" Getting pixel scale from user-provided float value: " +
str(pixelscale))
self.pixelscale = float(pixelscale)
except ValueError:
raise ValueError("pixelscale=%s is neither a FITS keyword string "
"nor a floating point value." % str(pixelscale))
# now turn the pixel scale into a Quantity
if self.planetype == PlaneType.image:
self.pixelscale *= u.arcsec / u.pixel
else: # pupil or any other types of plane
self.pixelscale *= u.meter / u.pixel
# ---- transformation: shift ----
# if a shift is specified and we're NOT a null (scalar) optic, then do the shift
# This has to happen after the pixelscale has been determined, for the shift_x/shift_y path.
if shift is not None and (shift_x is not None or shift_y is not None):
raise RuntimeError("You cannot specify both the shift and shift_x/shift_y parameters simultaneously.")
elif ((shift is not None) or (shift_x is not None or shift_y is not None)) and len(self.amplitude.shape) == 2:
if shift_x is not None or shift_y is not None:
# determine shift using the shift_x and shift_y parameters
if shift_x is None: shift_x = 0
if shift_y is None: shift_y = 0
rollx = int(shift_x/self.pixelscale.to(u.m/u.pixel).value)
rolly = int(shift_y/self.pixelscale.to(u.m/u.pixel).value)
_log.info("Requested optic shift of ({:6.3f}, {:6.3f}) meters".format(shift_x, shift_y))
_log.info("Actual shift applied = ({:6.3f}, {:6.3f}) pixels".format(rollx, rolly))
elif shift is not None:
# determine shift using the shift tuple
if abs(shift[0]) > 0.5 or abs(shift[1]) > 0.5:
raise ValueError("You have asked for an implausibly large shift. Remember, "
"shifts should be specified as decimal values between -0.5 and 0.5, "
"a fraction of the total optic diameter. ")
rolly = int(np.round(self.amplitude.shape[0] * shift[1])) # remember Y,X order for shape,
# but X,Y order for shift
rollx = int(np.round(self.amplitude.shape[1] * shift[0]))
_log.info("Requested optic shift of ({:6.3f}, {:6.3f}) fraction of pupil ".format(*shift))
_log.info("Actual shift applied = (%6.3f, %6.3f) " % (
rollx * 1.0 / self.amplitude.shape[1], rolly * 1.0 / self.amplitude.shape[0]))
self._shift = (rollx * 1.0 / self.amplitude.shape[1], rolly * 1.0 / self.amplitude.shape[0])
self.amplitude = _scipy.ndimage.shift(self.amplitude, (rolly, rollx))
self.opd = _scipy.ndimage.shift(self.opd, (rolly, rollx))
@property
def pupil_diam(self):
"""Diameter of the pupil (if this is a pupil plane optic)"""
return self.pixelscale * (self.amplitude.shape[0] * u.pixel)
[docs]
def get_opd(self, wave):
""" Return the optical path difference, given a wavelength.
When the OPD map is defined in terms of wavelength-independent
phase, as in the case of the vector apodizing phase plate
coronagraph of Snik et al. (Proc. SPIE, 2012), it is converted
to optical path difference in meters at the given wavelength for
consistency with the rest of POPPY.
Parameters
----------
wave : float or obj
either a scalar wavelength or a Wavefront object
Returns
-------
ndarray giving OPD in meters
"""
if isinstance(wave, BaseWavefront):
wavelength = wave.wavelength
else:
wavelength = wave
# casts to xp.asarray here needed to support the GPU case
if self._opd_in_radians:
return xp.asarray(self.opd * wavelength.to(u.m).value / (2 * np.pi))
return xp.asarray(self.opd)
class CoordinateTransform(OpticalElement):
""" Performs a coordinate transformation (rotation or axes inversion
in the optical train.
This is not an actual optic itself but a placeholder to indicate
when a coordinate transform should take place.
You should generally not need to use this class or its subclasses directly;
rather use the OpticalSystem add_rotation or add_inversion functions to
insert these as needed into optical systems.
Parameters
----------
hide : bool
Should this optic be displayed or hidden when showing the
planes of an OpticalSystem?
"""
def __init__(self, name='Coordinate transform', hide=False, **kwargs):
OpticalElement.__init__(self, name=name, **kwargs)
self._suppress_display = hide
def get_phasor(self, wave):
return 1.0 # no change in wavefront
# returning this is necessary to allow the multiplication in propagate_mono to be OK
def display(self, nrows=1, row=1, ax=None, **kwargs):
if ax is None:
ax = plt.subplot(nrows, 2, row * 2 - 1)
plt.text(0.3, 0.3, self.name)
return ax
[docs]
class Rotation(CoordinateTransform):
""" Performs a rotation of the axes in the optical train.
This is not an actual optic itself, of course, but can be used to model
a rotated optic by appling a Rotation before and/or after light is incident
on that optic.
This is basically a placeholder to indicate the need for a rotation at a
given part of the optical train. The actual rotation computation is performed
in the Wavefront object's propagation routines.
Parameters
----------
angle : float
Rotation angle, counterclockwise. By default in degrees.
units : 'degrees' or 'radians'
Units for the rotation angle.
hide : bool
Should this optic be displayed or hidden when showing the
planes of an OpticalSystem?
"""
def __init__(self, angle=0.0, units='degrees', hide=False, **kwargs):
if units == 'radians':
angle *= np.pi / 180
elif units == 'degrees':
pass
else:
raise ValueError("Unknown value for units='%s'. Must be degrees or radians." % units)
self.angle = angle
CoordinateTransform.__init__(self, name="Rotation by %.2f degrees" % angle,
planetype=PlaneType.rotation, hide=hide, **kwargs)
def __str__(self):
return "Rotation by %f degrees counter clockwise" % self.angle
class CoordinateInversion(CoordinateTransform):
""" Coordinate axis inversion indicator.
The actual inversion happens in Wavefront.propagate_to
Parameters
----------
axes : string
either 'both', 'x', or 'y', for which axes to invert
hide : bool
Should this optic be displayed or hidden when showing the
planes of an OpticalSystem?
"""
def __init__(self, name='Coordinate inversion', axis='both', hide=False, **kwargs):
self.axis = axis.lower()
CoordinateTransform.__init__(self, name=name,
planetype=PlaneType.inversion, hide=hide, **kwargs)
def __str__(self):
return "Coordinate Inversion in {} axis".format(self.axis)
# ------ Detector ------
[docs]
class Detector(OpticalElement):
""" A Detector is a specialized type of OpticalElement that forces a wavefront
onto a specific fixed pixelization of an Image plane.
This class is in effect just a metadata container for the desired sampling;
all the machinery for transformation of a wavefront to that sampling happens
within Wavefront.
Note that this is *not* in any way a representation of real noisy detectors;
no model for read noise, imperfect sensitivity, etc is included whatsoever.
Parameters
----------
name : string
Descriptive name
pixelscale : float or astropy.units.Quantity
Pixel scale, either in angular units such as arcsec/pixel, or
(for Fresnel optical systems only) in physical units such as micron/pixel.
Units should be specified as astropy Quantities. If pixelscale is given as
a float without an explicit unit, it will be interpreted as in arcsec/pixel.
Note, this value may be further subdivided by specifying the oversample
parameter > 1.
fov_pixels, fov_arcsec : float or astropy.units.Quantity
The field of view may be specified either in arcseconds or by a number
of pixels. Either is acceptable and the pixel scale is used to convert
as needed. You may specify a non-square FOV by providing two elements in
an iterable. Note that this follows the usual Python convention of
ordering axes (Y,X), so put your desired Y axis size first.
For Fresnel optical systems, if specifying pixelscale in microns/pixel then
you must specify fov_pixels rather than fov_arcsec.
oversample : int
Oversampling factor beyond the detector pixel scale. The returned array will
have sampling that much finer than the specified pixelscale.
offset : 2-tuple of floats
Offset (Y,X) in *pixels* for shifting the detector relative to the notional center of the output beam.
This has similar effect to shifting the source, but with opposite sign.
In other words, shifting a light source +1 arcsec in Y should have the same effect as
shifting the detector -1 arcsec in Y.
"""
# Note, pixelscale argument is intentionally not included in the quantity_input decorator; that is
# specially handled. See the _handle_pixelscale_units_flexibly method
@utils.quantity_input(fov_pixels=u.pixel, fov_arcsec=u.arcsec)
def __init__(self, pixelscale=1 * (u.arcsec / u.pixel), fov_pixels=None, fov_arcsec=None, oversample=1,
name="Detector",
offset=None,
**kwargs):
OpticalElement.__init__(self, name=name, planetype=PlaneType.detector, **kwargs)
self.pixelscale = self._handle_pixelscale_units_flexibly(pixelscale, fov_pixels)
self.oversample = oversample
if offset is not None:
if len(offset) != 2:
raise ValueError("If a detector offset is specified, it must be a tuple or list with 2 elements, "
"giving the (X, Y) offsets.")
# The offset is specified in pixels, so this can have units of pixels,
# or else if an integer or float, that's considered as implicitly a number of pixels
if isinstance(offset, u.Quantity):
try:
offset = offset.to_value(u.pixel)
except u.UnitConversionError:
raise(ValueError(f"A detector offset must be specified in units of detector pixels, not '{offset.unit}'"))
offset = np.asarray(offset) # ensure it's an ndarray, not just a list or tuple
# A note on sign convention for detector offset: (This is regrettably confusing.)
# The implementation in matrixDFT has the sense of "how much should the source be offset",
# i.e. an offset of +5 pix moves the source by +5 pix.
# However, physically we would like the opposite sign convention: Moving the detector by +5 pix
# should move the source by -5 pix. This is implemented by a sign flip multplication by -1
# which is applied in the _propagate_mft methods. That could just be a hard-coded -1,
# but we choose to implement as a named variable to help make this logic clear later to readers of this code:
self.offset = offset
self._offset_sign = -1
if fov_pixels is None and fov_arcsec is None:
raise ValueError("Either fov_pixels or fov_arcsec must be specified!")
elif fov_pixels is not None:
self.fov_pixels = np.round(fov_pixels)
self.fov_arcsec = self.fov_pixels * self.pixelscale
else:
# set field of view to closest value possible to requested,
# consistent with having an integer number of pixels
self.fov_pixels = np.round((fov_arcsec.to(u.arcsec) / self.pixelscale).to(u.pixel))
self.fov_arcsec = self.fov_pixels * self.pixelscale
if np.any(self.fov_pixels <= 0):
raise ValueError("FOV in pixels must be a positive quantity. Invalid: " + str(self.fov_pixels))
self.amplitude = 1
self.opd = 0
@property
def shape(self):
fpix = self.fov_pixels.to(u.pixel).value
# have to cast back to int since Quantities are all float internally
return (int(fpix), int(fpix)) if np.isscalar(fpix) else fpix.astype(int)[0:2]
def __str__(self):
return "Detector plane: {} ({}x{} pixels, {:.3f})".format(self.name, self.shape[1], self.shape[0], self.pixelscale)
@staticmethod
def _handle_pixelscale_units_flexibly(pixelscale, fov_pixels):
""" The unit conventions for pixelscale are tricky; deal with that.
For historical reasons and API simplicity, the Detector class can be
used with pixels in angular units (arcsec/pixel) or physical units (micron/pixel).
The regular @utils.quantity_input decorator won't support that, so we handle it here.
"""
# This code is adapted from utils.BackCompatibleQuantityInput
arcsec_per_pixel = u.arcsec/u.pixel
micron_per_pixel = u.micron/u.pixel
# Case 1: pixelscale given without units. Treat it as angular units in arcsec/pixel
if not isinstance(pixelscale, u.Quantity):
try:
new_pixelscale = pixelscale * arcsec_per_pixel
except (ValueError, TypeError):
raise ValueError("Argument '{0}' to function '{1}'"
" must be a number (not '{3}'), and convertable to"
" units='{2}'.".format('pixelscale', 'Detector.__init__',
arcsec_per_pixel, pixelscale))
# Case 2: pixelscale compatible with angular units. Treat it as such.
elif pixelscale.unit.is_equivalent(arcsec_per_pixel):
new_pixelscale = pixelscale
# Case 3: pixelscale compatible with physical units. Treat it as such. Also, in
# this case, the user *must* specify a value for fov_pixels (since the other option
# would be fov_arcsec, and that would make no sense in this case).
elif pixelscale.unit.is_equivalent(micron_per_pixel):
new_pixelscale = pixelscale
if fov_pixels is None:
raise ValueError("If you specify the detector pixelscale in microns/pixel or "
"other linear units (not angular), then you must specify the "
"field of view via fov_pixels=<some integer>.")
# Case 4: some other units. Raise an error.
else:
raise ValueError("Argument '{0}' to function '{1}'"
" must be a number or quantity convertable to"
" units=arcsec/pixel or micron/pixel. Note, make sure your pixelscale units are specified per pixel!".format('pixelscale',
'Detector.__init__',
pixelscale))
return new_pixelscale