Source code for sorcha.ephemeris.pixel_dict

import numpy as np
import healpy as hp
import numba

from collections import defaultdict

from sorcha.ephemeris.simulation_geometry import *
from sorcha.ephemeris.simulation_constants import *


@numba.njit(fastmath=True)
[docs] def lagrange3(t0, t1, t2, t): """Calculate the coefficients for second-order Lagrange interpolation for measured points at times t0, t1, and t2 and for an array of times t. These coefficients can be reused for any number of input vectors. Parameters ---------- t0 : float Time t0 t1 : float Time t1 t2 : float Time t2 t : 1D array Times for the interpolation Returns ------- L0 : 1D array interpolation coefficient at t0 L1 : 1D array interpolation coefficient at t1 L2 : 1D array interpolation coefficient at t2 """ L0 = (t - t1) * (t - t2) / ((t0 - t1) * (t0 - t2)) L1 = (t - t0) * (t - t2) / ((t1 - t0) * (t1 - t2)) L2 = (t - t0) * (t - t1) / ((t2 - t0) * (t2 - t1)) return L0, L1, L2
[docs] class PixelDict: """ Class with methods needed during the ephemerides generation Interfaces directly with the ASSIST+Rebound simulation objects as well as healpix """ def __init__( self, jd_tdb, sim_dict, ephem, obsCode, observatory, picket_interval=1.0, nside=128, nested=True, n_sub_intervals=101, ): """ Initialization function for the class. Computes the initial positions required for the ephemerides interpolation Parameters ---------- jd_tdb: float Reference time for the initialization sim_dict: dictionary dictionary of ASSIST simulation objects ephem: Ephem ASSIST Ephem object obsCode: str MPC Observatory code observatories: Observatory Observatory object picket_interval : float The interval (days) between picket calculations. This is 1 day by default nside : integer The nside value used for the HEALPIx calculations. Must be a power of 2 (1, 2, 4, ...) nside=128 is current default. nested: boolean Defines the ordering scheme for the healpix ordering. True (default) means a NESTED ordering n_sub_intervals: int Number of sub-intervals for the Lagrange interpolation (default: 101) """
[docs] self.nside = nside
[docs] self.picket_interval = picket_interval
[docs] self.n_sub_intervals = n_sub_intervals
[docs] self.obsCode = obsCode
[docs] self.nested = nested
[docs] self.sim_dict = sim_dict
[docs] self.ephem = ephem
[docs] self.observatory = observatory
# Set the three times and compute the observatory position # at those times # Using a quadratic isn't very general, but that can be # improved later
[docs] self.t0 = jd_tdb
[docs] self.r_obs_0 = self.get_observatory_position(self.t0)
[docs] self.tp = self.t0 + picket_interval
[docs] self.r_obs_p = self.get_observatory_position(self.tp)
[docs] self.tm = self.t0 - picket_interval
[docs] self.r_obs_m = self.get_observatory_position(self.tm)
# Initialize the dictionary of positions
[docs] self.pixel_dict = defaultdict(list)
[docs] self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm)
[docs] self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0)
[docs] self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp)
self.compute_pixel_traversed()
[docs] def get_observatory_position(self, t): """ Computes the barycentric position of the observatory (in au) Parameters ---------- t : float Epoch for the position vector Returns ------- : array (3,) Barycentric position of the observatory (x,y,z) """ et = (t - spice.j2000()) * 24 * 60 * 60 r_obs = self.observatory.barycentricObservatory(et, self.obsCode) / AU_KM return r_obs
[docs] def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01): """ Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector for a list of objects, at a given time Parameters ---------- desigs: list List of designations (consistent with the simulation dictionary) r_obs: array (3 entries) Observatory location t: float Time of the observation lt0: float Initial guess (in days) for light-time correction (default: 0.01 days) Returns ------- rho_hat_dict: dict Dictionary of unit vectors """ rho_hat_dict = {} for k in desigs: v = self.sim_dict[k] sim, ex = v["sim"], v["ex"] # Get the topocentric unit vectors rho, rho_mag, lt, r_ast, v_ast = integrate_light_time( sim, ex, t - self.ephem.jd_ref, r_obs, lt0=lt0 ) rho_hat = rho / rho_mag rho_hat_dict[k] = rho_hat return rho_hat_dict
[docs] def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01): """ Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector for *all* objects, at a given time Parameters ---------- r_obs: array (3 entries) Observatory location t: float Time of the observation lt0: float Initial guess (in days) for light-time correction (default: 0.01 days) Returns ------- rho_hat_dict: dict Dictionary of unit vectors """ desigs = self.sim_dict.keys() return self.get_object_unit_vectors(desigs, r_obs, t, lt0=lt0)
[docs] def get_interp_factors(self, tm, t0, tp, n_sub_intervals): """ Computes the Lagrange interpolation factors at a set of 3 times for an equally spaced grid of points with a chosen number of sub-intervals Parameters ---------- tm: float First reference time t0: float Second reference time tp: float Third reference time n_sub_intervals: int Number of sub-intervals for the Lagrange interpolation (default: 101) Returns ------- Lm: 2D array Lagrange coefficients at tm L0: 2D array Lagrange coefficients at t0 Lp: 2D array Lagrange coefficient at tp """ times = np.linspace(tm, tp, n_sub_intervals) Lm, L0, Lp = lagrange3(tm, t0, tp, times) Lm = Lm[:, np.newaxis] L0 = L0[:, np.newaxis] Lp = Lp[:, np.newaxis] return Lm, L0, Lp
[docs] def interpolate_unit_vectors(self, desigs, jd_tdb): """ Interpolates the unit vectors for a list of designations towards the new target time Parameters ---------- desigs: list List of designations (consistent with the simulation dictionary) jd_tdb: float Target time Returns ------- unit_vector_dict: dict Dictionary of unit vectors """ # Update the table of unit vectors if needed. # Should not normally need to, if this routine is being # called properly self.update_pickets(jd_tdb) Lm, L0, Lp = lagrange3(self.tm, self.t0, self.tp, jd_tdb) unit_vector_dict = {} for k in desigs: rho_hat_m = self.rho_hat_m_dict[k] rho_hat_0 = self.rho_hat_0_dict[k] rho_hat_p = self.rho_hat_p_dict[k] # Interpolate the unit vectors over a finer sampled set of times vec = rho_hat_m * Lm + rho_hat_0 * L0 + rho_hat_p * Lp unit_vector_dict[k] = vec return unit_vector_dict
[docs] def compute_pixel_traversed(self): """ Computes the healpix pixels traversed by all the objects during between times tm and tp """ # These don't need to be recomputed, if the interval stays the same Lm, L0, Lp = self.get_interp_factors(self.tm, self.t0, self.tp, self.n_sub_intervals) self.pixel_dict = defaultdict(list) for k, v in self.sim_dict.items(): rho_hat_m = self.rho_hat_m_dict[k] rho_hat_0 = self.rho_hat_0_dict[k] rho_hat_p = self.rho_hat_p_dict[k] # Interpolate the unit vectors over a finer sampled set of times vec = rho_hat_m * Lm + rho_hat_0 * L0 + rho_hat_p * Lp # Find the healpix locations pixels = hp.vec2pix(self.nside, vec[:, 0], vec[:, 1], vec[:, 2], nest=self.nested) pixels = list(set(pixels)) # Add the neighboring pixels pixels = set(hp.get_all_neighbours(self.nside, pixels, nest=self.nested).flatten()) # Add the pixels that the object traverse, and the neighbors, to the pixel_dict for pix in pixels: self.pixel_dict[pix].append(k)
[docs] def update_pickets(self, jd_tdb): """ Updates the picket interpolation vectors for the new reference time Parameters ---------- jd_tdb: float Target time """ if abs(jd_tdb - self.t0) > 0.5 * self.picket_interval: # Need to update if abs(jd_tdb - self.t0) <= 1.5 * self.picket_interval: # Can compute just one new set and shift the others if jd_tdb <= self.tm: # shift earlier self.tp = self.t0 self.r_obs_p = self.r_obs_0 self.rho_hat_p_dict = self.rho_hat_0_dict self.t0 = self.tm self.r_obs_0 = self.r_obs_m self.rho_hat_0_dict = self.rho_hat_m_dict self.tm = self.t0 - self.picket_interval self.r_obs_m = self.get_observatory_position(self.tm) self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm) else: # shift later self.tm = self.t0 self.r_obs_m = self.r_obs_0 self.rho_hat_m_dict = self.rho_hat_0_dict self.t0 = self.tp self.r_obs_0 = self.r_obs_p self.rho_hat_0_dict = self.rho_hat_p_dict self.tp = self.t0 + self.picket_interval self.r_obs_p = self.get_observatory_position(self.tp) self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp) else: # Need to compute three new sets n = round((jd_tdb - self.t0) / self.picket_interval) # This is repeated code self.t0 += n * self.picket_interval self.r_obs_0 = self.get_observatory_position(self.t0) self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0) self.tp = self.t0 + self.picket_interval self.r_obs_p = self.get_observatory_position(self.tp) self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp) self.tm = self.t0 - self.picket_interval self.r_obs_m = self.get_observatory_position(self.tm) self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm) self.compute_pixel_traversed() else: pass
[docs] def get_designations(self, jd_tdb, ra, dec, ang_fov): """ Get the object designations that are within an angular radius of a topocentric unit vector at a given time. Parameters ---------- jd_tdb: float Target time ra: float right ascension (degrees) dec: float declination (degrees) ang_fov: float Field of view radius Returns ------- desigs : list List of designations """ # Update the table of unit vectors if needed. self.update_pickets(jd_tdb) pixels = get_hp_neighbors(ra, dec, ang_fov, nside=self.nside, nested=self.nested) desigs = set() for pix in pixels: desigs.update(self.pixel_dict[pix]) return desigs