Source code for piglot.utils.response_transformer
"""Module for defining transformations for responses."""
from typing import Union, Dict, Any, Type, List, Tuple
from abc import ABC, abstractmethod
import numpy as np
from piglot.utils.assorted import read_custom_module
from piglot.solver.solver import OutputResult
from piglot.utils.responses import interpolate_response
[docs]
class ResponseTransformer(ABC):
"""Abstract class for defining transformation functions."""
[docs]
@abstractmethod
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
[docs]
def __call__(self, x_old: np.ndarray, y_old: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Transform a response function.
Parameters
----------
x_old : np.ndarray
Original time grid.
y_old : np.ndarray
Original values.
Returns
-------
Tuple[np.ndarray, np.ndarray]
Transformed time grid and values.
"""
response = OutputResult(x_old, y_old)
response = self.transform(response)
return response.time, response.data
[docs]
class ChainResponse(ResponseTransformer):
"""Chain of response transformers."""
def __init__(self, transformers: List[Any]):
self.transformers = [read_response_transformer(t) for t in transformers]
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
Output
Transformed time and data points of the response.
"""
for transformer in self.transformers:
response = transformer.transform(response)
return response
[docs]
class MinimumResponse(ResponseTransformer):
"""Minimum of a response transformer."""
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(np.array([0.0]), np.array([np.min(response.data)]))
[docs]
class MaximumResponse(ResponseTransformer):
"""Maximum of a response transformer."""
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(np.array([0.0]), np.array([np.max(response.data)]))
[docs]
class NegateResponse(ResponseTransformer):
"""Negate a response transformer."""
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(response.time, -response.data)
[docs]
class SquareResponse(ResponseTransformer):
"""Square a response transformer."""
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
Output
Transformed time and data points of the response.
"""
return OutputResult(response.time, np.square(response.data))
[docs]
class AffineTransformResponse(ResponseTransformer):
"""Affine transformation of a response transformer."""
def __init__(
self,
x_scale: float = 1.0,
x_offset: float = 0.0,
y_scale: float = 1.0,
y_offset: float = 0.0,
):
self.x_scale = x_scale
self.x_offset = x_offset
self.y_scale = y_scale
self.y_offset = y_offset
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(
self.x_scale * response.time + self.x_offset,
self.y_scale * response.data + self.y_offset,
)
[docs]
class ClipResponse(ResponseTransformer):
"""Clip x and y values of the response to given bounds."""
def __init__(
self,
x_min: float = -np.inf,
x_max: float = np.inf,
y_min: float = -np.inf,
y_max: float = np.inf,
):
self.x_min = x_min
self.x_max = x_max
self.y_min = y_min
self.y_max = y_max
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(
np.clip(response.time, self.x_min, self.x_max),
np.clip(response.data, self.y_min, self.y_max),
)
[docs]
class PointwiseErrors(ResponseTransformer):
"""Compute the pointwise errors between the response and a reference."""
def __init__(self, reference_time: np.ndarray, reference_data: np.ndarray) -> None:
self.reference_time = reference_time
self.reference_data = reference_data
[docs]
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
# Interpolate response to the reference grid
resp_interp = interpolate_response(
response.get_time(),
response.get_data(),
self.reference_time,
)
# Compute normalised error
factor = np.mean(np.abs(self.reference_data))
return OutputResult(self.reference_time, (resp_interp - self.reference_data) / factor)
AVAILABLE_RESPONSE_TRANSFORMERS: Dict[str, Type[ResponseTransformer]] = {
'min': MinimumResponse,
'max': MaximumResponse,
'negate': NegateResponse,
'square': SquareResponse,
'chain': ChainResponse,
'clip': ClipResponse,
'affine': AffineTransformResponse,
}
[docs]
def read_response_transformer(config: Union[str, Dict[str, Any]]) -> ResponseTransformer:
"""Read a response transformer from a configuration.
Parameters
----------
config : Union[str, Dict[str, Any]]
Configuration of the response transformer.
Returns
-------
ResponseTransformer
Response transformer.
"""
# Parse the transformer in the simple format
if isinstance(config, str):
name = config
if name == 'script':
raise ValueError('Need to pass the file path for the "script" transformer.')
if name not in AVAILABLE_RESPONSE_TRANSFORMERS:
raise ValueError(f'Response transformer "{name}" is not available.')
return AVAILABLE_RESPONSE_TRANSFORMERS[name]()
# Detailed format
if 'name' not in config:
raise ValueError('Need to pass the name of the response transformer.')
name = config.pop('name')
# Read script transformer
if name == 'script':
return read_custom_module(config, ResponseTransformer)()
if name not in AVAILABLE_RESPONSE_TRANSFORMERS:
raise ValueError(f'Response transformer "{name}" is not available.')
return AVAILABLE_RESPONSE_TRANSFORMERS[name](**config)