"""
Module for parameter validation. See `schrodinger.utils.sea` for more details.
Copyright Schrodinger, LLC. All rights reserved.
"""
import re
import inspect
import os
from copy import deepcopy
from .sea import Atom, Map, List
from .common import boolean, is_equal, debug_print
[docs]class Evalor:
"""
This is the evaluator class for checking validity of parameters.
"""
__slots__ = [
"_map",
"_err_break",
"_err",
"_unchecked_map",
]
[docs] def __init__(self, map, err_break="\n\n"):
"""
:param map: 'map' contains all parameters to be checked.
"""
self._map = map
self._err_break = err_break
self._err = ""
self._unchecked_map = []
def __call__(self, arg):
"""
:param arg: The validation criteria.
"""
return _eval(self._map, arg)
@property
def err(self):
return self._err
[docs] def is_ok(self):
"""
Returns true if there is no error and unchecked maps.
"""
return (not self._err and not self._unchecked_map)
[docs] def record_error(self, mapname=None, err=""):
"""
Records the error.
:param mapname: The name of the checked parameter.
:param err: The error message.
"""
debug_print("ERROR\n%s" % err)
if (mapname is not None):
self._err += mapname[1:] + ": "
self._err += err + self._err_break
@property
def unchecked_map(self):
"""
Returns a string that tell which parameters have not been checked.
"""
s = ""
for k in self._unchecked_map:
s += k[1:] + " "
return s
[docs] def copy_from(self, ev):
"""
Makes a copy from 'ev'.
:param ev: A 'Evalor' object.
"""
self._map = ev._map
self._err = ev._err
self._unchecked_map = ev._unchecked_map
[docs]def check_map(map, valid, ev, tag=set()): # noqa: M511
"""
Checks the validity of a map.
"""
if (not map.has_tag(tag)):
debug_print("(none is tagged with: %s)" % (", ".join(tag)))
return
map = map.sval
_check_map(map, valid, ev, "", tag)
debug_print("\nUnchecked maps:")
if (ev._unchecked_map == []):
debug_print("(none)")
else:
debug_print(ev.unchecked_map)
debug_print("\nError summary:")
if (ev._err == ""):
debug_print("(none)")
return
else:
debug_print(ev._err)
return ev._err
def __op_mul(map, arg):
"""
Evaluates the "multiplication" expression and returns product of the arg[0], arg[1], arg[3], ...
:param arg: The 'arg' should be a 'sea.List' object that contains two or more elements.
:param map: The original map that the elements in the 'arg' refer.
"""
prod = 1.0
for e in arg:
prod *= _eval(map, e)
return prod
def __op_eq(map, arg):
"""
Evaluates the "equal" expression and returns True the arg[0] and arg[1] are equal or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
a = _eval(map, arg[0])
b = _eval(map, arg[1])
if (isinstance(a, float) or isinstance(b, float)):
return is_equal(a, b)
return a == b
def __op_lt(map, arg):
"""
Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) < _eval(map, arg[1])
def __op_le(map, arg):
"""
Evaluates the "less or equal" expression and returns True the arg[0] is less than or equal to arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) <= _eval(map, arg[1])
def __op_gt(map, arg):
"""
Evaluates the "greater than" expression and returns True the arg[0] is greater than arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) > _eval(map, arg[1])
def __op_ge(map, arg):
"""
Evaluates the "greater or equal" expression and returns True the arg[0] is greater than or equal to arg[1] or False
otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) >= _eval(map, arg[1])
def __op_and(map, arg):
"""
Evaluates the "logic and" expression and returns True if both arg[0] and arg[1] are true.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) and _eval(map, arg[1])
def __op_or(map, arg):
"""
Evaluates the "logic or" expression and returns True if either arg[0] or arg[1] is true.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
return _eval(map, arg[0]) or _eval(map, arg[1])
def __op_not(map, arg):
"""
Evaluates the "logic not" expression and returns True if arg[0] is false or False if arg[0] is true.
:param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a
'ValueError' exception.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) != 1):
raise ValueError(
"'__op_not' function expects 1 argument, but there are %d" %
len(arg))
return not _eval(map, arg[0])
def __op_at(map, arg):
"""
Evaluates the "at" expression and returns the referenced value.
:param arg: The 'arg' should be a 'sea.List' object that contain only 1 element. More than 1 elements will cause a
'ValueError' exception.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) != 1):
raise ValueError(
"'__op_at' function expects 1 argument, but there are %d" %
len(arg))
k = map[_eval(map, arg[0])]
try:
return k.val
except AttributeError:
return k
def __op_minus(map, arg):
"""
Evaluates the "minus" expression and returns arithmatic result (the difference between two values, or the negative value).
:param arg: The 'arg' should be a 'sea.List' object that contains at most two elements. More than two elements will cause a
'ValueError' exception.
:param map: The original map that the elements in the 'arg' refer.
"""
num_arg = len(arg)
if (num_arg > 2):
raise ValueError(
"'__op_minus' function expects 1 or 2 arguments, but there are %d" %
len(arg))
if (num_arg == 1):
return -_eval(map, arg[0])
else:
return _eval(map, arg[0]) - _eval(map, arg[1])
def __op_cat(map, arg):
"""
Contatenate two strings and returns the result.
:param arg: The 'arg' should be a 'sea.List' object that contains at least 1 elements.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) < 1):
raise ValueError(
"'__op_cat' function expects at least 1 argument, but there is none"
)
ret = ""
for a in arg:
ret += str(_eval(map, a))
return ret
def __op_sizeof(map, arg):
"""
Evaluates the "less than" expression and returns True the arg[0] is less than arg[1] or False otherwise.
:param arg: The 'arg' should be a 'sea.List' object that contains two elements. Elements beyond the first two will be
ignored.
:param map: The original map that the elements in the 'arg' refer.
"""
if (len(arg) != 1):
raise ValueError(
"'__op_sizeof' function expects 1 argument, but there are %d" %
len(arg))
return len(_eval(map, arg[0]))
[docs]def is_powerof2(x):
"""
Returns True if 'x' is a power of 2, or False otherwise.
"""
return not (x & (x - 1))
def _regex_match(pattern):
"""
"""
return lambda s: re.match(pattern, s)
def _xchk_power2(map, valid, ev, prefix):
"""
This is an external checker. It checks whether an integer value is power of 2 or not.
:param map: 'map' contains the value to be checked. Use 'map.val' to get the value.
:param valid: 'valid' contains the validation criteria for the to-be-checked value.
:param ev: The evaluator, where the error messeages are collected.
:param prefix: The prefix of the checked parameter.
"""
val = map.val
if (not is_powerof2(val)):
debug_print("Error:\nValue %d is not an integer of power of 2" % val)
ev.record_error(prefix,
"Value %d is not an integer of power of 2" % val)
else:
debug_print("OK - value is an integer of powere of 2")
def _xchk_file_exists(map, valid, ev, prefix):
"""
This is an external checker. It checks whether a file (not a dir) exists.
:param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name.
:param valid: 'valid' contains the validation criteria for the to-be-checked value.
:param ev: The evaluator, where the error messeages are collected.
:param prefix: The prefix of the checked parameter.
"""
val = map.val
if (val != "" and not os.path.isfile(val)):
debug_print("Error:\nFile not found: %s" % val)
ev.record_error(prefix, "File not found: %s" % val)
else:
debug_print("OK - file exists")
def _xchk_dir_exists(map, valid, ev, prefix):
"""
This is an external checker. It checks whether a dir (not a file) exists.
:param map: 'map' contains the value to be checked. Use 'map.val' to get the valuefile name.
:param valid: 'valid' contains the validation criteria for the to-be-checked value.
:param ev: The evaluator, where the error messeages are collected.
:param prefix: The prefix of the checked parameter.
"""
val = map.val
if (val != "" and not os.path.isdir(val)):
debug_print("Error:\nDirectory not found: %s" % val)
ev.record_error(prefix, "Directory not found: %s" % val)
else:
debug_print("OK - Directory exists")
def _eval(map, arg):
"""
Evaluates the expression and returns the results.
:param arg: 'arg' can be either a 'sea.List' object or a 'sea.Atom' object, representing a prefix expression.
:param map: The original map that the elements in the 'arg' refer.
"""
if (isinstance(arg, List)):
val0 = _eval(map, arg[0])
if (isinstance(val0, str)):
val0 = val0.strip()
if (val0 in __OP):
a = arg[1:]
return __OP[val0](map, arg[1:])
return [_eval(map, e) for e in arg]
else:
val = arg.val
if (val in ['-', '@', '']):
return val
try:
if (val[0] == "@"):
k = map[val[1:]]
try:
return k.val
except AttributeError:
return k
except TypeError:
pass
return val
__OP = {
"*": __op_mul,
"==": __op_eq,
"<": __op_lt,
"<=": __op_le,
">": __op_gt,
">=": __op_gt,
"&&": __op_and,
"||": __op_or,
"!": __op_not,
"@": __op_at,
"-": __op_minus,
"cat": __op_cat,
"sizeof": __op_sizeof,
}
__TYPE = {
"str": str,
"str1": (
str,
[1, 1000000000],
),
"float": float,
"float+": (
float,
[0, float("inf")],
),
"float-": (
float,
[float("-inf"), 0],
),
"float0_1": (
float,
[0, 1.0],
),
"int": int,
"int0": (
int,
[0, 1000000000],
),
"int1": (
int,
[1, 1000000000],
),
"bool": boolean,
"bool0": (
boolean,
[False],
),
"bool1": (
boolean,
[True],
),
"enum": str,
"list": list,
"none": None,
"regex": _regex_match,
}
__CONVERTIBLE_TO = {
int: [float, str],
float: [str],
}
__xcheck = {
"power2": _xchk_power2,
"file_exists": _xchk_file_exists,
"dir_exists": _xchk_dir_exists,
}
[docs]def reg_xcheck(name, func):
"""
Registers external checker.
:param name: Name of the checker.
:param func: Callable object that checks validity of a parameter. For interface requirement, see '_xchk_power2', or
'_xchk_file_exists', or '_xchk_dir_exists' for example.
"""
__xcheck[name] = func
def _match(map, valid, ev, prefix, tag):
"""
Finds the best match.
"""
kk = map
vv = valid
ev_list = []
for vv_ in vv:
try:
_if = ev(vv._if)
except AttributeError:
pass
else:
debug_print("_if: {} = {}".format(
str(vv._if),
_if,
), False)
if (_if):
debug_print("True")
else:
debug_print("False - Skip checking the whole map.")
return
ev_ = deepcopy(ev)
_check_map(kk, vv_, ev_, prefix)
ev_list.append(ev_)
if (ev_list != []):
# Tries to find the best match.
candidate = [
ev_list[0],
]
least = len(candidate[0]._unchecked_map)
for ev_ in ev_list[1:]:
num = len(ev_._unchecked_map)
if (num < least):
candidate = [
ev_,
]
least = num
elif (num == least):
candidate.append(ev_)
best_ev = []
for ev_ in candidate:
if (ev_._err == ev._err):
best_ev.append(ev_)
if (best_ev == []):
best_ev = candidate
candidate = best_ev
best_ev = candidate[0]
least = best_ev._err.count("Wrong type:")
if (len(candidate) > 1):
for ev_ in candidate[1:]:
num = ev_._err.count("Wrong type:")
if (num < least):
best_ev = ev_
least = num
ev.copy_from(best_ev)
def _check_atom(atom, valid, ev, prefix):
"""
Checks the validity of atom.
"""
rr = None # Range
# type
debug_print(prefix + ":")
debug_print(" checking its type...", False)
try:
t = ev(valid.type)
if (t.startswith("regex:")):
tt = __TYPE["regex"](t[6:])
else:
tt = __TYPE[t]
if (isinstance(tt, tuple)):
tt, rr = tt[0], tt[1]
except AttributeError:
ev.record_error(
prefix,
"Wrong type: expecting a composite parameter, but got an atom")
return
except KeyError:
ev.record_error(
prefix,
"Wrong type: %s. 'type' is likely a parameter than a description." %
t)
return
atom_val = atom.val
if (atom_val is None):
if (tt is None):
debug_print("OK - value None is acceptable")
else:
ev.record_error(prefix,
"Wrong value: expecting %s, but got None" % str(tt))
return
if (atom._type == str and inspect.isfunction(tt) and tt != boolean):
if (tt(atom_val)):
debug_print("OK - {} matches the pattern: {}".format(
atom_val, t[6:]))
else:
ev.record_error(
prefix,
"Wrong type: expecting a string matching {}, but got {}".format(
t[6:],
atom_val,
))
return
elif (atom._type != tt and (atom._type not in __CONVERTIBLE_TO or
tt not in __CONVERTIBLE_TO[atom._type])):
ev.record_error(prefix, "Wrong type: expecting {}, but got {}".format(
"boolean" if tt == boolean else str(tt),
str(atom._type),
))
return
else:
debug_print("OK - %s" % t)
# range
debug_print(" checking its range...", False)
try:
if (rr is None):
rr = ev(valid.range)
except AttributeError:
debug_print("N/A")
else:
if (t == "enum" or tt == boolean):
if (atom_val not in rr):
ev.record_error(
prefix,
"Wrong value: should be one of {}, but got '{}'".format(
str(rr),
str(atom_val),
))
else:
debug_print("OK - '{}' is one of {}".format(
str(atom_val),
str(rr),
))
elif (tt == str):
if (atom._type != tt):
atom_val = str(atom_val)
length = len(atom_val)
if (length > int(rr[1])):
ev.record_error(prefix,
"String is too long (%d char's)" % length)
elif (length < int(rr[0])):
ev.record_error(
prefix,
"String is too short: it must have at least %d char's" %
rr[0])
else:
debug_print("OK - string has %d char's" % length)
else:
if (atom_val > tt(rr[1]) or atom_val < tt(rr[0])):
ev.record_error(
prefix,
"Value out of range: expecting within %s, but got '%s'" %
(str(rr), str(atom_val)))
else:
debug_print("OK - {} is within {}".format(
str(atom_val),
str(rr),
))
# _check
try:
cc = valid._check
except AttributeError:
pass
else:
debug_print(" external checking...")
if (isinstance(cc, List)):
for e in cc:
debug_print(" %s: " % e.val, False)
__xcheck[e.val](atom, valid, ev, prefix)
elif (cc.val != ""):
debug_print(" %s: " % cc.val, False)
__xcheck[cc.val](atom, valid, ev, prefix)
def _check_list(map, valid, ev, prefix, tag):
"""
Checks the validity of list.
"""
kk = map
vv = valid
# type
debug_print(prefix + ":")
debug_print(" checking its type...", False)
try:
t = ev(vv.type)
tt = __TYPE[t]
except AttributeError:
ev.record_error(
prefix,
"Wrong type: expecting a composite parameter, but got a list")
return
if (tt != list):
ev.record_error(
prefix, "Wrong type: expecting %s, but got <type 'list'>" % str(tt))
return
debug_print("OK - list")
# size
try:
debug_print(" checking its size...", False)
size = ev(vv.size)
ll = len(kk)
if (size > 0 and ll != size):
ev.record_error(
prefix,
"Wrong list length: expecting %d elements, but got %d" % (
size,
ll,
))
elif (size < 0 and ll < -size):
ev.record_error(
prefix,
"Wrong list length: expecting at least %d elements, but got %d"
% (
-size,
ll,
))
else:
debug_print("OK - %d" % ll)
except AttributeError:
pass
# elem
debug_print(" checking each element in list...", False)
try:
if (isinstance(vv.elem, List)):
lv, lk = len(vv.elem), len(kk)
[
_check_map(k, v, ev, ("%s[%d]" % (prefix, i)), tag)
for i, k, v in zip(list(range(lv)), kk, vv.elem)
]
if (lv < lk):
v = vv.elem[-1]
[
_check_map(kk[i], v, ev, ("%s[%d]" % (prefix, i)), tag)
for i in range(lv, lk)
]
else:
[
_check_map(elem, vv.elem, ev, ("%s[%d]" % (prefix, i)), tag)
for i, elem in enumerate(kk)
]
except AttributeError:
debug_print("OK - No requirement for elements")
# _check
try:
cc = vv._check
debug_print(" external checking for the whole list...")
if (isinstance(cc, List)):
for e in cc:
debug_print(" %s: " % e.val, False)
__xcheck[e.val](map, valid, ev, prefix)
elif (cc.val != ""):
debug_print(" %s: " % cc.val, False)
__xcheck[cc.val](map, valid, ev, prefix)
except AttributeError:
pass
def _check_map(map, valid, ev, prefix="", tag=set()): # noqa: M511
"""
Checks the validity of a map.
"""
# _if
try:
_if = ev(valid._if)
except AttributeError:
pass
else:
debug_print("_if: {} = {}".format(
str(valid._if),
_if,
), False)
if (_if):
debug_print("True")
else:
debug_print("False - Skip checking the whole map.")
return
if (isinstance(valid, List)):
return _match(map, valid, ev, prefix, tag)
if (isinstance(map, Atom)):
_check_atom(map, valid, ev, prefix)
elif (isinstance(map, List)):
_check_list(map, valid, ev, prefix, tag)
elif (isinstance(map, Map)):
# _skip
try:
skip = valid._skip.val
except AttributeError:
skip = []
else:
if (not isinstance(skip, list) and skip != "all"):
raise ValueError(
"_skip must be either a list of strings or the string \"all\""
)
# _mapcheck
try:
cc = valid._mapcheck
except AttributeError:
pass
else:
debug_print(prefix + ":")
debug_print(" external checking for the whole map...")
if (isinstance(cc, List)):
for e in cc:
debug_print(" %s: " % e.val, False)
__xcheck[e.val](map, valid, ev, prefix)
elif (cc.val != ""):
debug_print(" %s: " % cc.val, False)
__xcheck[cc.val](map, valid, ev, prefix)
# _enforce
try:
cc = valid._enforce
except AttributeError:
pass
else:
if (not isinstance(cc, List)):
raise ValueError("_enforce must be a list of strings")
debug_print(prefix + ":")
debug_print(" enforcing keys...", False)
missing_key = [e for e in cc.val if (e not in map)]
missing_key = ", ".join(missing_key)
if (missing_key == ""):
debug_print("OK - All enforced keys present")
else:
debug_print("Error\nMissing keys: " + missing_key[0:-2])
ev.record_error(prefix, "Missing keys: " + missing_key[0:-2])
if ("all" != skip):
# Key-value pairs
key_value = [
(k, kk) for k, kk in map.key_value(tag) if (k not in skip)
]
for k, kk in key_value:
try:
vv = valid[k]
except KeyError:
ev._unchecked_map.append(prefix + '.' + k)
continue
_check_map(kk, vv, ev, prefix + '.' + k, tag)