# -*- coding: utf-8 -*-
"""
Created on Sat Jul 30 2022
@author: Yuchen Wang
Built-in matchers.
"""
import numpy as np
from .utils import find_idx, find_eq, find_dup
from .utils import DTypeMismatchError, DTypeUnsupportedError
from astropy.coordinates import SkyCoord
import astropy.units as u
from astropy.units import UnitTypeError
import warnings
from collections.abc import Iterable
class UnsafeMatchingWarning(Warning):
pass
# def __init__(self, data, **kwargs)
[docs]
class DuplicationWarning(UnsafeMatchingWarning):
pass
[docs]
class ExactMatcher():
'''
Used to match `pyttop.table.Data` objects `data1` to `data`.
Match records with exact values.
This should be passed to method `data.match()`.
See `help(data.match)`.
Parameters
----------
value : str or Iterable
Specify values for `data` used to match catalogs. Possible inputs are:
- str, name of the field used for matching.
- Iterable, values for `data`. `len(value)` should be equal to `len(data)`.
value1 : str or Iterable, optional
Specify values for `data1` used to match catalogs. Possible inputs are:
- str, name of the field used for matching.
- Iterable, values for `data1`. `len(value1)` should be equal to `len(data1)`.
If not given and ``value`` is a string, ``value1`` set to the same as ``value``.
'''
def __init__(self, value, value1=None):
self.value = value
self.value1 = value1
self.value_name, self.value1_name = f'"{value}"' if isinstance(value, str) else value, f'"{value1}"' if isinstance(value1, str) else value1 # before evaluation, let the input be the names
if self.value1 is None:
if isinstance(self.value, str):
self.value1 = self.value
else:
raise TypeError("argument missing: 'value1'")
def get_values(self, data, data1, verbose=True):
self.data, self.data1 = data, data1
valuetype, value1type = type(self.value), type(self.value1)
if isinstance(self.value, str):
self.value = data[self.value]
elif isinstance(self.value, Iterable):
self.value = np.asanyarray(self.value)
# if not isinstance(self.value, np.ndarray): # Column, MaskedArray, etc. are instances of np.ndarray but will be converted by np.array(), so we need this condition
# self.value = np.array(self.value)
else:
raise TypeError(f"expected str or Iterable for 'value', got '{type(self.value)}'")
if isinstance(self.value1, str):
self.value1 = data1[self.value1]
elif isinstance(self.value1, Iterable):
self.value1 = np.asanyarray(self.value1)
# if not isinstance(self.value1, np.ndarray): # Column, MaskedArray, etc. are instances of np.ndarray but will be converted by np.array(), so we need this condition
# self.value1 = np.array(self.value1)
else:
raise TypeError(f"expected str or Iterable for 'value1', got '{type(self.value1)}'")
if hasattr(self.value, 'name'):
self.value_name = f'"{self.value.name}"'
else:
self.value_name = valuetype
if hasattr(self.value1, 'name'):
self.value1_name = f'"{self.value1.name}"'
else:
self.value1_name = value1type
dup_vals = find_dup(self.value)
if dup_vals.size > 0:
warnings.warn(f"Duplications found for data '{data.name}' while matching '{data1.name}' to it: the same row of '{data1.name}' may be matched to multiple rows in '{data.name}'.",
stacklevel=3, category=DuplicationWarning)
dup_vals = find_dup(self.value1)
if dup_vals.size > 0:
warnings.warn(f"Duplications found for data '{data1.name}' while matching to '{data.name}': there may be multiple rows in '{data1.name}' that can be matched to a row in '{data.name}', and only one will be returned by the matcher.",
stacklevel=3, category=DuplicationWarning)
missings = [] # whether the coord is missing
not_missing_ids = [] # the indices of those that are not missing
for valuei, datai in [[self.value, data], [self.value1, data1]]:
if np.ma.is_masked(valuei): #datai.t.masked:
# NOTE: it should not matter whether datai.t is masked; it is valuei that matters. A table that is not "masked" can have masked colums; valuei can also be user-specified rather than from datai.t
missingi = valuei.mask
else:
missingi = np.full(len(datai), False)
not_missing_idi = np.arange(len(datai), dtype=int)[~missingi]
missings.append(missingi)
not_missing_ids.append(not_missing_idi)
self.missing, self.missing1 = missings
self.not_missing_id, self.not_missing_id1 = not_missing_ids
def match(self):
def dtypestr(dtype):
return f"{dtype.name} ('{dtype}')"
l = len(self.missing)
idx = np.full(self.missing.shape, -l-1)
matched = np.full(self.missing.shape, False)
try:
idx_nm, matched_nm = find_idx(self.value1[~self.missing1], self.value[~self.missing])
except DTypeUnsupportedError as e:
if e.argname == 'array':
name = self.value1_name
data = self.data1
elif e.argname == 'values':
name = self.value_name
data = self.data
else:
raise ValueError(f"Unexpected argname: '{e.argname}'")
raise TypeError(
'ExactMatcher only supports integers or strings. '
f'{name} for {data._short_name} is {dtypestr(e.dtype)}'
# f"got {self.value_name} -> '{self.value.dtype}' and {self.value1_name} -> '{self.value1.dtype}'"
) from e
except DTypeMismatchError as e:
# related checks:
# self.value.dtype != self.value1.dtype: # may be too strict
# self.value.dtype.kind != self.value1.dtype.kind
# other alternatives: np.can_cast, np.promote_types
tips = ''
if {e.kind0, e.kind1} == {'U', 'S'}: # mixing unicode and bytes
# allowing this may cause silent errors, for example:
# find_idx(np.array(["text", "tex"], dtype="U"), np.array([b"text"], dtype="S"))
# -> [-3], [False] # not found
tips = (
# "\nTo convert between 'U' (unicode) and 'S' (bytes): "
# "np.char.decode(x, encoding=...) for bytes->text, or np.char.encode(x, encoding=...) for text->bytes."
"\nTo convert between str and bytes: "
"np.char.decode(x, encoding=...) for bytes->str, or np.char.encode(x, encoding=...) for str->bytes."
# ", where encoding is usually 'utf-8'."
)
raise TypeError(
'For safety, ExactMatcher only supports matching the same dtype kind. '
f'{self.value_name} for "{self.data._short_name}" is {dtypestr(self.value.dtype)}, '
f'and {self.value1_name} for "{self.data1._short_name}" is {dtypestr(self.value1.dtype)}. '
f'{tips}'
) from e
matched[~self.missing] = matched_nm
idx[matched] = self.not_missing_id1[idx_nm[matched_nm]]
return idx, matched
def __repr__(self):
return f'ExactMatcher({self.value_name}, {self.value1_name})'
[docs]
class SkyMatcher():
'''
Used to match `pyttop.table.Data` objects `data1` to `data`.
Match records with nearest coordinates.
This should be passed to method `data.match()`.
See `help(data.match)`.
Parameters
----------
thres : number, optional
Threshold in arcsec. The default is 1.
coord : str or astropy.coordinates.SkyCoord, optional
Specify coordinate for the base data. Possible inputs are:
- astropy.coordinates.SkyCoord (recommended), the coordinate object.
- str, should be like 'RA-DEC', which specifies the column name for RA and Dec.
- None (default), will try ['ra', 'RA'] and ['DEC', 'Dec', 'dec'].
The default is None.
coord1 : str or astropy.coordinates.SkyCoord, optional
Specify coordinate for the matched data. Possible inputs are:
- astropy.coordinates.SkyCoord (recommended), the coordinate object.
- str, should be like 'RA-DEC', which specifies the column name for RA and Dec.
- None (default), will try ['ra', 'RA'] and ['DEC', 'Dec', 'dec'].
The default is None.
unit : astropy.units.core.Unit or list/tuple/array of it
If astropy.coordinates.SkyCoord object is not given for coord,
this is used to specify the unit of coord.
The default is astropy.units.deg.
unit1 : astropy.units.core.Unit or list/tuple/array of it
If astropy.coordinates.SkyCoord object is not given for coord1,
this is used to specify the unit of coord1.
The default is astropy.units.deg.
Notes
-----
The data columns for RA, Dec may already have units (e.g. ``data.t['RA'].unit``).
In this case, any input for ``unit`` or ``unit1`` is ignored, and the units recorded
in the columns are used.
'''
def __init__(self, thres=1, coord=None, coord1=None, unit=u.deg, unit1=u.deg):
self.thres = thres
self.coord = coord
self.coord1 = coord1
self.unit = unit
self.unit1 = unit1
def get_values(self, data, data1, verbose=True):
# TODO: this method has not been debugged!
# USE WITH CAUTION!
ra_names = np.array(['ra', 'RA'])
dec_names = np.array(['DEC', 'Dec', 'dec'])
coords = []
missings = [] # whether the coord is missing
not_missing_ids = [] # the indices of those that are not missing
for coordi, datai, uniti in [[self.coord, data, self.unit], [self.coord1, data1, self.unit1]]:
if coordi is None or isinstance(coordi, str):
if coordi is None: # auto decide ra, dec
found_ra = np.isin(ra_names, datai.colnames)
if not np.any(found_ra):
raise KeyError(f'RA for {datai.name} not found.')
self.ra_name = ra_names[np.where(found_ra)][0]
ra = datai.t[self.ra_name]
found_dec = np.isin(dec_names, datai.colnames)
if not np.any(found_dec):
raise KeyError(f'Dec for {datai.name} not found.')
self.dec_name = dec_names[np.where(found_dec)][0]
dec = datai.t[self.dec_name]
if verbose: print(f"[SkyMatcher] Data {datai.name}: found RA name '{self.ra_name}' and Dec name '{self.dec_name}'.")
else: # type(coordi) is str:
self.ra_name, self.dec_name = coordi.split('-')
ra = datai.t[self.ra_name]
dec = datai.t[self.dec_name]
# check missing values for ra and dec
# TODO: below NOT TESTED
missingi = np.full(len(datai), False)
if np.ma.is_masked(ra):
missingi |= ra.mask
if np.ma.is_masked(dec): # datai.t.masked or
missingi |= dec.mask
not_missing_idi = np.arange(len(datai), dtype=int)[~missingi]
try:
coordi = SkyCoord(ra=ra[~missingi], dec=dec[~missingi], unit=uniti)
except UnitTypeError as e:
info = e.args[0]
which_coor = self.ra_name if 'Longitude' in info else self.dec_name
got_unit = info.split('set it to ')[-1]
raise UnitTypeError(f"Unrecognized unit for column '{which_coor}': expected units equivalent to 'rad', got {got_unit}"\
f" Try manually setting {datai.__repr__()}.t['{which_coor}'].unit") from e
elif type(coordi) is SkyCoord:
self.ra_name, self.dec_name = None, None
coordi = coordi
missingi = np.full(len(datai), False)
not_missing_idi = np.arange(len(datai), dtype=int)[~missingi]
else:
raise TypeError(f"Unsupported type for coord/coord1: expected str or astropy.coordinates.SkyCoord, got {type(coordi)}")
coords.append(coordi)
missings.append(missingi)
not_missing_ids.append(not_missing_idi)
self.coord, self.coord1 = coords
self.missing, self.missing1 = missings
self.not_missing_id, self.not_missing_id1 = not_missing_ids
# check duplicates for coordinates
_, counts = np.unique(np.stack([self.coord.ra, self.coord.dec]), axis=1, return_counts=True)
if np.any(counts != 1):
warnings.warn(f"Duplications found for data '{data.name}' while matching '{data1.name}' to it: the same row of '{data1.name}' may be matched to multiple rows in '{data.name}'.",
stacklevel=3, category=DuplicationWarning)
_, counts = np.unique(np.stack([self.coord1.ra, self.coord1.dec]), axis=1, return_counts=True)
if np.any(counts != 1):
warnings.warn(f"Duplications found for data '{data1.name}' while matching to '{data.name}': there may be multiple rows in '{data1.name}' that can be matched to a row in '{data.name}', and only one will be returned by the matcher.",
stacklevel=3, category=DuplicationWarning)
def match(self):
l = len(self.missing)
idx = np.full(self.missing.shape, -l-1)
matched = np.full(self.missing.shape, False)
idx_nm, d2d, d3d = self.coord.match_to_catalog_sky(self.coord1)
idx[~self.missing] = self.not_missing_id1[idx_nm]
matched[~self.missing] = d2d.arcsec < self.thres
return idx, matched
[docs]
def explore(self, data, data1):
'''
Plot as simple histogram to
check the distribution of the minimum (2-d) sky separation.
Parameters
----------
data : ``pyttop.table.Data``
The base data of the match.
data1 : ``pyttop.table.Data``
The data to be matched to ``data1``.
Returns
-------
None.
'''
self.get_values(data, data1)
idx, d2d, d3d = self.coord.match_to_catalog_sky(self.coord1)
import matplotlib.pyplot as plt
plt.figure()
plt.hist(np.log10(d2d.arcsec), bins=min((200, len(data)//20)), histtype='step', linewidth=1.5, log=True)
plt.axvline(np.log10(self.thres), color='r', linestyle='--')
plt.xlabel('lg (d / arcsec)')
plt.title(f"Min. distance to '{data1.name}' objects for each '{data.name}' object\nthreshold={self.thres}\"")
return d2d.arcsec
def __repr__(self):
# TODO: show more information here
return f'<SkyMatcher with thres={self.thres}>'
[docs]
class IdentityMatcher():
'''
Used to match ``pyttop.table.Data`` objects ``data1`` to ``data``.
Directly match records row by row, i.e. row #1 matched to row #1, row #2 matched to row #2, etc.
Only possible if ``len(data1) == len(data)``.
This should be passed to method `data.match()`.
See ``help(data.match)``.
'''
def __init__(self):
pass
def get_values(self, data, data1, verbose=True):
if len(data) != len(data1):
raise ValueError(f'IdentityMatcher can only be used to match data with the same number of rows ({len(data)} != {len(data1)})')
self.len = len(data)
def match(self):
idx = np.arange(self.len)
matched = np.full((self.len,), True)
return idx, matched
def __repr__(self):
return '<IdentityMatcher>'