"""Module for solvers."""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Dict, Any, Type, TypeVar
from abc import ABC, abstractmethod
import os
import time
import shutil
import numpy as np
from yaml import safe_dump_all, safe_load_all
from piglot.parameter import ParameterSet
from piglot.utils.assorted import pretty_time
from piglot.utils.solver_utils import VerbosityManager
T = TypeVar('T', bound='Solver')
[docs]
@dataclass
class OutputResult:
"""Container for output results."""
time: np.ndarray
data: np.ndarray
[docs]
def get_time(self) -> np.ndarray:
"""Get the time column of the result.
Returns
-------
np.ndarray
Time column.
"""
return self.time
[docs]
def get_data(self) -> np.ndarray:
"""Get the data column of the result.
Returns
-------
np.ndarray
Data column.
"""
return self.data
[docs]
@dataclass
class CaseResult:
"""Class for case results."""
begin_time: float
run_time: float
values: np.ndarray
success: bool
param_hash: str
responses: Dict[str, OutputResult]
[docs]
def write(self, filename: str, parameters: ParameterSet) -> None:
"""Write out the case result.
Parameters
----------
filename : str
Path to write the file to.
parameters : ParameterSet
Set of parameters for this case.
"""
# Build case metadata
metadata = {
"start_time": time.strftime("%a, %d %b %Y %H:%M:%S", time.gmtime(self.begin_time)),
"begin_time": self.begin_time,
"run_time": self.run_time,
"run_time (pretty)": pretty_time(self.run_time),
"parameters": {p.name: float(v) for p, v in zip(parameters, self.values)},
"success": "true" if self.success else "false",
"param_hash": self.param_hash,
}
# Build response data
responses = {
name: list(zip(result.get_time().tolist(), result.get_data().tolist()))
for name, result in self.responses.items()
}
# Dump all data to file
with open(filename, 'w', encoding='utf8') as file:
safe_dump_all((metadata, responses), file)
[docs]
@staticmethod
def read(filename: str, parameters: ParameterSet) -> CaseResult:
"""Read a case result file.
Parameters
----------
filename : str
Path to the case result file.
parameters : ParameterSet
Set of parameters for this case.
Returns
-------
CaseResult
Result instance.
"""
# Read the file
with open(filename, 'r', encoding='utf8') as file:
metadata, responses_raw = safe_load_all(file)
# Parse the responses
responses = {
name: OutputResult(np.array([a[0] for a in data]), np.array([a[1] for a in data]))
for name, data in responses_raw.items()
}
return CaseResult(
metadata["begin_time"],
metadata["run_time"],
np.array([float(metadata["parameters"][p.name]) for p in parameters]),
metadata["success"] == "true",
metadata["param_hash"],
responses,
)
[docs]
class Solver(ABC):
"""Base class for solvers."""
def __init__(
self,
parameters: ParameterSet,
output_dir: str,
tmp_dir: str,
verbosity: str,
) -> None:
self.parameters = parameters
self.output_dir = output_dir
self.tmp_dir = tmp_dir
self.verbosity_manager = VerbosityManager(verbosity, os.path.join(output_dir, 'solver'))
self.begin_time = time.time()
[docs]
@abstractmethod
def prepare(self) -> None:
"""Prepare data for the optimisation."""
[docs]
@abstractmethod
def solve(
self,
values: np.ndarray,
concurrent: bool,
) -> Dict[str, OutputResult]:
"""Solve all cases for the given set of parameter values.
Parameters
----------
values : array
Current parameters to evaluate.
concurrent : bool
Whether this run may be concurrent to another one (so use unique file names).
Returns
-------
Dict[str, OutputResult]
Evaluated results for each output field.
"""
[docs]
@abstractmethod
def get_output_fields(self) -> List[str]:
"""Get all output fields.
Returns
-------
List[str]
Output fields.
"""
[docs]
@abstractmethod
def get_case_params(self, param_hash: str) -> Dict[str, float]:
"""Get the parameters for a given hash.
Parameters
----------
param_hash : str
Hash of the case to load.
Returns
-------
Dict[str, float]
Parameters for this hash.
"""
[docs]
@abstractmethod
def get_output_response(self, param_hash: str) -> Dict[str, OutputResult]:
"""Get the responses from all output fields for a given case.
Parameters
----------
param_hash : str
Hash of the case to load.
Returns
-------
Dict[str, OutputResult]
Output responses.
"""
[docs]
def get_current_response(self) -> Dict[str, OutputResult]:
"""Get the responses from a given output field for all cases.
Returns
-------
Dict[str, OutputResult]
Output responses.
"""
raise NotImplementedError("This solver does not support getting current responses.")
[docs]
@classmethod
@abstractmethod
def read(
cls: Type[T],
config: Dict[str, Any],
parameters: ParameterSet,
output_dir: str,
) -> T:
"""Read the solver from the configuration dictionary.
Parameters
----------
config : Dict[str, Any]
Configuration dictionary.
parameters : ParameterSet
Parameter set for this problem.
output_dir : str
Path to the output directory.
Returns
-------
Solver
Solver to use for this problem.
"""
[docs]
class SingleCaseSolver(Solver, ABC):
"""Generic class for solvers with a single case."""
def __init__(
self,
output_fields: List[str],
parameters: ParameterSet,
output_dir: str,
tmp_dir: str,
verbosity: str,
) -> None:
"""Constructor for the solver class.
Parameters
----------
output_fields : List[str]
List of output fields.
parameters : ParameterSet
Parameter set for this problem.
output_dir : str
Path to the output directory.
tmp_dir : str
Path to the temporary directory.
"""
super().__init__(parameters, output_dir, tmp_dir, verbosity)
self.output_fields = output_fields
self.cases_dir = os.path.join(output_dir, "cases")
self.cases_hist = os.path.join(output_dir, "cases_hist")
[docs]
def prepare(self) -> None:
"""Prepare data for the optimisation."""
self.verbosity_manager.prepare()
# Create output directories
os.makedirs(self.cases_dir, exist_ok=True)
if os.path.isdir(self.cases_hist):
shutil.rmtree(self.cases_hist)
os.mkdir(self.cases_hist)
# Build headers for case log files
for case in self.output_fields:
case_dir = os.path.join(self.cases_dir, case)
with open(case_dir, 'w', encoding='utf8') as file:
file.write(f"{'Start Time /s':>15}\t")
file.write(f"{'Run Time /s':>15}\t")
file.write(f"{'Success':>10}\t")
for param in self.parameters:
file.write(f"{param.name:>15}\t")
file.write(f'{"Hash":>64}\n')
[docs]
def get_output_fields(self) -> List[str]:
"""Get all output fields.
Returns
-------
List[str]
Output fields.
"""
return self.output_fields
[docs]
def get_case_result(self, param_hash: str) -> CaseResult:
"""Get the result for a given case.
Parameters
----------
param_hash : str
Hash of the case to load.
Returns
-------
CaseResult
Result for this hash.
"""
return CaseResult.read(os.path.join(self.cases_hist, param_hash), self.parameters)
[docs]
def get_case_params(self, param_hash: str) -> Dict[str, float]:
"""Get the parameters for a given hash.
Parameters
----------
param_hash : str
Hash of the case to load.
Returns
-------
Dict[str, float]
Parameters for this hash.
"""
result = self.get_case_result(param_hash)
return {param.name: result.values[i] for i, param in enumerate(self.parameters)}
[docs]
def get_output_response(self, param_hash: str) -> Dict[str, OutputResult]:
"""Get the responses from all output fields for a given case.
Parameters
----------
param_hash : str
Hash of the case to load.
Returns
-------
Dict[str, OutputResult]
Output responses.
"""
return self.get_case_result(param_hash).responses
@abstractmethod
def _solve(self, values: np.ndarray, concurrent: bool) -> Dict[str, OutputResult]:
"""Internal solver for the prescribed problems.
Parameters
----------
values : array
Current parameters to evaluate.
concurrent : bool
Whether this run may be concurrent to another one (so use unique file names).
Returns
-------
Dict[str, OutputResult]
Evaluated results for each output field.
"""
[docs]
def solve(
self,
values: np.ndarray,
concurrent: bool,
) -> Dict[str, OutputResult]:
"""Solve all cases for the given set of parameter values.
Parameters
----------
values : array
Current parameters to evaluate.
concurrent : bool
Whether this run may be concurrent to another one (so use unique file names).
Returns
-------
Dict[str, OutputResult]
Evaluated results for each output field.
"""
# Run the solver
begin_time = time.time()
with self.verbosity_manager:
results = self._solve(values, concurrent)
run_time = time.time() - begin_time
# Post-process results: write history entries
param_hash = self.parameters.hash(values)
case_result = CaseResult(begin_time, run_time, values, True, param_hash, results)
case_result.write(os.path.join(self.cases_hist, case_result.param_hash), self.parameters)
return results