Source code for fastr.core.cardinality

from abc import abstractmethod
import re
from typing import Optional, Sequence

from sympy.core.expr import Expr
from sympy.core.symbol import Symbol
from sympy import Add, Mul

from .samples import SampleIndex
from .. import exceptions
from ..helpers import log
from ..execution import inputoutputrun


[docs]class CardinalitySpec:
[docs] def __init__(self, parent): self.parent = parent
[docs] @abstractmethod def __str__(self) -> str: """ String version of the cardinality spec, should be parseable by create_cardinality """
[docs] def __repr__(self) -> str: """ Console representation of the cardinality spec """ return '<{} {}>'.format(type(self).__name__, self)
[docs] @abstractmethod def __eq__(self, other: 'CardinalitySpec') -> bool: """Test for equality"""
[docs] def __ne__(self, other: 'CardinalitySpec') -> bool: return not self == other
@property def predefined(self): """ Indicate whether the cardinality is predefined or can only be calculated after execution """ return False
[docs] def validate(self, payload: Optional[dict], cardinality: int, planning: bool = True) -> bool: """ Validate cardinality given a payload and cardinality :param payload: Payload of the corresponding job :param cardinality: Cardinality to validate :param planning: Indicate whether the is for the planning phase or not :return: Validity of the cardinality given the spec and payload """ if isinstance(cardinality, (Symbol, Expr)): return planning result = self._validate(payload, cardinality) if isinstance(result, bool): return result elif isinstance(result, (Symbol, Expr)): return True else: raise exceptions.FastrTypeError( 'Cardinality validation should be either a sympy expression' ' or boolean, found a {}'.format(type(result).__name__) )
@abstractmethod def _validate(self, payload: Optional[dict], cardinality: int) -> bool: """ Validate cardinality given a payload and cardinality :param payload: Payload of the corresponding job :param cardinality: Cardinality to validate :return: Validity of the cardinality given the spec and payload """
[docs] def calculate_planning_cardinality(self) -> Optional[int]: """ Calculate the cardinality given the node and spec, for cardinalities that only have validation and not a pre-calculable value, this return None. :return: calculated cardinality """ return None
[docs] def calculate_execution_cardinality(self, key=None) -> Optional[int]: """ Calculate the cardinality given the node and spec, during execution this should be available and not give unknowns once the data is present and the key is given. :param key: Key for which the cardinality is calculated :return: calculated cardinality """ if key is None: return None sample = self.parent.samples.get(key) if sample is None: return None return sample.cardinality
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: """ Calculate the actually cardinality when a job needs to know how many arguments to create for a non-automatic output. """ return None
[docs]class IntCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, value: int): super().__init__(parent) self.value = value
[docs] def __eq__(self, other) -> bool: """Test for equality""" if type(self) != type(other): return NotImplemented return self.value == other.value
[docs] def __str__(self) -> str: return '{}'.format(self.value)
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: return self.value
@property def predefined(self): return True def _validate(self, payload: dict, cardinality: int) -> bool: return self.value == cardinality
[docs] def calculate_planning_cardinality(self) -> int: return self.value
[docs] def calculate_execution_cardinality(self, key=None) -> int: return self.value
[docs]class MinCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, value: int): super().__init__(parent) self.value = value
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.value == other.value
[docs] def __str__(self) -> str: return '{}-*'.format(self.value)
def _validate(self, payload: dict, cardinality: int) -> bool: return cardinality >= self.value
[docs]class MaxCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, value: int): super().__init__(parent) self.value = value
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.value == other.value
[docs] def __str__(self) -> str: return '*-{}'.format(self.value)
def _validate(self, payload: dict, cardinality: int) -> bool: return cardinality <= self.value
[docs]class RangeCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, min: int, max: int): super().__init__(parent) self.min = min self.max = max
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.min == other.min and self.max == other.max
[docs] def __str__(self) -> str: return '{}-{}'.format(self.min, self.max)
def _validate(self, payload: dict, cardinality: int) -> bool: return self.min <= cardinality <= self.max
[docs]class ChoiceCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, options: Sequence[int]): super().__init__(parent) self.options = tuple(options)
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.options == other.options
[docs] def __str__(self) -> str: return '[{}]'.format(','.join(str(x) for x in self.options))
def _validate(self, payload: dict, cardinality: int) -> bool: return cardinality in self.options
[docs]class AnyCardinalitySpec(CardinalitySpec): def _validate(self, payload: dict, cardinality: int) -> bool: return True
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return True
[docs] def __str__(self) -> str: return 'any'
[docs]class AsCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, target): super().__init__(parent) self.target = target
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.target == other.target
[docs] def __str__(self) -> str: return 'as:{}'.format(self.target)
@property def predefined(self): return True @property def node(self): return self.parent.node def _validate(self, payload: dict, cardinality: int) -> bool: if payload is None: return True if self.target.endswith(']'): # Need to index an input or output to get the cardinality start = self.target.find('[') index = int(self.target[start + 1:-1]) target = self.get_target() value = payload['inputs'].get(target) if value is None: value = payload['outputs'].get(target, []) if len(value) == 0: # Argument not given, thus expected cardinality is zero expected_cardinality = 0 else: try: index_key = list(value.keys())[index] except IndexError: length = len(list(value.keys())) raise exceptions.FastrIndexError( 'Cardinality references to invalid field ' '(target {} has length {} '.format(target, length) + 'and thus does not have index {})'.format(index) ) expected_cardinality = len(value[index_key]) else: value = payload['inputs'].get(self.target) if value is None: value = payload['outputs'].get(self.target, []) expected_cardinality = len(value) return cardinality == expected_cardinality
[docs] def get_target(self) -> str: # Check if we need to index the target if self.target.endswith(']'): # Need to get length of the as input start = self.target.find('[') target = self.target[:start] else: target = self.target return target
[docs] def get_ordereddict_cardinality(self): # Get the index of the OrderedDict target start = self.target.find('[') index = int(self.target[start + 1:-1]) # Get the full target and cardinality target = self.node.inputs.get(self.get_target()) cardinality = target.cardinality() # On Job execution, target is inputoutputrun if isinstance(target, inputoutputrun.InputRun): # Multiple subinputs, index. Occurs in job execution. cardinality = target.get_subinput_cardinality(index) elif isinstance(cardinality, Add): # Multiple elements added, split. Occurs in planning. elements = cardinality.args n_elements = len(elements) if index > n_elements: raise exceptions.FastrIndexError( 'Cardinality references to invalid field ' '(target {} has length {} '.format(target, n_elements) + 'and thus does not have index {})'.format(index) ) else: cardinality = elements[index] elif isinstance(cardinality, Mul): # All elements are multplied: del multiplier, as we # only request a single element. Occurs in execution. elements = cardinality.args elements = list(elements) del elements[0] # Delete the multiplier cardinality = cardinality.func(*elements) return cardinality
[docs] def calculate_planning_cardinality(self) -> Optional[int]: # Get cardinality from target if possible check_target = self.node.inputs.get(self.get_target()) if check_target is not None: if self.target.endswith(']'): # print('Planning: ' + str(self.target) + str(check_target.cardinality())) return self.get_ordereddict_cardinality() else: return self.node.inputs.get(self.target).cardinality() else: raise exceptions.FastrCardinalityError( 'Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target) )
[docs] def calculate_execution_cardinality(self, key=None) -> Optional[int]: check_target = self.node.inputs.get(self.get_target()) if check_target is None: raise exceptions.FastrCardinalityError('Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target)) if self.target.endswith(']'): # Target is element of OrderedDict return self.get_ordereddict_cardinality() else: # Normal procedure target = check_target if key is None: # No key is used, call target without key cardinality = target.cardinality(None) elif all(x == 0 for x in target.size): # Target is empty, cardinality can be set to 0 cardinality = 0 elif target.size == (1,): # Target has only sample, it will be repeated, use first sample cardinality = target.cardinality((0,)) elif len(self.node.input_groups) == 1: # The InputGroups are not mixed, we can request the key if len(key) == len(target.size): cardinality = target.cardinality(key) else: index_map = dict(zip(self.parent.dimnames, key)) lookup = {v: dimname for dimname in self.parent.dimnames for value in self.node.parent.nodegroups.values() if dimname in value for v in value} lookup.update({x: x for x in self.parent.dimnames}) if all(x in lookup for x in target.dimnames): # Print there is broadcasting going on, we need to undo that here matched_dimnames = [lookup[x] for x in target.dimnames] matched_index = SampleIndex(index_map[x] for x in matched_dimnames) cardinality = target.cardinality(matched_index) else: raise exceptions.FastrSizeMismatchError( 'InputGroup has inconsistent size/dimension info for Input ' '{}, cannot figure out broadcasting used!'.format(target.fullid) ) else: log.debug('Unmixing key "{}" for cardinality retrieval'.format(key)) # The InputGroups are mixed, find the part of the ID relevant to this Input test = self.node.input_group_combiner.unmerge(key) index = test[target.input_group] if len(index) == len(target.size): cardinality = target.cardinality(index) else: raise exceptions.FastrSizeMismatchError('TODO: add broadcasting to this branch?') return cardinality
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: target = payload['inputs'].get(self.target) if target is None: raise exceptions.FastrCardinalityError( 'Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target) ) return len(target)
[docs]class ValueCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, target): super().__init__(parent) self.target = target
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.target == other.target
[docs] def __str__(self) -> str: return 'val:{}'.format(self.target)
@property def node(self): return self.parent.node def _validate(self, payload: dict, cardinality: int) -> bool: if payload is None: return True value = payload['inputs'].get(self.target) if value is None: value = payload['outputs'].get(self.target) if value is None: raise exceptions.FastrValueError('Cannot calculate val: type cardinality if value not in payload!') if len(value) != 1: return False try: value = int(value[0].value) except (ValueError, TypeError): return False return cardinality == value
[docs] def calculate_execution_cardinality(self, key=None) -> Optional[int]: if self.target in self.node.inputs: # We cannot access to the jobs inputs it appears, so we # check if the output has already been generated. if self.parent.samples is not None and key in self.parent.samples: value = self.parent.samples[key].data log.debug('Got val via output data result, got {}'.format(value)) return len(value) else: log.debug('Cannot get val: cardinality if there is no execution data!') return None elif self.target in self.node.outputs: # Get the value an output if key is None: return None output = self.node.outputs[self.target] if output.samples is None: return None # Try to cast via str to int (To make sure Int datatypes fares well) try: return int(str(output[key])) except exceptions.FastrKeyError: return None
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: target = payload['inputs'].get(self.target) if target is None: raise exceptions.FastrCardinalityError( 'Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target) ) if len(target) != 1: raise exceptions.FastrValueError( 'Cannot determine cardinality from multiple values ' '(requested {}, found {} as value)'.format(self, target) ) return int(str(target[0]))
[docs]def create_cardinality(desc: str, parent) -> CardinalitySpec: """ Create simplified description of the cardinality. This changes the string representation to a tuple that is easier to check at a later time. :param desc: the string version of the cardinality :param parent: the parent input or output to which this cardinality spec belongs :return: the simplified cardinality description :raises FastrCardinalityError: if the Input/Output has an incorrect cardinality description. The translation works with the following table: ==================== ============================= =============================================================== cardinality string cardinality spec description ==================== ============================= =============================================================== ``"*"``, ``any`` ``('any',) Any cardinality is allowed ``"N"`` ``('int', N)`` A cardinality of N is required ``"N-M"`` ``('range', N, M)`` A cardinality between N and M is required ``"*-M"`` ``('max', M)`` A cardinality of maximal M is required ``"N-*"`` ``('min', N)`` A cardinality of minimal N is required ``"[M,N,...,O,P]"`` ``('choice', [M,N,...,O,P])`` The cardinality should one of the given options ``"as:input_id"`` ``('as', 'input_id')`` The cardinality should match the cardinality of the given Input ``"val:input_id"`` ``('val', 'input_id')`` The cardinliaty should match the value of the given Input ==================== ============================= =============================================================== .. note:: The maximumu, minimum and range are inclusive """ if isinstance(desc, int) or re.match(r'^\d+$', desc) is not None: # N return IntCardinalitySpec(parent, int(desc)) elif desc in ['*', 'any', 'unknown']: # * (anything is okay) return AnyCardinalitySpec(parent) elif re.match(r'^\[\d+(,\d+)*]', desc) is not None: # [M,N,..,O,P] return ChoiceCardinalitySpec(parent, [int(x) for x in desc[1:-1].split(',')]) elif '-' in desc: match = re.match(r'^(\d+|\*)-(\d+|\*)$', desc) if match is None: raise exceptions.FastrCardinalityError("Not a valid cardinality description string (" + desc + ")") lower, upper = match.groups() if lower == '*' and upper == '*': # *-* (anything is okay) return AnyCardinalitySpec(parent) elif lower == '*' and upper != '*': # N-* return MaxCardinalitySpec(parent, int(upper)) elif lower != '*' and upper == '*': # *-M return MinCardinalitySpec(parent, int(lower)) else: # N-M return RangeCardinalitySpec(parent, int(lower), int(upper)) elif desc.startswith("as:"): # as:other return AsCardinalitySpec(parent, desc[3:]) elif desc.startswith("val:"): # val:other return ValueCardinalitySpec(parent, desc[4:]) else: raise exceptions.FastrCardinalityError("Not a valid cardinality description string (" + desc + ")")