Source code for fastr.core.cardinality

from abc import abstractmethod
import re
from typing import Optional, Sequence

import sympy

from .samples import SampleIndex
from .. import exceptions
from ..helpers import log


[docs]class CardinalitySpec:
[docs] def __init__(self, parent): self.parent = parent
[docs] @abstractmethod def __str__(self) -> str: """ String version of the cardinality spec, should be parseable by create_cardinality """
[docs] def __repr__(self) -> str: """ Console representation of the cardinality spec """ return '<{} {}>'.format(type(self).__name__, self)
[docs] @abstractmethod def __eq__(self, other): """Test for equality"""
[docs] def __ne__(self, other): return not self == other
@property def predefined(self): """ Indicate whether the cardinality is predefined or can only be calculated after execution """ return False
[docs] def validate(self, payload: Optional[dict], cardinality: int, planning=True) -> bool: """ Validate cardinality given a payload and cardinality :param payload: Payload of the corresponding job :param cardinality: Cardinality to validate :return: Validity of the cardinality given the spec and payload """ if isinstance(cardinality, (sympy.Symbol, sympy.Expr)): return planning result = self._validate(payload, cardinality) if isinstance(result, bool): return result elif isinstance(result, (sympy.Symbol, sympy.Expr)): return True else: raise exceptions.FastrTypeError( 'Cardinality validation should be either a sympy expression' ' or boolean, found a {}'.format(type(result).__name__) )
@abstractmethod def _validate(self, payload: Optional[dict], cardinality: int) -> bool: """ Validate cardinality given a payload and cardinality :param payload: Payload of the corresponding job :param cardinality: Cardinality to validate :return: Validity of the cardinality given the spec and payload """
[docs] def calculate_planning_cardinality(self) -> Optional[int]: """ Calculate the cardinality given the node and spec, for cardinalities that only have validation and not a pre-calculable value, this return None. :param node: Node for which the cardinality is calculated :return: calculated cardinality """ return None
[docs] def calculate_execution_cardinality(self, key=None) -> Optional[int]: """ Calculate the cardinality given the node and spec, during execution this should be available and not give unknowns once the data is present and the key is given. :param key: Key for which the cardinality is calculated :return: calculated cardinality """ if key is None: return None sample = self.parent.samples.get(key) if sample is None: return None return sample.cardinality
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: """ Calculate the actually cardinality when a job needs to know how many arguments to create for a non-automatic output. """ return None
[docs]class IntCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, value: int): super().__init__(parent) self.value = value
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.value == other.value
[docs] def __str__(self) -> str: return '{}'.format(self.value)
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: return self.value
@property def predefined(self): return True def _validate(self, payload: dict, cardinality: int) -> bool: return self.value == cardinality
[docs] def calculate_planning_cardinality(self) -> int: return self.value
[docs] def calculate_execution_cardinality(self, node) -> int: return self.value
[docs]class MinCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, value: int): super().__init__(parent) self.value = value
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.value == other.value
[docs] def __str__(self) -> str: return '{}-*'.format(self.value)
def _validate(self, payload: dict, cardinality: int) -> bool: return cardinality >= self.value
[docs]class MaxCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, value: int): super().__init__(parent) self.value = value
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.value == other.value
[docs] def __str__(self) -> str: return '*-{}'.format(self.value)
def _validate(self, payload: dict, cardinality: int) -> bool: return cardinality <= self.value
[docs]class RangeCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, min: int, max: int): super().__init__(parent) self.min = min self.max = max
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.min == other.min and self.max == other.max
[docs] def __str__(self) -> str: return '{}-{}'.format(self.min, self.max)
def _validate(self, payload: dict, cardinality: int) -> bool: return self.min <= cardinality <= self.max
[docs]class ChoiceCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, options: Sequence[int]): super().__init__(parent) self.options = tuple(options)
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.options == other.options
[docs] def __str__(self) -> str: return '[{}]'.format(','.join(self.options))
def _validate(self, payload: dict, cardinality: int) -> bool: return cardinality in self.options
[docs]class AnyCardinalitySpec(CardinalitySpec):
[docs] def validate(self, payload: dict, cardinality: int) -> bool: return True
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return True
[docs] def __str__(self) -> str: return 'any'
[docs]class AsCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, target): super().__init__(parent) self.target = target
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.target == other.target
[docs] def __str__(self) -> str: return 'as:{}'.format(self.target)
@property def predefined(self): return True @property def node(self): return self.parent.node def _validate(self, payload: dict, cardinality: int) -> bool: if payload is None: return True value = payload['inputs'].get(self.target) if value is None: value = payload['outputs'].get(self.target, []) return cardinality == len(value)
[docs] def calculate_planning_cardinality(self) -> Optional[int]: # Get cardinality from target if possible target = self.node.inputs.get(self.target) if target is not None: return target.cardinality() else: raise exceptions.FastrCardinalityError( 'Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target) )
[docs] def calculate_execution_cardinality(self, key=None) -> Optional[int]: target = self.node.inputs.get(self.target) if target is None: raise exceptions.FastrCardinalityError('Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target)) if key is None: # No key is used, call target without key cardinality = target.cardinality(None) elif all(x == 0 for x in target.size): # Target is empty, cardinality can be set to 0 cardinality = 0 elif target.size == (1,): # Target has only sample, it will be repeated, use first sample cardinality = target.cardinality((0,)) elif len(self.node.input_groups) == 1: # The InputGroups are not mixed, we can request the key if len(key) == len(target.size): cardinality = target.cardinality(key) else: index_map = dict(zip(self.parent.dimnames, key)) lookup = {v: dimname for dimname in self.parent.dimnames for value in self.node.parent.nodegroups.values() if dimname in value for v in value} lookup.update({x: x for x in self.parent.dimnames}) if all(x in lookup for x in target.dimnames): # Print there is broadcasting going on, we need to undo that here matched_dimnames = [lookup[x] for x in target.dimnames] matched_index = SampleIndex(index_map[x] for x in matched_dimnames) cardinality = target.cardinality(matched_index) else: raise exceptions.FastrSizeMismatchError( 'InputGroup has inconsistent size/dimension info for Input ' '{}, cannot figure out broadcasting used!'.format(target.fullid) ) else: log.debug('Unmixing key "{}" for cardinality retrieval'.format(key)) # The InputGroups are mixed, find the part of the ID relevant to this Input test = self.node.input_group_combiner.unmerge(key) index = test[target.input_group] if len(index) == len(target.size): cardinality = target.cardinality(index) else: raise exceptions.FastrSizeMismatchError('TODO: add broadcasting to this branch?') return cardinality
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: target = payload['inputs'].get(self.target) if target is None: raise exceptions.FastrCardinalityError( 'Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target) ) return len(target)
[docs]class ValueCardinalitySpec(CardinalitySpec):
[docs] def __init__(self, parent, target): super().__init__(parent) self.target = target
[docs] def __eq__(self, other): """Test for equality""" if type(self) != type(other): return NotImplemented return self.target == other.target
[docs] def __str__(self) -> str: return 'val:{}'.format(self.target)
@property def node(self): return self.parent.node def _validate(self, payload: dict, cardinality: int) -> bool: if payload is None: return True value = payload['inputs'].get(self.target) if value is None: value = payload['outputs'].get(self.target) if value is None: raise exceptions.FastrValueError('Cannot calculate val: type cardinality if value not in payload!') if len(value) != 1: return False try: value = int(value[0].value) except (ValueError, TypeError): return False return cardinality == value
[docs] def calculate_execution_cardinality(self, key=None) -> Optional[int]: if self.target in self.node.inputs: # We cannot access to the jobs inputs it appears, so we # check if the output has already been generated. if self.parent.samples is not None and key in self.parent.samples: value = self.parent.samples[key].data log.debug('Got val via output data result, got {}'.format(value)) return len(value) else: log.debug('Cannot get val: cardinality if there is no execution data!') return None elif self.target in self.node.outputs: # Get the value an output if key is None: return None output = self.node.outputs[self.target] if output.samples is None: return None # Try to cast via str to int (To make sure Int datatypes fares well) try: return int(str(output[key])) except exceptions.FastrKeyError: return None
[docs] def calculate_job_cardinality(self, payload: dict) -> Optional[int]: target = payload['inputs'].get(self.target) if target is None: raise exceptions.FastrCardinalityError( 'Cardinality references to invalid field ' '({} is not an Input in this Node)'.format(self.target) ) if len(target) != 1: raise exceptions.FastrValueError( 'Cannot determine cardinality from multiple values ' '(requested {}, found {} as value)'.format(self, target) ) return int(str(target[0]))
[docs]def create_cardinality(desc: str, parent) -> CardinalitySpec: """ Create simplified description of the cardinality. This changes the string representation to a tuple that is easier to check at a later time. :param desc: the string version of the cardinality :parent: the parent input or output to which this cardinality spec belongs :return: the simplified cardinality description :raises FastrCardinalityError: if the Input/Output has an incorrect cardinality description. The translation works with the following table: ==================== ============================= =============================================================== cardinality string cardinality spec description ==================== ============================= =============================================================== ``"*"``, ``any`` ``('any',) Any cardinality is allowed ``"N"`` ``('int', N)`` A cardinality of N is required ``"N-M"`` ``('range', N, M)`` A cardinality between N and M is required ``"*-M"`` ``('max', M)`` A cardinality of maximal M is required ``"N-*"`` ``('min', N)`` A cardinality of minimal N is required ``"[M,N,...,O,P]"`` ``('choice', [M,N,...,O,P])`` The cardinality should one of the given options ``"as:input_id"`` ``('as', 'input_id')`` The cardinality should match the cardinality of the given Input ``"val:input_id"`` ``('val', 'input_id')`` The cardinliaty should match the value of the given Input ==================== ============================= =============================================================== .. note:: The maximumu, minimum and range are inclusive """ if isinstance(desc, int) or re.match(r'^\d+$', desc) is not None: # N return IntCardinalitySpec(parent, int(desc)) elif desc in ['*', 'any', 'unknown']: # * (anything is okay) return AnyCardinalitySpec(parent) elif re.match(r'^\[\d+(,\d+)*\]', desc) is not None: # [M,N,..,O,P] return ChoiceCardinalitySpec(parent, [int(x) for x in desc[1:-1].split(',')]) elif '-' in desc: match = re.match(r'^(\d+|\*)-(\d+|\*)$', desc) if match is None: raise exceptions.FastrCardinalityError("Not a valid cardinality description string (" + desc + ")") lower, upper = match.groups() if lower == '*' and upper == '*': # *-* (anything is okay) return AnyCardinalitySpec(parent) elif lower == '*' and upper != '*': # N-* return MaxCardinalitySpec(parent, int(upper)) elif lower != '*' and upper == '*': # *-M return MinCardinalitySpec(parent, int(lower)) else: # N-M return RangeCardinalitySpec(parent, int(lower), int(upper)) elif desc.startswith("as:"): # as:other return AsCardinalitySpec(parent, desc[3:]) elif desc.startswith("val:"): # val:other return ValueCardinalitySpec(parent, desc[4:]) else: raise exceptions.FastrCardinalityError("Not a valid cardinality description string (" + desc + ")")