Source code for ebltable.interpolate

from __future__ import absolute_import, division, print_function
import numpy as np
import warnings
from collections.abc import Iterable
from astropy.table import Table, Column
from astropy.io import fits
from astropy.units import Unit
from scipy.interpolate import RectBivariateSpline


[docs] class GridInterpolator(object): """ Base class for 2D grid interpolation """
[docs] def __init__(self, x, y, Z, logx=False, logy=False, logZ=False, kx=1, ky=1, **kwargs): """ Initialize the class. Parameters ---------- x: array-like Array with values of x axis of grid, n-dimensional y: array-like Array with values of y axis of grid, m-dimensional z: array-like Array with grid (z) values, n x m-dimensional logx: bool Use logarithmic interpolation over x axis. Default: False logy: bool Use logarithmic interpolation over y axis. Default: False logZ: bool Use logarithmic interpolation over Z axis. Default: False kx: int Order of spline interpolation in x direction. Default: 1 ky: int Order of spline interpolation in y direction. Default: 1 kwargs: dict Additional kwargs passed to :class:`scipy.interpolate.RectBivariateSpline` """ kx = kwargs.pop('kx', kx) ky = kwargs.pop('ky', ky) self._logx = logx self._logy = logy self._logZ = logZ self._kx = kx self._ky = ky if logx: x[x == 0.] = 1e-40 self._x = np.log10(x) else: self._x = x if logy: y[y == 0.] = 1e-40 self._y = np.log10(y) else: self._y = y if logZ: Z[Z == 0.] = 1e-40 self._Z = np.log10(Z) else: self._Z = Z self._spline = RectBivariateSpline(self._x, self._y, self._Z, kx=kx, ky=ky, **kwargs)
@property def x(self): return self._x @property def y(self): return self._y @property def Z(self): return self._Z @x.setter def x(self, x): if self._logx: x[x == 0.] = 1e-40 self._x = np.log10(x) else: self._x = x self._spline = RectBivariateSpline(self._x, self._y, self._Z, kx=self._kx, ky=self._ky) @y.setter def y(self, y): if self._logy: y[y == 0.] = 1e-40 self._y = np.log10(y) else: self._y = y self._spline = RectBivariateSpline(self._x, self._y, self._Z, kx=self._kx, ky=self._ky) return self._y @Z.setter def Z(self, Z): if self._logZ: Z[Z == 0.] = 1e-40 self._Z = np.log10(Z) else: self._Z = Z self._spline = RectBivariateSpline(self._x, self._y, self._Z, kx=self._kx, ky=self._ky) return self._Z @staticmethod def _read_ascii(file_name): """ Read in a model file from an arbitrary file. Parameters ---------- file_name: str, full path to optical depth model file, with a (n+1) x (m+1) dimensional table. The zeroth column contains the x values, first row contains the y values. The remaining values are the Z values of the grid. The [0,0] entry will be ignored. Returns ------- tuple with x, y and Z values """ data = np.loadtxt(file_name) x = data[1:, 0] y = data[0, 1:] Z = data[1:, 1:] return x, y, Z @staticmethod def _read_fits(file_name, hdu_name_grid, hdu_name_x, xcol_name, ycol_name, Zcol_name, xtarget_unit): """ Read in a model file from an arbitrary file. Parameters ---------- file_name: str, full path to fits file containing the grid values hdu_name_grid: str Name of the HDU extension containing the Grid values Z and y axis values hdu_name_x: str, Name of the HDU extension containing the x axis values xcol_name: str, name of x column in hdu_name_x extension ycol_name: str, name of y column in hdu_name_grid extension Zcol_name: str, name of Z column in hdu_name_grid extension xtarget_unit: str, name of target unit for x values Returns ------- tuple with x, y and Z values """ t = Table.read(file_name, hdu=hdu_name_grid) y = t[ycol_name].data Z = t[Zcol_name].data t2 = Table.read(file_name, hdu=hdu_name_x) x = t2[xcol_name].data * t2[xcol_name].unit return x.to(xtarget_unit).value, y, Z.T def _write_fits(self, filename, x, y, hdu_name_grid, hdu_name_x, xunit, xcol_name, ycol_name, Zcol_name, xtarget_unit=None, overwrite=True): """ Write Z values to a fits file using the astropy table environment. Parameters ---------- filename: str, full file path for output fits file x: array-like x values for interpolation y: array-like y values for interpolation hdu_name_grid: str Name of the HDU extension containing the Grid values Z and y axis values hdu_name_x: str, Name of the HDU extension containing the x axis values xunit: str, name of unit for x values xcol_name: str, name of x column in hdu_name_x extension ycol_name: str, name of y column in hdu_name_grid extension Zcol_name: str, name of Z column in hdu_name_grid extension xtarget_unit: str name of unit of x values overwrite: bool Overwrite existing file. """ if xtarget_unit is None: xtarget_unit = xunit t = Table([y, self.evaluate(x, y)], names=(ycol_name, Zcol_name)) t2 = Table() t2[xcol_name] = Column(x * Unit(xunit).to(xtarget_unit), unit=xtarget_unit) hdulist = fits.HDUList([fits.PrimaryHDU(), fits.table_to_hdu(t), fits.table_to_hdu(t2)]) hdulist[1].name = hdu_name_grid hdulist[2].name = hdu_name_x hdulist.writeto(filename, overwrite=overwrite) return
[docs] def evaluate(self, x, y): """ Evaluate Spline for some x and y values Parameters ---------- x: array-like x coordinates for evaluation, n-dimensional y: array-like y coordinates for evaluation, m-dimensional Returns ------- Interpolated Z values for input x and y values (m x n dimensional). If n or m are equal to one, drop this axis. """ if np.isscalar(x): x = np.array([x]) elif x is Iterable: x = np.array(x) if np.isscalar(y): y = np.array([y]) elif y is Iterable: y = np.array(y) if np.any(y < self._y[0]): warnings.warn(f"Warning: a y value is below interpolation range, y min = {self._y[0]:.2f}", RuntimeWarning) result = np.zeros((y.shape[0], x.shape[0])) tt = np.zeros((y.shape[0], x.shape[0])) args_x = np.argsort(x) args_y = np.argsort(y) # Spline interpolation requires sorted lists # alternative would be to calculate the spline with grid=False # but this takes longer in my tests if self._logx: x = np.log10(x) if self._logy: y = np.log10(y) tt[args_y, :] = self._spline(np.sort(x), np.sort(y)).transpose() result[:, args_x] = tt if self._logZ: result = np.power(10., result) return np.squeeze(result)