Source code for piglot.utils.reductions

"""Module for defining reduction functions for responses."""
from typing import Callable, Dict, Any, Union
from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.autograd.gradcheck import gradcheck, GradcheckError
from piglot.utils.assorted import read_custom_module


[docs] class Reduction(ABC): """Abstract class for defining reduction functions."""
[docs] @abstractmethod def reduce_torch( self, time: torch.Tensor, data: torch.Tensor, params: torch.Tensor, ) -> torch.Tensor: """Reduce the input data to a single value (with gradients). Parameters ---------- time : torch.Tensor Time points of the response. data : torch.Tensor Data points of the response. params : torch.Tensor Parameters for the given responses. Returns ------- torch.Tensor Reduced value of the data. """
[docs] def reduce(self, time: np.ndarray, data: np.ndarray, params: np.ndarray) -> np.ndarray: """Reduce the input data to a single value. Parameters ---------- time : np.ndarray Time points of the response. data : np.ndarray Data points of the response. params : np.ndarray Parameters for the given responses. Returns ------- np.ndarray Reduced value of the data. """ return self.reduce_torch( torch.from_numpy(time), torch.from_numpy(data), torch.from_numpy(params), ).numpy(force=True)
[docs] def test_reduction(self) -> None: """Test the reduction function to check batch processing.""" # Sanitise the shape after applying the reduction test_params = [2, 6] test_shapes = [(4,), (2, 4), (3, 2, 4), (6, 3, 2, 4)] for num_params in test_params: for shape in test_shapes: time = torch.arange(shape[-1]).repeat(shape[:-1] + (1,)) data = torch.randn(*shape) params = torch.randn(num_params).repeat(shape[:-1] + (1,)) try: reduced = self.reduce_torch(time, data, params) except Exception as exc: raise ValueError(f"Test failed for reduction {type(self)}.") from exc if reduced.shape != shape[:-1]: raise ValueError( f"Bad shape after reduction for {type(self)}. " f"While reducing a tensor of shape {shape}, " f"got {reduced.shape} instead of {shape[:-1]}." ) # Check if the gradient is computed time = torch.tensor([[0, 1], [1, 2]], requires_grad=True, dtype=torch.float64) data = torch.tensor([[2, 3], [4, 3]], requires_grad=True, dtype=torch.float64) params = torch.tensor([[1, 2], [3, 4]], requires_grad=True, dtype=torch.float64) try: if not gradcheck(self.reduce_torch, (time, data, params)): raise ValueError(f"Gradient check failed for {type(self)}.") except GradcheckError as exc: raise ValueError(f"Gradient check failed for {type(self)}.") from exc
[docs] class NegateReduction(Reduction): """Negate the result of another reduction function.""" def __init__(self, reduction: Reduction) -> None: self.reduction = reduction
[docs] def reduce_torch( self, time: torch.Tensor, data: torch.Tensor, params: torch.Tensor, ) -> torch.Tensor: """Reduce the input data to a single value. Parameters ---------- time : np.ndarray Time points of the response. data : np.ndarray Data points of the response. params : np.ndarray Parameters for the given responses. Returns ------- np.ndarray Reduced value of the data. """ return -self.reduction.reduce_torch(time, data, params)
[docs] class SimpleReduction(Reduction): """Reduction function defined from a lambda function (without using the parameters).""" def __init__(self, reduction: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> None: self.reduction = reduction
[docs] def reduce_torch( self, time: torch.Tensor, data: torch.Tensor, params: torch.Tensor, ) -> torch.Tensor: """Reduce the input data to a single value. Parameters ---------- time : torch.Tensor Time points of the response. data : torch.Tensor Data points of the response. params : torch.Tensor Parameters for the given responses. Returns ------- torch.Tensor Reduced value of the data. """ return self.reduction(time, data)
AVAILABLE_REDUCTIONS: Dict[str, Reduction] = { 'mean': SimpleReduction(lambda time, data: torch.mean(data, dim=-1)), 'max': SimpleReduction(lambda time, data: torch.amax(data, dim=-1)), 'min': SimpleReduction(lambda time, data: torch.amin(data, dim=-1)), 'sum': SimpleReduction(lambda time, data: torch.sum(data, dim=-1)), 'std': SimpleReduction(lambda time, data: torch.std(data, dim=-1)), 'var': SimpleReduction(lambda time, data: torch.var(data, dim=-1)), 'mse': SimpleReduction(lambda time, data: torch.mean(torch.square(data), dim=-1)), 'mae': SimpleReduction(lambda time, data: torch.mean(torch.abs(data), dim=-1)), 'last': SimpleReduction(lambda time, data: data[..., -1]), 'first': SimpleReduction(lambda time, data: data[..., 0]), 'max_abs': SimpleReduction(lambda time, data: torch.amax(torch.abs(data), dim=-1)), 'min_abs': SimpleReduction(lambda time, data: torch.amin(torch.abs(data), dim=-1)), 'integral': SimpleReduction(lambda time, data: torch.trapz(data, time, dim=-1)), 'square_integral': SimpleReduction( lambda time, data: torch.trapz(torch.square(data), time, dim=-1), ), 'abs_integral': SimpleReduction( lambda time, data: torch.trapz(torch.abs(data), time, dim=-1), ), } # TODO: Add test for non-existing 'script' reduction
[docs] def read_reduction(config: Union[str, Dict[str, Any]]) -> Reduction: """Read a reduction function from a configuration. Parameters ---------- config : Union[str, Dict[str, Any]] Configuration of the reduction function. Returns ------- Reduction Reduction function. """ # Parse the reduction in the simple format if isinstance(config, str): name = config if name == 'script': raise ValueError('Need to pass the file path for the "script" reduction.') if name not in AVAILABLE_REDUCTIONS: raise ValueError(f'Reduction function "{name}" is not available.') return AVAILABLE_REDUCTIONS[name] # Detailed format if 'name' not in config: raise ValueError('Need to pass the name of the reduction function.') name = config['name'] # Read script reduction if name == 'script': instance = read_custom_module(config, Reduction)() # Sanitise external reductions if not bool(config.get('skip_test', False)): instance.test_reduction() return instance if name not in AVAILABLE_REDUCTIONS: raise ValueError(f'Reduction function "{name}" is not available.') return AVAILABLE_REDUCTIONS[name]