Source code for piglot.solver.input_file_solver
"""Module for input file-based solvers."""
from __future__ import annotations
from typing import Dict, Any, List, TypeVar, Type, Callable, Tuple
from abc import ABC, abstractmethod
import os
import re
import time
import shutil
import numpy as np
from piglot.parameter import ParameterSet
from piglot.solver.solver import OutputResult, CaseResult
from piglot.solver.multi_case_solver import Case, MultiCaseSolver
from piglot.utils.assorted import read_custom_module
T = TypeVar('T', bound='OutputField')
V = TypeVar('V', bound='InputFileCase')
[docs]
def write_parameters(
param_value: Dict[str, float],
source: str,
dest: str,
regex: Callable[[str], str] = lambda param: r'\<' + param + r'\>',
) -> None:
"""Write the set of parameters to the input file.
Parameters
----------
param_value : Dict[str, float]
Collection of parameters and their values.
source : str
Source input file, to be copied to the destination.
dest : str
Destination input file.
regex : Callable[[str], str], optional
Function to generate the regex for the parameter substitution.
By default, uses the regex to replace "<param_name>" with the value.
"""
with open(source, 'r', encoding='utf8') as fin:
with open(dest, 'w', encoding='utf8') as fout:
for line in fin:
for parameter, value in param_value.items():
line = re.sub(regex(parameter), str(value), line)
fout.write(line)
[docs]
class InputData:
"""Class for input file-based input data."""
def __init__(self, tmp_dir: str, input_file: str, dependencies: List[str]) -> None:
self.tmp_dir = tmp_dir
self.input_file = input_file
self.dependencies = dependencies
[docs]
class InputDataGenerator(ABC):
"""Base class for input data generators for input file-based solvers."""
[docs]
@abstractmethod
def generate(self, parameters: ParameterSet, values: np.ndarray, tmp_dir: str) -> InputData:
"""Generate the input data for the given set of parameters.
Parameters
----------
parameters : ParameterSet
Parameter set for this problem.
values : np.ndarray
Current parameters to evaluate.
tmp_dir : str
Temporary directory to run the problem.
"""
[docs]
class DefaultInputDataGenerator(InputDataGenerator):
"""Default input data generator for input file-based solvers."""
def __init__(
self,
input_file: str,
substitution_dependencies: List[str] = None,
copy_dependencies: List[str] = None,
) -> None:
self.input_file = input_file
self.substitution_dependencies = substitution_dependencies or []
self.copy_dependencies = copy_dependencies or []
[docs]
def generate(self, parameters: ParameterSet, values: np.ndarray, tmp_dir: str) -> InputData:
"""Generate the input data for the given set of parameters.
Parameters
----------
parameters : ParameterSet
Parameter set for this problem.
values : np.ndarray
Current parameters to evaluate.
tmp_dir : str
Temporary directory to run the problem.
Returns
-------
InputData
Input data for this problem.
"""
param_dict = parameters.to_dict(values)
# Replace parameters in the input file
gen_input_file = os.path.join(tmp_dir, self.input_file)
write_parameters(param_dict, self.input_file, gen_input_file)
# Replace parameters in the dependencies
dependencies = []
for dep in self.substitution_dependencies:
output_file = os.path.join(tmp_dir, dep)
write_parameters(param_dict, dep, output_file)
dependencies.append(os.path.basename(output_file))
# Copy dependencies
for dep in self.copy_dependencies:
output_file = os.path.join(tmp_dir, dep)
shutil.copy(dep, output_file)
dependencies.append(os.path.basename(output_file))
return InputData(tmp_dir, os.path.basename(gen_input_file), dependencies)
[docs]
class OutputField(ABC):
"""Generic class for output fields."""
[docs]
@abstractmethod
def check(self, input_data: InputData) -> None:
"""Check for validity in the input data before reading.
Parameters
----------
input_data : InputData
Input data to check for.
"""
[docs]
@abstractmethod
def get(self, input_data: InputData) -> OutputResult:
"""Read the output data from the simulation.
Parameters
----------
input_data : InputData
Input data to check for.
Returns
-------
OutputResult
Output result for this field.
"""
[docs]
@classmethod
@abstractmethod
def read(cls: Type[T], config: Dict[str, Any]) -> T:
"""Read the output field from the configuration dictionary.
Parameters
----------
config : Dict[str, Any]
Configuration dictionary.
Returns
-------
OutputField
Output field to use for this problem.
"""
[docs]
class ScriptOutputField(OutputField):
"""Class for script-bsaed output fields."""
[docs]
def check(self, input_data: InputData) -> None:
"""Check for validity in the input data before reading.
Parameters
----------
input_data : InputData
Input data to check for.
"""
[docs]
@staticmethod
def read(config: Dict[str, Any]) -> ScriptOutputField:
"""Read the output field from the configuration dictionary.
Parameters
----------
config : Dict[str, Any]
Configuration dictionary.
Returns
-------
ScriptOutputField
Output field to use for this problem.
"""
raise RuntimeError("Cannot read the configuration for a script-based output field.")
[docs]
class InputFileCase(Case, ABC):
"""Base case class for input file-based solvers."""
def __init__(
self,
name: str,
fields: Dict[str, OutputField],
generator: InputDataGenerator,
) -> None:
self.case_name = name
self.fields = fields
self.generator = generator
[docs]
def name(self) -> str:
"""Return the name of the case.
Returns
-------
str
Name of the case.
"""
return self.case_name
[docs]
def get_fields(self) -> List[str]:
"""Get the fields to output for this case.
Returns
-------
List[str]
Fields to output for this case.
"""
return list(self.fields.keys())
@abstractmethod
def _run_case(self, input_data: InputData, tmp_dir: str) -> bool:
"""Run the case for the given set of parameters.
Parameters
----------
input_data : InputData
Input data for this problem.
tmp_dir : str
Temporary directory to run the problem.
Returns
-------
bool
Whether the case ran successfully or not.
"""
[docs]
def run(
self,
parameters: ParameterSet,
values: np.ndarray,
tmp_dir: str,
) -> CaseResult:
"""Run the case for the given set of parameters.
Parameters
----------
parameters : ParameterSet
Parameter set for this problem.
values : np.ndarray
Current parameters to evaluate.
tmp_dir : str
Temporary directory to run the problem.
Returns
-------
CaseResult
Result of the case.
"""
# Isolate the input data into a new directory and generate the input data
tmp_dir = os.path.join(tmp_dir, self.name())
os.makedirs(tmp_dir, exist_ok=True)
param_hash = parameters.hash(values)
input_data = self.generator.generate(parameters, values, tmp_dir)
# Ensure the temporary directory is consistent
if input_data.tmp_dir != tmp_dir:
raise ValueError(
f'Input data temporary directory "{input_data.tmp_dir}" does not match '
f'the expected temporary directory "{tmp_dir}".'
)
# Ensure the input file has been generated
input_file = os.path.join(input_data.tmp_dir, input_data.input_file)
if not os.path.exists(input_file):
raise RuntimeError(
f'Input file "{input_data.input_file}" does not exist '
f'in the temporary directory "{input_data.tmp_dir}".'
)
# Sanitise the input data and output fields
for field in self.fields.values():
field.check(input_data)
# Run and time the case (no high precision timing to track start time)
begin_time = time.time()
success = self._run_case(input_data, tmp_dir)
elapsed_time = time.time() - begin_time
# Read and return the fields
responses = {name: field.get(input_data) for name, field in self.fields.items()}
return CaseResult(begin_time, elapsed_time, values, success, param_hash, responses)
[docs]
@classmethod
@abstractmethod
def get_supported_fields(cls) -> Dict[str, Type[OutputField]]:
"""Get the supported fields for this input file type.
Returns
-------
Dict[str, Type[OutputField]]
Names and supported fields for this input file type.
"""
[docs]
@classmethod
def get_dependencies(cls, input_file: str) -> Tuple[List[str], List[str]]:
"""Get the dependencies for a given input file.
Override this method to provide custom dependencies.
Parameters
----------
input_file : str
Input file to check for dependencies.
Returns
-------
Tuple[List[str], List[str]]
Substitution and copy dependencies for this input file.
"""
if not os.path.exists(input_file):
raise ValueError(f'Input file "{input_file}" does not exist.')
return [], []
[docs]
@classmethod
def read(
cls: Type[V],
name: str,
config: Dict[str, Any],
) -> V:
"""Read the case from the configuration dictionary.
Parameters
----------
name : str
Name of the case.
config : Dict[str, Any]
Configuration dictionary.
Returns
-------
InputFileCase
Case to use for this problem.
"""
# Sanitise fields
if 'fields' not in config:
raise ValueError(f'No fields defined for case "{name}".')
config_fields = config.pop('fields')
# Read each field
fields: Dict[str, OutputField] = {}
supported_fields = cls.get_supported_fields()
for field_name, field_config in config_fields.items():
# Sanitise name
if 'name' not in field_config:
raise ValueError(f'No name defined for field "{field_name}" of case "{name}".')
field_type = field_config['name']
# Check if we are using a script
if field_type == 'script':
fields[field_name] = read_custom_module(field_config, ScriptOutputField)()
else:
# Check if supported
if field_type not in supported_fields:
raise ValueError(f'Field "{field_type}" not supported for case "{name}".')
fields[field_name] = supported_fields[field_type].read(field_config)
# Check if we are using a custom generator
if 'generator' in config:
generator = read_custom_module(config.pop('generator'), InputDataGenerator)()
# Ensure we don't have dependencies with a custom generator
if 'substitution_dependencies' in config or 'copy_dependencies' in config:
raise ValueError('Dependencies not supported with custom input data generators.')
else:
# Try to find the dependencies for this case
substitution_deps, copy_deps = cls.get_dependencies(name)
# Override the dependencies if they are defined in the config
substitution_deps = config.pop('substitution_dependencies', substitution_deps)
copy_deps = config.pop('copy_dependencies', copy_deps)
# Build the generator
generator = DefaultInputDataGenerator(name, substitution_deps, copy_deps)
return cls(name, fields, generator, **config)
[docs]
class InputFileSolver(MultiCaseSolver):
"""Base class for input file-based solvers."""
[docs]
@classmethod
@abstractmethod
def get_case_class(cls) -> Type[InputFileCase]:
"""Get the case class to use for this solver.
Returns
-------
Type[InputFileCase]
InputFileCase class to use for this solver.
"""