Source code for schrodinger.utils.sea.evalor

"""
Module for parameter validation. See `schrodinger.utils.sea` for more details.
Copyright Schrodinger, LLC. All rights reserved.
"""
import re
import inspect
import os
from copy import deepcopy

from .sea import Atom, Map, List
from .common import boolean, is_equal, debug_print


[docs]class Evalor: """ This is the evaluator class for checking validity of parameters. """ __slots__ = [ "_map", "_err_break", "_err", "_unchecked_map", ]
[docs] def __init__(self, map, err_break="\n\n"): """ :param map: 'map' contains all parameters to be checked. """ self._map = map self._err_break = err_break self._err = "" self._unchecked_map = []
def __call__(self, arg): """ :param arg: The validation criteria. """ return _eval(self._map, arg) @property def err(self): return self._err
[docs] def is_ok(self): """ Returns true if there is no error and unchecked maps. """ return (not self._err and not self._unchecked_map)
[docs] def record_error(self, mapname=None, err=""): """ Records the error. :param mapname: The name of the checked parameter. :param err: The error message. """ debug_print("ERROR\n%s" % err) if (mapname is not None): self._err += mapname[1:] + ": " self._err += err + self._err_break
@property def unchecked_map(self): """ Returns a string that tell which parameters have not been checked. """ s = "" for k in self._unchecked_map: s += k[1:] + " " return s
[docs] def copy_from(self, ev): """ Makes a copy from 'ev'. :param ev: A 'Evalor' object. """ self._map = ev._map self._err = ev._err self._unchecked_map = ev._unchecked_map
[docs]def check_map(map, valid, ev, tag=set()): # noqa: M511 """ Checks the validity of a map. """ if (not map.has_tag(tag)): debug_print("(none is tagged with: %s)" % (", ".join(tag))) return map = map.sval _check_map(map, valid, ev, "", tag) debug_print("\nUnchecked maps:") if (ev._unchecked_map == []): debug_print("(none)") else: debug_print(ev.unchecked_map) debug_print("\nError summary:") if (ev._err == ""): debug_print("(none)") return else: debug_print(ev._err) return ev._err
def __op_mul(map, arg): """ Evaluates the "multiplication" expression and returns product of the arg[0], arg[1], arg[3], ... :param arg: The 'arg' should be a 'sea.List' object that contains two or more elements. :param map: The original map that the elements in the 'arg' refer. """ prod = 1.0 for e in arg: prod *= _eval(map, e) return prod def __op_eq(map, arg): """ Evaluates the "equal" expression and returns True the arg[0] and arg[1] are equal or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ a = _eval(map, arg[0]) b = _eval(map, arg[1]) if (isinstance(a, float) or isinstance(b, float)): return is_equal(a, b) return a == b def __op_lt(map, arg): """ Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) < _eval(map, arg[1]) def __op_le(map, arg): """ Evaluates the "less or equal" expression and returns True the arg[0] is less than or equal to arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) <= _eval(map, arg[1]) def __op_gt(map, arg): """ Evaluates the "greater than" expression and returns True the arg[0] is greater than arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) > _eval(map, arg[1]) def __op_ge(map, arg): """ Evaluates the "greater or equal" expression and returns True the arg[0] is greater than or equal to arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) >= _eval(map, arg[1]) def __op_and(map, arg): """ Evaluates the "logic and" expression and returns True if both arg[0] and arg[1] are true. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) and _eval(map, arg[1]) def __op_or(map, arg): """ Evaluates the "logic or" expression and returns True if either arg[0] or arg[1] is true. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ return _eval(map, arg[0]) or _eval(map, arg[1]) def __op_not(map, arg): """ Evaluates the "logic not" expression and returns True if arg[0] is false or False if arg[0] is true. :param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a 'ValueError' exception. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) != 1): raise ValueError( "'__op_not' function expects 1 argument, but there are %d" % len(arg)) return not _eval(map, arg[0]) def __op_at(map, arg): """ Evaluates the "at" expression and returns the referenced value. :param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a 'ValueError' exception. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) != 1): raise ValueError( "'__op_at' function expects 1 argument, but there are %d" % len(arg)) k = map[_eval(map, arg[0])] try: return k.val except AttributeError: return k def __op_minus(map, arg): """ Evaluates the "minus" expression and returns arithmatic result (the difference between two values, or the negative value). :param arg: The 'arg' should be a 'sea.List' object that contains at most two elements. More than two elements will cause a 'ValueError' exception. :param map: The original map that the elements in the 'arg' refer. """ num_arg = len(arg) if (num_arg > 2): raise ValueError( "'__op_minus' function expects 1 or 2 arguments, but there are %d" % len(arg)) if (num_arg == 1): return -_eval(map, arg[0]) else: return _eval(map, arg[0]) - _eval(map, arg[1]) def __op_cat(map, arg): """ Contatenate two strings and returns the result. :param arg: The 'arg' should be a 'sea.List' object that contains at least 1 elements. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) < 1): raise ValueError( "'__op_cat' function expects at least 1 argument, but there is none" ) ret = "" for a in arg: ret += str(_eval(map, a)) return ret def __op_sizeof(map, arg): """ Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise. :param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be ignored. :param map: The original map that the elements in the 'arg' refer. """ if (len(arg) != 1): raise ValueError( "'__op_sizeof' function expects 1 argument, but there are %d" % len(arg)) return len(_eval(map, arg[0]))
[docs]def is_powerof2(x): """ Returns True if 'x' is a power of 2, or False otherwise. """ return not (x & (x - 1))
def _regex_match(pattern): """ """ return lambda s: re.match(pattern, s) def _xchk_power2(map, valid, ev, prefix): """ This is an external checker. It checks whether an integer value is power of 2 or not. :param map: 'map' contains the value to be checked. Use 'map.val' to get the value. :param valid: 'valid' contains the validation criteria for the to-be-checked value. :param ev: The evaluator, where the error messeages are collected. :param prefix: The prefix of the checked parameter. """ val = map.val if (not is_powerof2(val)): debug_print("Error:\nValue %d is not an integer of power of 2" % val) ev.record_error(prefix, "Value %d is not an integer of power of 2" % val) else: debug_print("OK - value is an integer of powere of 2") def _xchk_file_exists(map, valid, ev, prefix): """ This is an external checker. It checks whether a file (not a dir) exists. :param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name. :param valid: 'valid' contains the validation criteria for the to-be-checked value. :param ev: The evaluator, where the error messeages are collected. :param prefix: The prefix of the checked parameter. """ val = map.val if (val != "" and not os.path.isfile(val)): debug_print("Error:\nFile not found: %s" % val) ev.record_error(prefix, "File not found: %s" % val) else: debug_print("OK - file exists") def _xchk_dir_exists(map, valid, ev, prefix): """ This is an external checker. It checks whether a dir (not a file) exists. :param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name. :param valid: 'valid' contains the validation criteria for the to-be-checked value. :param ev: The evaluator, where the error messeages are collected. :param prefix: The prefix of the checked parameter. """ val = map.val if (val != "" and not os.path.isdir(val)): debug_print("Error:\nDirectory not found: %s" % val) ev.record_error(prefix, "Directory not found: %s" % val) else: debug_print("OK - Directory exists") def _eval(map, arg): """ Evaluates the expression and returns the results. :param arg: 'arg' can be either a 'sea.List' object or a 'sea.Atom' object, representing a prefix expression. :param map: The original map that the elements in the 'arg' refer. """ if (isinstance(arg, List)): val0 = _eval(map, arg[0]) if (isinstance(val0, str)): val0 = val0.strip() if (val0 in __OP): a = arg[1:] return __OP[val0](map, arg[1:]) return [_eval(map, e) for e in arg] else: val = arg.val if (val in ['-', '@', '']): return val try: if (val[0] == "@"): k = map[val[1:]] try: return k.val except AttributeError: return k except TypeError: pass return val __OP = { "*": __op_mul, "==": __op_eq, "<": __op_lt, "<=": __op_le, ">": __op_gt, ">=": __op_gt, "&&": __op_and, "||": __op_or, "!": __op_not, "@": __op_at, "-": __op_minus, "cat": __op_cat, "sizeof": __op_sizeof, } __TYPE = { "str": str, "str1": ( str, [1, 1000000000], ), "float": float, "float+": ( float, [0, float("inf")], ), "float-": ( float, [float("-inf"), 0], ), "float0_1": ( float, [0, 1.0], ), "int": int, "int0": ( int, [0, 1000000000], ), "int1": ( int, [1, 1000000000], ), "bool": boolean, "bool0": ( boolean, [False], ), "bool1": ( boolean, [True], ), "enum": str, "list": list, "none": None, "regex": _regex_match, } __CONVERTIBLE_TO = { int: [float, str], float: [str], } __xcheck = { "power2": _xchk_power2, "file_exists": _xchk_file_exists, "dir_exists": _xchk_dir_exists, }
[docs]def reg_xcheck(name, func): """ Registers external checker. :param name: Name of the checker. :param func: Callable object that checks validity of a parameter. For interface requirement, see '_xchk_power2', or '_xchk_file_exists', or '_xchk_dir_exists' for example. """ __xcheck[name] = func
def _match(map, valid, ev, prefix, tag): """ Finds the best match. """ kk = map vv = valid ev_list = [] for vv_ in vv: try: _if = ev(vv._if) except AttributeError: pass else: debug_print("_if: {} = {}".format( str(vv._if), _if, ), False) if (_if): debug_print("True") else: debug_print("False - Skip checking the whole map.") return ev_ = deepcopy(ev) _check_map(kk, vv_, ev_, prefix) ev_list.append(ev_) if (ev_list != []): # Tries to find the best match. candidate = [ ev_list[0], ] least = len(candidate[0]._unchecked_map) for ev_ in ev_list[1:]: num = len(ev_._unchecked_map) if (num < least): candidate = [ ev_, ] least = num elif (num == least): candidate.append(ev_) best_ev = [] for ev_ in candidate: if (ev_._err == ev._err): best_ev.append(ev_) if (best_ev == []): best_ev = candidate candidate = best_ev best_ev = candidate[0] least = best_ev._err.count("Wrong type:") if (len(candidate) > 1): for ev_ in candidate[1:]: num = ev_._err.count("Wrong type:") if (num < least): best_ev = ev_ least = num ev.copy_from(best_ev) def _check_atom(atom, valid, ev, prefix): """ Checks the validity of atom. """ rr = None # Range # type debug_print(prefix + ":") debug_print(" checking its type...", False) try: t = ev(valid.type) if (t.startswith("regex:")): tt = __TYPE["regex"](t[6:]) else: tt = __TYPE[t] if (isinstance(tt, tuple)): tt, rr = tt[0], tt[1] except AttributeError: ev.record_error( prefix, "Wrong type: expecting a composite parameter, but got an atom") return except KeyError: ev.record_error( prefix, "Wrong type: %s. 'type' is likely a parameter than a description." % t) return atom_val = atom.val if (atom_val is None): if (tt is None): debug_print("OK - value None is acceptable") else: ev.record_error(prefix, "Wrong value: expecting %s, but got None" % str(tt)) return if (atom._type == str and inspect.isfunction(tt) and tt != boolean): if (tt(atom_val)): debug_print("OK - {} matches the pattern: {}".format( atom_val, t[6:])) else: ev.record_error( prefix, "Wrong type: expecting a string matching {}, but got {}".format( t[6:], atom_val, )) return elif (atom._type != tt and (atom._type not in __CONVERTIBLE_TO or tt not in __CONVERTIBLE_TO[atom._type])): ev.record_error(prefix, "Wrong type: expecting {}, but got {}".format( "boolean" if tt == boolean else str(tt), str(atom._type), )) return else: debug_print("OK - %s" % t) # range debug_print(" checking its range...", False) try: if (rr is None): rr = ev(valid.range) except AttributeError: debug_print("N/A") else: if (t == "enum" or tt == boolean): if (atom_val not in rr): ev.record_error( prefix, "Wrong value: should be one of {}, but got '{}'".format( str(rr), str(atom_val), )) else: debug_print("OK - '{}' is one of {}".format( str(atom_val), str(rr), )) elif (tt == str): if (atom._type != tt): atom_val = str(atom_val) length = len(atom_val) if (length > int(rr[1])): ev.record_error(prefix, "String is too long (%d char's)" % length) elif (length < int(rr[0])): ev.record_error( prefix, "String is too short: it must have at least %d char's" % rr[0]) else: debug_print("OK - string has %d char's" % length) else: if (atom_val > tt(rr[1]) or atom_val < tt(rr[0])): ev.record_error( prefix, "Value out of range: expecting within %s, but got '%s'" % (str(rr), str(atom_val))) else: debug_print("OK - {} is within {}".format( str(atom_val), str(rr), )) # _check try: cc = valid._check except AttributeError: pass else: debug_print(" external checking...") if (isinstance(cc, List)): for e in cc: debug_print(" %s: " % e.val, False) __xcheck[e.val](atom, valid, ev, prefix) elif (cc.val != ""): debug_print(" %s: " % cc.val, False) __xcheck[cc.val](atom, valid, ev, prefix) def _check_list(map, valid, ev, prefix, tag): """ Checks the validity of list. """ kk = map vv = valid # type debug_print(prefix + ":") debug_print(" checking its type...", False) try: t = ev(vv.type) tt = __TYPE[t] except AttributeError: ev.record_error( prefix, "Wrong type: expecting a composite parameter, but got a list") return if (tt != list): ev.record_error( prefix, "Wrong type: expecting %s, but got <type 'list'>" % str(tt)) return debug_print("OK - list") # size try: debug_print(" checking its size...", False) size = ev(vv.size) ll = len(kk) if (size > 0 and ll != size): ev.record_error( prefix, "Wrong list length: expecting %d elements, but got %d" % ( size, ll, )) elif (size < 0 and ll < -size): ev.record_error( prefix, "Wrong list length: expecting at least %d elements, but got %d" % ( -size, ll, )) else: debug_print("OK - %d" % ll) except AttributeError: pass # elem debug_print(" checking each element in list...", False) try: if (isinstance(vv.elem, List)): lv, lk = len(vv.elem), len(kk) [ _check_map(k, v, ev, ("%s[%d]" % (prefix, i)), tag) for i, k, v in zip(list(range(lv)), kk, vv.elem) ] if (lv < lk): v = vv.elem[-1] [ _check_map(kk[i], v, ev, ("%s[%d]" % (prefix, i)), tag) for i in range(lv, lk) ] else: [ _check_map(elem, vv.elem, ev, ("%s[%d]" % (prefix, i)), tag) for i, elem in enumerate(kk) ] except AttributeError: debug_print("OK - No requirement for elements") # _check try: cc = vv._check debug_print(" external checking for the whole list...") if (isinstance(cc, List)): for e in cc: debug_print(" %s: " % e.val, False) __xcheck[e.val](map, valid, ev, prefix) elif (cc.val != ""): debug_print(" %s: " % cc.val, False) __xcheck[cc.val](map, valid, ev, prefix) except AttributeError: pass def _check_map(map, valid, ev, prefix="", tag=set()): # noqa: M511 """ Checks the validity of a map. """ # _if try: _if = ev(valid._if) except AttributeError: pass else: debug_print("_if: {} = {}".format( str(valid._if), _if, ), False) if (_if): debug_print("True") else: debug_print("False - Skip checking the whole map.") return if (isinstance(valid, List)): return _match(map, valid, ev, prefix, tag) if (isinstance(map, Atom)): _check_atom(map, valid, ev, prefix) elif (isinstance(map, List)): _check_list(map, valid, ev, prefix, tag) elif (isinstance(map, Map)): # _skip try: skip = valid._skip.val except AttributeError: skip = [] else: if (not isinstance(skip, list) and skip != "all"): raise ValueError( "_skip must be either a list of strings or the string \"all\"" ) # _mapcheck try: cc = valid._mapcheck except AttributeError: pass else: debug_print(prefix + ":") debug_print(" external checking for the whole map...") if (isinstance(cc, List)): for e in cc: debug_print(" %s: " % e.val, False) __xcheck[e.val](map, valid, ev, prefix) elif (cc.val != ""): debug_print(" %s: " % cc.val, False) __xcheck[cc.val](map, valid, ev, prefix) # _enforce try: cc = valid._enforce except AttributeError: pass else: if (not isinstance(cc, List)): raise ValueError("_enforce must be a list of strings") debug_print(prefix + ":") debug_print(" enforcing keys...", False) missing_key = [e for e in cc.val if (e not in map)] missing_key = ", ".join(missing_key) if (missing_key == ""): debug_print("OK - All enforced keys present") else: debug_print("Error\nMissing keys: " + missing_key[0:-2]) ev.record_error(prefix, "Missing keys: " + missing_key[0:-2]) if ("all" != skip): # Key-value pairs key_value = [ (k, kk) for k, kk in map.key_value(tag) if (k not in skip) ] for k, kk in key_value: try: vv = valid[k] except KeyError: ev._unchecked_map.append(prefix + '.' + k) continue _check_map(kk, vv, ev, prefix + '.' + k, tag)