#!/usr/bin/env python
"""
dataset.py
Written by Tyler Sutterley (04/2026)
An xarray.Dataset extension for SMB and firn model data
PYTHON DEPENDENCIES:
numpy: Scientific Computing Tools For Python
https://numpy.org
https://numpy.org/doc/stable/user/numpy-for-matlab-users.html
pint: Python package to define, operate and manipulate physical quantities
https://pypi.org/project/Pint/
https://pint.readthedocs.io/en/stable
pyproj: Python interface to PROJ library
https://pypi.org/project/pyproj/
https://pyproj4.github.io/pyproj/
scipy: Scientific Tools for Python
https://docs.scipy.org/doc/
xarray: N-D labeled arrays and datasets in Python
https://docs.xarray.dev/en/stable/
UPDATE HISTORY:
Updated 04/2026: added combine_attrs to merge conflicts into a list
added grid cell area calculators for geographic and projected models
Written 04/2026
"""
import re
import pint
import pyproj
import warnings
import numpy as np
import xarray as xr
from typing import Any
from xarray.core.utils import equivalent
# suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
__all__ = [
"DataTree",
"Dataset",
"DataArray",
"combine_attrs",
"equivalent_attrs",
"get_variable",
"register_datatree_subaccessor",
"register_dataset_subaccessor",
"register_dataarray_subaccessor",
"_transform",
"_coords",
]
# pint unit registry
__ureg__ = pint.UnitRegistry()
# add water and ice equivalents
__ureg__.define("we = 1.0 * g / cm^3")
__ureg__.define("ie = 0.917 * g / cm^3")
__ureg__.define("@alias we = water = water_equivalent")
__ureg__.define("@alias ie = ice = ice_equivalent")
# air equivalent for FAC
__ureg__.define("air = 1.0")
# default units for SMB and firn outputs
_default_units = {
"mass density": "cm we",
}
[docs]
@xr.register_datatree_accessor("fcorr")
class DataTree:
"""Accessor for extending an ``xarray.DataTree`` for SMB and firn data"""
def __init__(self, dtree):
# initialize DataTree
self._dtree = dtree
[docs]
def assign_coords(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Assign new coordinates to the ``DataTree``
Parameters
----------
x: np.ndarray
Updated x-coordinates
y: np.ndarray
Updated y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of coordinates
kwargs: dict
Keyword arguments for ``xarray.Dataset.assign_coords``
Returns
-------
dtree: xarray.DataTree
``DataTree`` with updated coordinates
"""
# assign new coordinates to each dataset
dtree = self._dtree.copy()
for key, ds in self._dtree.items():
ds = ds.to_dataset().assign_coords(dict(x=x, y=y), **kwargs)
ds.attrs["crs"] = crs
dtree[key] = ds
# return the datatree
return dtree
[docs]
def coords_as(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Transform coordinates into ``DataArrays`` in the ``DataTree``
coordinate reference system
Parameters
----------
x: np.ndarray
Input x-coordinates
y: np.ndarray
Input y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of input coordinates
Returns
-------
X: xarray.DataArray
Transformed x-coordinates
Y: xarray.DataArray
Transformed y-coordinates
"""
# convert coordinate reference system to that of the datatree
# and format as xarray DataArray with appropriate dimensions
X, Y = _coords(x, y, source_crs=crs, target_crs=self.crs, **kwargs)
# return the transformed coordinates
return X, Y
[docs]
def crop(self, *args, **kwargs):
"""
Crop ``DataTree`` to input bounding box
"""
# create copy of datatree
dtree = self._dtree.copy()
# crop each dataset in the datatree
for key, ds in dtree.items():
ds = ds.to_dataset()
dtree[key] = ds.fcorr.crop(*args, **kwargs)
# return the datatree
return dtree
[docs]
def inpaint(self, **kwargs):
"""
Inpaint over missing data in ``DataTree``
"""
# create copy of datatree
dtree = self._dtree.copy()
# inpaint each dataset in the datatree
for key, ds in dtree.items():
ds = ds.to_dataset()
dtree[key] = ds.fcorr.inpaint(**kwargs)
# return the datatree
return dtree
[docs]
def interp(
self,
x: np.ndarray,
y: np.ndarray,
**kwargs,
):
"""
Interpolate ``DataTree`` to new coordinates
Parameters
----------
x: np.ndarray
Interpolation x-coordinates
y: np.ndarray
Interpolation y-coordinates
"""
# create copy of datatree
dtree = self._dtree.copy()
# interpolate each dataset in the datatree
for key, ds in dtree.items():
ds = ds.to_dataset()
dtree[key] = ds.fcorr.interp(x, y, **kwargs)
# return the datatree
return dtree
@property
def crs(self):
"""Coordinate reference system of the ``DataTree``"""
# inherit CRS from one of the datasets
for key, ds in self._dtree.items():
ds = ds.to_dataset()
return ds.fcorr.crs
[docs]
@xr.register_dataset_accessor("fcorr")
class Dataset:
"""Accessor for extending an ``xarray.Dataset`` for SMB and firn data"""
def __init__(self, ds):
# initialize Dataset
self._ds = ds
[docs]
def assign_coords(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Assign new coordinates to the ``Dataset``
Parameters
----------
x: np.ndarray
Updated x-coordinates
y: np.ndarray
Updated y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of coordinates
kwargs: dict
Keyword arguments for ``xarray.Dataset.assign_coords``
Returns
-------
ds: xarray.Dataset
``Dataset`` with updated coordinates
"""
# assign new coordinates to dataset
ds = self._ds.assign_coords(dict(x=x, y=y), **kwargs)
ds.attrs["crs"] = crs
# return the dataset
return ds
[docs]
def cell_area(self):
"""
Calculate the area of each grid cell in the ``Dataset``
Returns
-------
area: xarray.DataArray
Area of each grid cell in the dataset
"""
from FirnCorr.spatial import scale_factors
# get PROJ4 parameters for dataset projection
crs = self.crs.to_dict()
# get geodetic parameters
geod = self.crs.get_geod()
# ellipsoid semi-major and semi-minor axes
a_axis = geod.a
b_axis = geod.b
# ellipsoidal flattening
flat = geod.f
# first numerical eccentricity and its square
e12 = geod.es
ecc = np.sqrt(e12)
# authalic radius (same area as ellipsoid)
rad_e = np.sqrt(0.5 * (a_axis**2 + b_axis**2 * np.arctanh(ecc) / ecc))
# coordinates and attributes for output DataArray
coords = dict(y=self._ds.y, x=self._ds.x)
attrs = dict(units="m^2", long_name="Grid Cell Area")
# calculate areas based on the coordinate reference system
if self.crs.is_geographic and crs.get("proj") == "ob_tran":
# rotated pole projection (assume spherical)
_, gridy = np.meshgrid(np.radians(self._x), np.radians(self._y))
# grid spacing in the x and y directions
dx = np.abs(np.radians(self._x[1] - self._x[0]))
dy = np.abs(np.radians(self._y[1] - self._y[0]))
# calculate area of each grid cell
area = (rad_e * dy) * (rad_e * dx * np.cos(gridy))
# note: differs from RACMO as they calculate using equatorial radius
attrs["note"] = f"Multiply by scale to approximate RACMO cell areas"
attrs["scale"] = np.round((a_axis**2) / (rad_e**2), decimals=4)
elif self.crs.is_geographic:
# geographic coordinates (assume equirectangular)
_, gridy = np.meshgrid(np.radians(self._x), np.radians(self._y))
# grid spacing in the x and y directions
dx = np.abs(np.radians(self._x[1] - self._x[0]))
dy = np.abs(np.radians(self._y[1] - self._y[0]))
# radius of curvature in prime vertical direction (east-west)
N = a_axis / np.sqrt(1.0 - e12 * np.sin(gridy) ** 2)
# radius of curvature in meridional direction (north-south)
M = a_axis * (1.0 - e12) / (1.0 - e12 * np.sin(gridy) ** 2) ** 1.5
# calculate area of each grid cell
area = (M * dy) * (N * np.cos(gridy) * dx)
elif self.crs.is_projected and crs.get("proj") == "stere":
# stereographic projection
geodetic_crs = getattr(self.crs, "geodetic_crs", 4326)
# get latitude and true-scale latitude
_, lat = self.to_geographic(crs=geodetic_crs)
lat_ts = crs.get("lat_ts", 90.0)
# calculate scaling factors for area distortions
ps_scale = scale_factors(lat, flat=flat, reference_latitude=lat_ts)
# calculate scaling factors to convert from axis units to meters
axis_units = 1.0 * __ureg__.parse_units(self.axis_units)
axis_scale = axis_units.to(__ureg__.meter).magnitude
# grid spacing in the x and y directions
dx = axis_scale * np.abs(self._x[1] - self._x[0])
dy = axis_scale * np.abs(self._y[1] - self._y[0])
# calculate area of each grid cell
area = ps_scale * dx * dy
else:
# projected coordinates (assume Cartesian)
ny, nx = len(self._y), len(self._x)
# calculate scaling factors to convert from axis units to meters
axis_units = 1.0 * __ureg__.parse_units(self.axis_units)
axis_scale = axis_units.to(__ureg__.meter).magnitude
# grid spacing in the x and y directions
dx = axis_scale * np.abs(self._x[1] - self._x[0])
dy = axis_scale * np.abs(self._y[1] - self._y[0])
# calculate area of each grid cell
area = dx * dy * np.ones((ny, nx))
# return area as xarray DataArray
return xr.DataArray(area, coords=coords, dims=["y", "x"], attrs=attrs)
[docs]
def coords_as(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Transform coordinates into ``DataArrays`` in the ``Dataset``
coordinate reference system
Parameters
----------
x: np.ndarray
Input x-coordinates
y: np.ndarray
Input y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of input coordinates
Returns
-------
X: xarray.DataArray
Transformed x-coordinates
Y: xarray.DataArray
Transformed y-coordinates
"""
# convert coordinate reference system to that of the dataset
# and format as xarray DataArray with appropriate dimensions
X, Y = _coords(x, y, source_crs=crs, target_crs=self.crs, **kwargs)
# return the transformed coordinates
return X, Y
[docs]
def crop(
self,
bounds: list | tuple,
buffer: int | float = 0,
):
"""
Crop ``Dataset`` to input bounding box
Parameters
----------
bounds: list, tuple
Bounding box ``[min_x, max_x, min_y, max_y]``
buffer: int or float, default 0
Buffer to add to bounds for cropping
"""
# pad global grids along x-dimension (if necessary)
lon_wrap = self.crs.to_dict().get("lon_wrap", 0)
if self.is_global and (lon_wrap == 180) and (np.min(bounds[:2]) < 0):
# number of points to pad for global grids
n = int(180 // (self._x[1] - self._x[0]))
ds = self.pad(n=(n, 0))
elif self.is_global and (lon_wrap == 0) and (np.max(bounds[:2]) > 180):
# number of points to pad for global grids
n = int(180 // (self._x[1] - self._x[0]))
ds = self.pad(n=(0, n))
else:
# copy dataset
ds = self._ds.copy()
# check if chunks are present
if hasattr(ds, "chunks") and ds.chunks is not None:
ds = ds.chunk(-1).compute()
# unpack bounds and buffer
xmin = bounds[0] - buffer
xmax = bounds[1] + buffer
ymin = bounds[2] - buffer
ymax = bounds[3] + buffer
# crop dataset to bounding box
ds = ds.where(
(ds.x >= xmin) & (ds.x <= xmax) & (ds.y >= ymin) & (ds.y <= ymax),
drop=True,
)
# return the cropped dataset
return ds
[docs]
def cumsum(self, **kwargs):
"""
Calculate cumulative sum of ``Dataset`` along time dimension
Returns
-------
ds: xarray.Dataset
Cumulative sum of the ``Dataset``
"""
# calculate cumulative sum along time dimension
ds = self._ds.cumsum(dim="time", skipna=False, **kwargs)
# return the cumulative sum dataset
return ds
[docs]
def gaussian_filter(
self,
sigma: float | list[float] = 1.5,
**kwargs,
):
"""
Apply Gaussian smoothing to the ``Dataset``
Parameters
----------
sigma: float or list, default 1.5
Standard deviation for Gaussian kernel in x and y directions
kwargs: dict
Keyword arguments for ``scipy.ndimage.gaussian_filter``
Returns
-------
ds: xarray.Dataset
Smoothed ``Dataset``
"""
# import gaussian filter function
from scipy.ndimage import gaussian_filter
# set default keyword arguments
kwargs.setdefault("mode", "constant")
kwargs.setdefault("cval", 0)
# create copy of dataset
ds = self._ds.copy(deep=True)
# apply Gaussian smoothing to each variable in the dataset
for v in ds.data_vars.keys():
# use a gaussian filter to smooth mask
mask = np.logical_not(ds[v].isnull().any(dim="time")).astype("f")
kernel = gaussian_filter(mask, sigma=sigma, **kwargs)
for i, t in enumerate(ds.time):
# replace fill values with zeros before smoothing data
tmp = ds[v].isel(time=i).fillna(0.0)
# smooth spatial field
smooth = gaussian_filter(tmp, sigma=sigma, **kwargs)
# scale output smoothed field
scaled = xr.where(kernel != 0, smooth / kernel, np.nan)
# replace valid values with original
ds[v][i, :, :] = xr.where(mask, tmp, scaled)
# return the smoothed dataset
return ds
[docs]
def get(self, name: str):
"""
Get variable in ``Dataset`` using a case-insensitive search
Parameters
----------
name: str
Name of variable to find in dataset
Returns
-------
var: xarray.DataArray or None
Variable from dataset if found, otherwise None
"""
return get_variable(self._ds, name)
[docs]
def grid_interp(
self,
x: np.ndarray,
y: np.ndarray,
method="linear",
**kwargs,
):
"""
Interpolate a regular or rectilinear ``Dataset`` to new coordinates
Parameters
----------
x: np.ndarray
Interpolation x-coordinates
y: np.ndarray
Interpolation y-coordinates
method: str, default 'linear'
Interpolation method
Returns
-------
other: xarray.Dataset
Interpolated ``Dataset``
"""
# pad global grids along x-dimension (if necessary)
if self.is_global:
self._ds = self.pad(n=1)
# verify longitudinal convention for geographic models
if self.crs.is_geographic:
# grid spacing in x-direction
dx = self._x[1] - self._x[0]
# adjust input longitudes to be consistent with model
if (np.min(x) < 0.0) & (self._x.max() > (180.0 + dx)):
# input points convention (-180:180)
# model convention (0:360)
x = xr.where(x < 0.0, x + 360.0, x)
elif (np.max(x) > 180.0) & (self._x.min() < (0.0 - dx)):
# input points convention (0:360)
# model convention (-180:180)
x = xr.where(x > 180.0, x - 360.0, x)
# interpolate dataset using built-in xarray methods
other = self._ds.interp(x=x, y=y, method=method)
# return xarray dataset
return other
[docs]
def inpaint(self, **kwargs):
"""
Inpaint over missing data in ``Dataset``
Parameters
----------
kwargs: dict
Keyword arguments for :py:func:`FirnCorr.interpolate.inpaint`
Returns
-------
ds: xarray.Dataset
Interpolated ``Dataset``
"""
# import inpaint function
from FirnCorr.interpolate import inpaint
# create copy of dataset
ds = self._ds.copy()
# inpaint each variable in the dataset
for v in ds.data_vars.keys():
ds[v].values = inpaint(
self._x, self._y, self._ds[v].values, **kwargs
)
# return the dataset
return ds
[docs]
def interp(
self,
x: np.ndarray,
y: np.ndarray,
**kwargs,
):
"""
Interpolate ``Dataset`` to new coordinates
Parameters
----------
x: np.ndarray
Interpolation x-coordinates
y: np.ndarray
Interpolation y-coordinates
extrapolate: bool, default False
Flag to extrapolate values using nearest-neighbors
cutoff: int or float, default np.inf
Maximum distance for extrapolation
kwargs: dict
Additional keyword arguments for interpolation functions
Returns
-------
other: xarray.Dataset
Interpolated ``Dataset``
"""
# set default keyword arguments
kwargs.setdefault("method", "linear")
kwargs.setdefault("extrapolate", False)
kwargs.setdefault("cutoff", np.inf)
# use built-in xarray interpolation methods
other = self.grid_interp(x, y, **kwargs)
# extrapolate missing values using nearest-neighbors
if kwargs["extrapolate"]:
other = self.extrap_like(other, cutoff=kwargs["cutoff"])
# return xarray dataset
return other
[docs]
def pad(
self,
n: int = 1,
chunks=None,
):
"""
Pad ``Dataset`` by repeating edge values in the x-direction
Parameters
----------
n: int, default 1
Number of padding values to add on each side
Returns
-------
ds: xarray.Dataset
Padded ``Dataset``
"""
# (possibly) unchunk x-coordinates and pad to wrap at meridian
x = xr.DataArray(self._x, dims="x").pad(
x=n, mode="reflect", reflect_type="odd"
)
# pad dataset and re-assign x-coordinates
ds = self._ds.copy()
ds = ds.pad(x=n, mode="wrap").assign_coords(x=x)
# rechunk dataset (if specified)
if chunks is not None:
ds = ds.chunk(chunks)
# return the dataset
return ds
[docs]
def to_anomaly(
self,
reference: str | None = None,
climatology: list | None = None,
):
"""
Convert ``Dataset`` to anomalies relative to a reference period
Parameters
----------
reference: str or None
Method for referencing anomalies
- ``'first'``: remove first time step
- ``'mean'``: remove mean over a time range
climatology: list, default None
Time range for calculating mean reference
"""
# if referencing anomalies: change from absolute to relative values
if reference == "first":
# subtract first time step from all time steps
z0 = self._ds.isel(time=0)
ds = self._ds - z0
elif reference == "mean":
# get time range for calculating reference period
if climatology is None:
# default time range is the full range of the dataset
tmin = self._ds["time"].values.min()
tmax = self._ds["time"].values.max()
elif isinstance(climatology[0], (int, float)):
# convert years to numpy datetime64 format
tmin = np.array(climatology[0] - 1970, dtype="datetime64[Y]")
tmax = np.array(climatology[1] - 1970, dtype="datetime64[Y]")
else:
# verify that time range is in datetime64 format
tmin, tmax = np.array(climatology, dtype="datetime64[D]")
# subtract mean from all time steps
zmean = self._ds.where(
(self._ds["time"] >= tmin) & (self._ds["time"] < tmax + 1),
drop=True,
)
ds = self._ds - zmean.mean(dim="time")
# add (actual) climatology attributes to variable
ds.attrs["climatology"] = np.array(
[zmean.time.values.min(), zmean.time.values.max()]
).astype("datetime64[D]")
else:
raise ValueError(f"Invalid reference method: {reference}")
# return the anomaly dataset
return ds
[docs]
def to_geographic(self, crs: str | int | dict = 4326):
"""
Get latitude and longitude coordinates for the ``Dataset``
Parameters
----------
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system for geographic coordinates
Returns
-------
lon: xarray.DataArray
Longitude coordinates for the dataset
lat: xarray.DataArray
Latitude coordinates for the dataset
"""
# target spatial reference
target_crs = pyproj.CRS.from_user_input(crs)
# create transformation
transformer = pyproj.Transformer.from_crs(
self.crs, target_crs, always_xy=True
)
# create meshgrid of points in original projection
gridx, gridy = np.meshgrid(self._x, self._y)
# convert coordinates to latitude and longitude
lon, lat = transformer.transform(gridx, gridy)
# convert to xarray DataArrays
coords = dict(y=self._ds.y, x=self._ds.x)
lon = xr.DataArray(lon, coords=coords, dims=["y", "x"])
lat = xr.DataArray(lat, coords=coords, dims=["y", "x"])
return lon, lat
[docs]
def to_units(
self,
units: str,
value: float = 1.0,
):
"""Convert ``Dataset`` to specified units
Parameters
----------
units: str
Output units
value: float, default 1.0
Scaling factor to apply
"""
# create copy of dataset
ds = self._ds.copy()
# convert each variable in the dataset
for k in ds.data_vars.keys():
ds[k] = ds[k].fcorr.to_units(units, value=value)
# return the dataset
return ds
[docs]
def to_base_units(self):
"""Convert ``Dataset`` to base units"""
# create copy of dataset
ds = self._ds.copy()
# convert each variable in the dataset
for k in ds.data_vars.keys():
ds[k] = ds[k].fcorr.to_base_units()
# return the dataset
return ds
[docs]
def to_default_units(self):
"""Convert ``Dataset`` to default units"""
# create copy of dataset
ds = self._ds.copy()
# convert each variable in the dataset
for k in ds.data_vars.keys():
ds[k] = ds[k].fcorr.to_default_units()
# return the dataset
return ds
@property
def crs(self):
"""Coordinate reference system of the ``Dataset``"""
# return the CRS of the dataset
# default is EPSG:4326 (WGS84)
CRS = self._ds.attrs.get("crs", 4326)
return pyproj.CRS.from_user_input(CRS)
@property
def is_global(self) -> bool:
"""Determine if ``Dataset`` covers a global domain"""
# grid spacing in x-direction
dx = self._x[1] - self._x[0]
# check if global grid
cyclic = np.isclose(self._x[-1] - self._x[0], 360.0 - dx)
return self.crs.is_geographic and cyclic
@property
def area_of_use(self) -> str | None:
"""Area of use from the ``Dataset`` CRS"""
if self.crs.area_of_use is not None:
return self.crs.area_of_use.name.replace(".", "").lower()
@property
def axis_units(self) -> str:
"""Units of the coordinate axes from the ``Dataset`` CRS"""
return self.crs.axis_info[0].unit_name
@property
def _x(self):
"""x-coordinates of the ``Dataset``"""
return self._ds.x.values
@property
def _y(self):
"""y-coordinates of the ``Dataset``"""
return self._ds.y.values
[docs]
@xr.register_dataarray_accessor("fcorr")
class DataArray:
"""Accessor for extending an ``xarray.DataArray`` for SMB and firn data"""
def __init__(self, da):
# initialize DataArray
self._da = da
[docs]
def to_units(
self,
units: str,
value: float = 1.0,
):
"""Convert ``DataArray`` to specified units
Parameters
----------
units: str
Output units
value: float, default 1.0
Scaling factor to apply
"""
# convert to specified units
conversion = value * self.quantity.to(units)
da = self._da * conversion.magnitude
da.attrs["units"] = str(conversion.units)
return da
[docs]
def to_base_units(self, value=1.0):
"""Convert ``DataArray`` to base units
Parameters
----------
value: float, default 1.0
Scaling factor to apply
"""
# convert to base units
conversion = value * self.quantity.to_base_units()
da = self._da * conversion.magnitude
da.attrs["units"] = str(conversion.units)
return da
[docs]
def to_default_units(self, value=1.0):
"""Convert ``DataArray`` to default units
Parameters
----------
value: float, default 1.0
Scaling factor to apply
"""
# convert to default units
default_units = _default_units.get(self.group, self.units)
da = self.to_units(default_units, value=value)
return da
@property
def units(self):
"""Units of the ``DataArray``"""
try:
return self._parse_units(self._units)
except TypeError as exc:
raise ValueError(f"Unknown units: {self._units}") from exc
except AttributeError as exc:
raise AttributeError("DataArray has no attribute 'units'") from exc
@property
def quantity(self):
"""``Pint`` Quantity of the ``DataArray``"""
return 1.0 * self.units
@property
def group(self):
"""Variable group of the ``DataArray``"""
if self.units.is_compatible_with("m"):
return "elevation"
elif self.units.is_compatible_with("m/s"):
return "velocity"
elif self.units.is_compatible_with("g / cm^2"):
return "mass density"
elif self.units.is_compatible_with("g"):
return "mass"
elif self.units.is_compatible_with("degrees"):
return "angle"
else:
raise ValueError(f"Unknown unit group: {self._units}")
@staticmethod
def _parse_units(units: str):
"""
Convert units attributes to ``pint`` units
"""
# fix the exponent notation in units string
units = re.sub(
r"(\w)([-]?\d+)",
lambda m: m.group(1) + r"^" + m.group(2),
units,
flags=re.IGNORECASE,
)
# remove "of" from units string
units = re.sub(
r"of\s(water|ice|air)",
lambda m: m.group(1),
units,
flags=re.IGNORECASE,
)
# prepend "equivalent" with underscore to units string
units = re.sub(
r"\s+equivalent",
"_equivalent",
units,
flags=re.IGNORECASE,
)
# delete periods between water or ice equivalent units
units = re.sub(
r"(w|i)\.e[q]?\.",
lambda m: m.group(1) + "e",
units,
flags=re.IGNORECASE,
)
# add a space before water or ice equivalent units
units = re.sub(
r"([\w])(we|ie)\b",
lambda m: m.group(1) + " " + m.group(2),
units,
flags=re.IGNORECASE,
)
# parse units string using pint
return __ureg__.parse_units(units.lower())
@property
def _units(self):
"""Units attribute of the ``DataArray`` as a string"""
return self._da.attrs.get("units")
@property
def _has_compatible_units(self):
"""Tests that units are compatible with known groups"""
try:
unit_group = self.group
except (TypeError, ValueError, AttributeError) as exc:
return False
else:
return True
[docs]
def combine_attrs(
attrs_list: list[dict],
context: str | None,
**kwargs,
) -> dict:
"""
Combine attributes from multiple datasets into a single dictionary
merging conflicting values into a list
Parameters
----------
attrs_list: list of dict
List of attribute dictionaries from multiple datasets
context: str
Context for the attributes being combined
skip_keys: list of str, default ["units"]
List of attribute keys to skip from comparison
Returns
-------
result: dict
Combined attributes dictionary
"""
# set default keyword arguments
skip_keys = kwargs.get("skip_keys", ["units"])
# return an empty dictionary when no attributes are provided
if not attrs_list:
return {}
# initialize combined attributes with the first dictionary in the list
result = attrs_list[0].copy()
append_keys = set()
# for each attribute key, check if values are equivalent
for attrs in attrs_list:
for key, value in attrs.items():
# skip keys that have already been identified as conflicts
# and keys that should be skipped from comparison
if key in append_keys or key in skip_keys:
continue
# check if the attribute values are equivalent
if not equivalent_attrs(result.get(key), value):
append_keys.add(key)
# combine conflicting attributes into lists
for key in append_keys:
# build list of values for this key across all datasets
combined_values = []
for attrs in attrs_list:
# check if the key is present
# if a list or tuple: extend the combined values
# if a single value: append to the combined values
if key in attrs and isinstance(attrs[key], (list, tuple)):
combined_values.extend(attrs[key])
elif key in attrs:
combined_values.append(attrs[key])
# clean up combined results: removing duplicates and null values
result[key] = sorted(set(filter(None, combined_values)))
# if only one unique value remains, simplify back to a single value
if len(result[key]) == 1:
result[key] = result[key].pop()
# return the combined attributes
return result
[docs]
def equivalent_attrs(a: Any, b: Any) -> bool:
"""
Check if two attribute values are equivalent (ignoring case for strings)
Adapted from ``xarray.structure.merge.equivalent_attrs``
Parameters
----------
a: Any
First attribute value
b: Any
Second attribute value
"""
# if both attributes are strings, compare them case-insensitively
if isinstance(a, str) and isinstance(b, str):
return equivalent(a.casefold(), b.casefold())
# otherwise, compare the attributes directly
# exceptions would indicate comparison is ambiguous
try:
return equivalent(a, b)
except (TypeError, ValueError):
return False
[docs]
def get_variable(ds: xr.Dataset, name: str) -> xr.DataArray:
"""
Get variable from a ``Dataset`` using a case-insensitive search
Parameters
----------
ds: xarray.Dataset
Dataset to search for variable
name: str
Name the variable to find
Returns
-------
var: xarray.DataArray
Variable matching the input name
"""
# case-insensitive search for variable in dataset
imap = [v for v in ds.data_vars if (v.casefold() == name.casefold())]
# check if variable is in dataset
if name in ds.data_vars:
pass
elif not any(imap):
return None
elif len(imap) == 1:
name = imap.pop()
elif len(imap) > 1:
raise ValueError(f"Ambiguous mapping of {name} in dataset")
# return the variable from the dataset
return ds[name]
[docs]
def register_datatree_subaccessor(name):
"""Register a custom subaccessor on ``DataTree`` objects
Parameters
----------
name: str
Name of the subaccessor
"""
return xr.core.extensions._register_accessor(name, DataTree)
[docs]
def register_dataset_subaccessor(name):
"""Register a custom subaccessor on ``Dataset`` objects
Parameters
----------
name: str
Name of the subaccessor
"""
return xr.core.extensions._register_accessor(name, Dataset)
[docs]
def register_dataarray_subaccessor(name):
"""Register a custom subaccessor on ``DataArray`` objects
Parameters
----------
name: str
Name of the subaccessor
"""
return xr.core.extensions._register_accessor(name, DataArray)
[docs]
def _coords(
x: np.ndarray,
y: np.ndarray,
source_crs: str | int | dict = 4326,
target_crs: str | int | dict = None,
**kwargs,
):
"""
Transform coordinates into DataArrays in a new
coordinate reference system
Parameters
----------
x: np.ndarray
Input x-coordinates
y: np.ndarray
Input y-coordinates
source_crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of input coordinates
target_crs: str, int, or dict, default None
Coordinate reference system of output coordinates
type: str or None, default None
Coordinate data type
If not provided: must specify ``time`` parameter to auto-detect
- ``None``: determined from input variable dimensions
- ``'drift'``: drift buoys or satellite/airborne altimetry
- ``'grid'``: spatial grids or images
- ``'time series'``: time series at a single point
time: np.ndarray or None, default None
Time variable for determining coordinate data type
Returns
-------
X: xarray.DataArray
Transformed x-coordinates
Y: xarray.DataArray
Transformed y-coordinates
"""
from FirnCorr.spatial import data_type
# set default keyword arguments
kwargs.setdefault("type", None)
kwargs.setdefault("time", None)
# determine coordinate data type if possible
if (np.ndim(x) == 0) and (np.ndim(y) == 0):
coord_type = "time series"
elif kwargs["type"] is None:
# must provide time variable to determine data type
assert kwargs["time"] is not None, (
"Must provide time parameter when type is not specified"
)
coord_type = data_type(x, y, np.ravel(kwargs["time"]))
else:
# use provided coordinate data type
# and verify that it is lowercase
coord_type = kwargs.get("type").lower()
# convert coordinates to a new coordinate reference system
if (coord_type == "grid") and (np.size(x) != np.size(y)):
gridx, gridy = np.meshgrid(x, y)
mx, my = _transform(
gridx,
gridy,
source_crs=source_crs,
target_crs=target_crs,
direction="FORWARD",
)
else:
mx, my = _transform(
x,
y,
source_crs=source_crs,
target_crs=target_crs,
direction="FORWARD",
)
# convert to xarray DataArray with appropriate dimensions
if (np.ndim(x) == 0) and (np.ndim(y) == 0):
X = xr.DataArray(mx)
Y = xr.DataArray(my)
elif coord_type == "grid":
X = xr.DataArray(mx, dims=("y", "x"))
Y = xr.DataArray(my, dims=("y", "x"))
elif coord_type == "drift":
X = xr.DataArray(mx, dims=("time"))
Y = xr.DataArray(my, dims=("time"))
elif coord_type == "time series":
X = xr.DataArray(mx, dims=("station"))
Y = xr.DataArray(my, dims=("station"))
else:
raise ValueError(f"Unknown coordinate data type: {coord_type}")
# return the transformed coordinates
return (X, Y)