Source code for schrodinger.application.desmond.ana

"""
Desmond analyses

Copyright Schrodinger, LLC. All rights reserved.
"""
import contextlib
import inspect
import math
import re
import traceback
from bisect import bisect
from bisect import bisect_left
from pathlib import Path
from typing import Callable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

from schrodinger.application.desmond import constants
from schrodinger.application.desmond import util
from schrodinger.application.desmond.packages import analysis
from schrodinger.application.desmond.packages import topo
from schrodinger.application.desmond.packages import traj
from schrodinger.utils import sea

QUANTITY_CLASS_MAP = {
    "angle": analysis.Angle,
    "dihedral": analysis.Torsion,
    "distance": analysis.Distance,
}


class DSC(constants.Constants):
    """
    Data selection codes. See `select_data` below for its usage.
    """
    # Makes it hard for codes here to collide with real data of users.
    ANY_VALUE = "<<any-value>>"
    NO_VALUE = "<<absence>>"


def calc_time_series(requests, model_fname, traj_fname):
    """
    :type requests: `list` of requests where each request is a `list` of
        quantity name, arguments, and (optional) time selection.
        Examples:
          [['dihedral', 1, 2, 3, 4],
           ['dihedral', 1, 2, 3, 4, {'begin':0, 'end':24.001}],]

    :return: `list` of analysis results
    """
    if not requests:
        return []

    models = list(topo.read_cms(model_fname))
    tr = traj.read_traj(traj_fname)
    sim_times = [fr.time for fr in tr]

    analyzers, times = [], []
    for req in requests:
        has_opt = isinstance(req[-1], dict)
        args = models + (req[1:-1] if has_opt else req[1:])
        times.append(req[-1] if has_opt else None)
        analyzers.append(QUANTITY_CLASS_MAP[req[0]](*args))
    results = analysis.analyze(tr, *analyzers)

    # FIXME: this is to undo the special treatment of analysis.analyze() when
    # there is only one analyzer
    if len(analyzers) == 1:
        results = [results]

    answer = []
    for res, time_sel in zip(results, times):
        if time_sel:
            left = bisect_left(sim_times, time_sel['begin'])
            right = bisect(sim_times, time_sel['end'])
            answer.append(list(zip(sim_times[left:right], res[left:right])))
        else:
            answer.append(list(zip(sim_times, res)))
    return answer


def calc_prob_profile(data, bin_width, min, max, is_periodic=False):
    """
    FIXME: To be added.
    """
    num_bin = int(float(max - min) / bin_width)
    bin = [0] * num_bin
    range = max - min
    for e in data:
        while (e > max):
            e -= range
        while (e < min):
            e += range
        i_bin = int((e - min) / bin_width + 0.5)
        if (i_bin >= num_bin):
            i_bin = i_bin - num_bin if (is_periodic) else num_bin - 1
        elif (i_bin < 0):
            i_bin = i_bin + num_bin if (is_periodic) else 0
        bin[i_bin] += 1
    result = []
    num_data = len(data)
    for i, e in enumerate(bin):
        result.append((i * bin_width + min, float(e) / num_data))
    return result


class ForEachDo(tuple):
    """
    An advanced tuple container that is able to apply any method call to itself
    to all its elements. For example::

      a = ForEachDo([" a ", "b", " c"])
      # Constructs a `ForEachDo` object with the three string elements.

      assert isinstance(a, tuple)
      # A `ForEachDo` instance is really a `tuple` instance.

      assert ("a", "b", "c") == a.strip()
      # `strip()` is applied to each element, and the results are aggregated
      # into a tuple.
    """

    def __new__(self, items):
        return tuple.__new__(ForEachDo, (items))

    def __getattr__(self, attr):
        return ForEachDo(getattr(e, attr) for e in self)

    def __call__(self, *args, **kwargs):
        return ForEachDo(e(*args, **kwargs) for e in self)


class CompositeKeySyntaxError(SyntaxError):
    pass


class ArkDbGetError(KeyError):
    pass


class ArkDbPutError(Exception):
    pass


class ArkDbDelError(Exception):
    pass


class SubtaskExecutionError(RuntimeError):
    pass


_GetExceptions = (KeyError, IndexError, TypeError, AttributeError)


class _NONVAL:
    """
    This class object is a senitel to indicate that there is no need to match
    the value. We cannot use `None` there because that's a valid value in the
    ARK format.
    """
    pass


def _check_composite_key_syntax(keys: List[str]):
    """
    Checks the syntax of a given composite key (which is decomposed and passed
    as a list of substrings).

    :raise CompositeKeySyntaxError: if syntax error found.
    """
    new_list_meta_index = None
    for i, e in enumerate(keys):
        if e in ("[i]", "[@]", "[*]", "[$]"):
            if i == 0:
                raise CompositeKeySyntaxError("No key found")
            k = keys[i - 1]
            if k.endswith("."):
                raise CompositeKeySyntaxError(
                    "Wrong key syntax: %s%s. Did you mean: %s%s?" %
                    (k, e, k[:-1], e))   # yapf: disable
            if e == "[@]" and new_list_meta_index:
                raise CompositeKeySyntaxError(
                    "Cannot use meta-index `[@]` after `%s`" %
                    new_list_meta_index)
            if e in ("[@]", "[$]"):
                new_list_meta_index = e


def _get_subkeys(key: str) -> List[str]:
    # Must use `filter` to remove empty strings.
    subkeys = list(filter(None, re.split(r"(\[i\]|\[\$\]|\[\*\]|\[@\])", key)))
    _check_composite_key_syntax(subkeys)
    return subkeys


def _get_key0_subkeys(keys: List[str]) -> Tuple[str, List[str]]:
    key0, subkeys = keys[0], keys[1:]
    if subkeys and subkeys[0].startswith('.'):
        # Removes the leading '.'.
        # An example of the original key is
        #  "Keywords[i].ResultLambda0.Keywords[i].ProtLigInter.HBondResult",
        # when decomposed, it becomes a list of subkeys:
        #  ["Keywords", "[i]", ".ResultLambda0.Keywords", "[i]",
        #   ".ProtLigInter.HBondResult"]
        # We should drop the '.' in ".ResultLambda0.Keywords" and
        # ".ProtLigInter.HBondResult"
        # We only need to deal with the 2nd element of `keys` in each recursive
        # call.
        subkeys[0] = subkeys[0][1:]
    return key0, subkeys


def _get_impl(db: sea.Sea, keys: List[str]) -> sea.Sea:
    """
    Gets a datum from the database `db` with the given key. The key is
    passed in the decomposed form as a list of strings: `keys`.

    :raises: `KeyError` or `IndexError` or `AttributeError` or `TypeError`
        if the key is not found in `db`.
    """
    k0, subkeys = _get_key0_subkeys(keys)

    if k0 == "[i]":
        for subdb in db:
            with contextlib.suppress(_GetExceptions):
                return _get_impl(subdb, subkeys)
        else:
            raise KeyError(str(keys))

    if k0 == "[*]":
        all_val = []
        has_val = False
        for subdb in db:
            with contextlib.suppress(_GetExceptions):
                val = _get_impl(subdb, subkeys)
                has_val = True
                if isinstance(val, list):
                    all_val += val
                else:
                    all_val.append(val)
        if has_val:
            return all_val
        raise KeyError(str(keys))

    if isinstance(db, sea.List):
        k0 = k0.split(".", 1)
        k0_k0, k0_subkeys = k0[0], k0[1:]
        try:
            # Does `k0_k0` match the pattern "[<number>]"?
            i = int(k0_k0.strip()[1:-1])
            subdb = db[i]
        except Exception:
            raise IndexError(f"Bad array index: {k0_k0}")
        subkeys = k0_subkeys + subkeys
    else:
        subdb = db.get_value(k0)

    if subkeys:
        return _get_impl(subdb, subkeys)
    return subdb


def _get_matched_subdb(db, keys) -> Tuple[sea.Map, List[str]]:
    """
    Finds the data field that matches the given key. This is only used by
    `_put_impl` (see below).

    For examples, say, `db` is configured as follows:

      Keywords = [
        ...  # Other elements
        {ResultLambda0 = {...}}
        {ResultLambda1 = {...}}
      ]

    and the original key is "Keywords[i].ResultLambda0.ParchedCmsFname". We
    generally do NOT a priori know the index of the "ResultLambda0" data field.
    So we use this function to find it. Note that if there were multiple
    "ResultLambda0" data fields, only the first one is returned.

    If no matching data field found, an `ArkDbPutError` will be raised.

    :type db: sea.List or sea.Map
    :type keys: List[str]
    """
    k0, subkeys = _get_key0_subkeys(keys)

    if k0 == "[i]":
        if isinstance(db, sea.List):
            for subdb in db:
                with contextlib.suppress(ArkDbPutError):
                    return _get_matched_subdb(subdb, subkeys)
    else:
        k0_k0 = k0.split(".", 1)[0]
        if isinstance(db, sea.Map) and k0_k0 in db:
            return db, keys
    raise ArkDbPutError("No data field matchng key: '%s'" % "".join(keys))


def _put_impl(db, keys, val) -> Union[sea.List, sea.Map]:
    """
    Puts (saves) `val` into the database with the given key. The key is passed
    in the decomposed form as a list of strings: `keys`.

    `db` can be `None`, and then a new `sea.List` or `sea.Map` object will be
    created for the given value `val` and returned; otherwise, `db` will be
    updated with `val` and returned.

    :type db: sea.List or sea.Map or None
    :type keys: List[str]
    :type val: List, or scalar, or an empty `dict` (`{}`), or `sea.Sea`
    """
    k0, subkeys = _get_key0_subkeys(keys)

    if k0 == "[i]":
        subdb, subkeys = _get_matched_subdb(db, keys)
        _put_impl(subdb, subkeys, val)
        return db
    elif k0 == "[@]":
        # This requires a preexisting list to make sense, so `db` must be a
        # `sea.List` object.
        k1, subsubkeys = subkeys[0], subkeys[1:]
        k1 = k1.split('.', 1)
        i = int(k1[0])
        subsubkeys = k1[1:] + subsubkeys
        db.insert(i, _put_impl(None, subsubkeys, val) if subsubkeys else val)
        return db
    elif k0 == "[$]":
        # We allow user to create a new list with this meta-index, for example,
        # say we have an existing `db` whose contents is the following:
        #   a.b.c = 1
        # Now say we want to add a new item: a.b.d = [{e.f = 1}]
        # We can do that with the syntax:
        #   db.put("a.b.d[$].e.f", [1, 2, 3])
        # where "a.b.d" is a new list. So `db` must be either `None` or
        # `sea.List`.
        # Note that this cannot be done with the syntax:
        #   db.put("a.b.d", [{e: {f: [1, 2, 3]}}])
        # because `val` cannot have `dict` objects.
        db = db or sea.List()
        db.append(_put_impl(None, subkeys, val) if subkeys else val)
        return db

    db = db or sea.Map()
    if subkeys:
        try:
            subdb = db.get_value(k0)
        except _GetExceptions:
            subdb = None
        db.set_value_fast(k0, _put_impl(subdb, subkeys, val))
    else:
        val = sea.Map() if val == {} else val
        db.set_value_fast(k0, val)
    return db


def _match_single_keyvalue(db: sea.Sea, keys: List[str], value) -> bool:
    k0, subkeys = _get_key0_subkeys(keys)

    if k0 in ["[i]", "[*]"]:
        for subdb in db:
            with contextlib.suppress(_GetExceptions):
                return _match_single_keyvalue(subdb, subkeys, value)
        return False

    if isinstance(db, sea.List):
        k0 = k0.split(".", 1)
        k0_k0, k0_subkeys = k0[0], k0[1:]
        try:
            # Does `k0_k0` match the pattern "[<number>]"?
            i = int(k0_k0.strip()[1:-1])
            subdb = db[i]
        except Exception:
            return False
        subkeys = k0_subkeys + subkeys
    else:
        subdb = db.get_value(k0)

    if subkeys:
        return _match_single_keyvalue(subdb, subkeys, value)

    if value is _NONVAL:
        return True

    # Converts `value` (str) to an object of the correct type and then compares.
    return subdb.val == sea.Map(f"typed = {value}").typed.val


KeyValues = Optional[Union[str, Iterable[str]]]


def _match_keyvalues(db: sea.Sea, keyvalues: KeyValues) -> bool:
    """
    This function tries to find all key-value pairs given in `keyvalues` in the
    database `db`. If that succeeds, it returns `True`, otherwise it returns
    `False`.

    If there is no key-value pairs to find, IOW, if `keyvalues` is `None` or an
    empty string or an empty iterable, this function returns `True`.

    Each key-value pair is a string in the format of "<key>=<value>". Note that
    the key and the value are connected by a single "=" symbol, no spaces
    allowed in the connection. Key is in the extended standard composite format
    (see the docstring of the `ArkDb` class below). Value is in the ARK format
    (note that spaces are allowed in the value). The value part is optional,
    when it's missing, the "=" symbol should be absent as well, and this
    function will only look for the key in `db` and disregard the value.
    """
    if not keyvalues:
        return True

    # Prevents a very easy type of mistakes like `keyvalues=("a.b.c")`.
    if isinstance(keyvalues, str):
        keyvalues = (keyvalues,)

    for kv in keyvalues:
        # If no value is specified, `value` will be `_NONVAL`.
        key, value = (kv.split('=', 1) + [_NONVAL])[:2]
        subkeys = _get_subkeys(key)
        if subkeys[-1] in ("[i]", "[*]"):
            raise CompositeKeySyntaxError(
                "Cannot determine array index because there is nothing to "
                "match after meta-index `[i]`")

        if "[$]" in subkeys or "[@]" in subkeys:
            raise CompositeKeySyntaxError(
                "Meta-indices `[$]` and `[@]` cannot be used for matching "
                f"data: '{key}'")

        try:
            if not _match_single_keyvalue(db, subkeys, value):
                return False
        except _GetExceptions:
            return False

    return True


def _del_impl(db: sea.Sea, keys: List[str], matches: KeyValues = None):
    """
    Deletes a datum from the database `db` with the given key. The key is
    passed in the decomposed form as a list of strings: `keys`.

    If `matches` is specified, this function will further check on the datum to
    see if it can find all key-value pairs specified in `matches` in the datum.
    If it can, the deletion will happen, otherwise will not (and no expections
    will be raised). See the docstring of `_match_keyvalues` for more detail on
    the value of `matches`.

    :raises: `KeyError` or `IndexError` or `AttributeError` or `TypeError`
        if the key is not found in `db`.
    """
    k0, subkeys = _get_key0_subkeys(keys)

    if k0 == "[i]":
        # We allow "[i]" to be at the end of the key, and in such cases
        # `subkeys` is an empty list.
        if subkeys:
            for subdb in db:
                with contextlib.suppress(_GetExceptions):
                    return _del_impl(subdb, subkeys, matches)
            else:
                raise KeyError(str(keys))
        else:
            for i, subdb in enumerate(db):
                if _match_keyvalues(subdb, matches):
                    del db[i]
                    break
            else:
                raise KeyError(str(keys))
            return

    if k0 == "[*]":
        # We allow "[*]" to be at the end of the key, and in such cases
        # `subkeys` is an empty list.
        has_instance = False
        if subkeys:
            for subdb in db:
                with contextlib.suppress(_GetExceptions):
                    _del_impl(subdb, subkeys, matches)
                    has_instance = True
        else:
            to_be_deleted = [
                i for i, subdb in enumerate(db)
                if _match_keyvalues(subdb, matches)
            ]
            for i in reversed(to_be_deleted):
                del db[i]
            has_instance = bool(to_be_deleted)
        if has_instance:
            return
        raise KeyError(str(keys))

    if isinstance(db, sea.List):
        k0 = k0.split(".", 1)
        k0_k0, k0_subkeys = k0[0], k0[1:]
        try:
            # Does `k0_k0` match the pattern "[<number>]"?
            i = int(k0_k0.strip()[1:-1])
            subdb = db[i]
        except Exception:
            raise IndexError(f"Bad array index: {k0_k0}")
        subkeys = k0_subkeys + subkeys
    else:
        subdb = db.get_value(k0)

    if subkeys:
        return _del_impl(subdb, subkeys, matches)

    if _match_keyvalues(subdb, matches):
        db.del_key(k0)


class ArkDb:
    """
    Abstracts the key-value database where analysis results are stored.
    """

    def __init__(self, fname=None, string=None, db=None):
        if db:
            assert not (bool(fname) or bool(string))
            assert isinstance(db, sea.Sea)
            self._db = db
        else:
            fname = fname and str(fname)
            assert bool(fname) ^ bool(string)
            if fname:
                with open(fname) as fh:
                    string = fh.read()
            self._db = sea.Map(string)

    def __str__(self):
        return str(self._db)

    @property
    def val(self):
        return self._db.val

    def get(self, key: str, default=ArkDbGetError):
        """
        Gets a value keyed by `key`. Note that `None` is a normal return value
        and does NOT mean that the key was not found.

        :raises CompositeKeySyntaxError: if `key` has a syntax error. You
            normally should NOT catch this exception, because this means your
            code has a syntactical error.
        :raises ArkDbGetError: if `key` is not found in the database. You can
            optionally change raising the exception to returning a default
            value by specifying the "default" argument.

        Explanation on `key`'s value:
        - The value is generally a composite key like "a.b.c[1].d", where "a",
          "b", "c", "[1]", and "d" are the subkeys or array-indices at each
          hierarchical level.
        - For array indices, sometimes the exact number is unknown a priori,
          e.g., "ResultLambda0.Keywords[<number>].ProtLigInter", where the
          <number> cannot be specified in the source code. For cases like this,
          we have to iterate over the "ResultLambda0.Keywords" list and find
          "ProtLigInter" by matching the keyword. Note that it's possible (at
          least in principle) that there may be multiple matching elements.
        - In order to express the above indexing ideas, we introduce four new
          syntax components here:
          - [i]  Iterates over elements in the list and returns the first
                 matching element. For getting, putting, finding, and deleting.
          - [*]  Iterates over elements in the list and returns a tuple of all
                 matching elements. Only for getting, finding, and deleting.
          - [$]  Insert at the end of the list. Only for putting.
          - [@]  Similar to `[$]` except that this is for insertion into an
                 arbitrary position in the list. This is to be used with a
                 number immediately followed, e.g., `[@]123`, and the number
                 specifies the position in the list. Only for putting.
          We may call these meta-indices.
          Examples:
          - "ResultLambda0.Keywords[i].ProtLigInter"
            - Gets the first "ProtLigInter" data.
          - "ResultLambda0.Keywords[*].ProtLigInter"
            - Gets all "ProtLigInter" data, and returns a tuple.
          - "ResultLambda0.Keywords[@]0.ProtLigInter"
            - Inserts a new "ProtLigInter" data at "ResultLambda0.Keywords[0]"
            - Note the difference from using "ResultLambda0.Keywords[0]", which
              is to change the existing data.
          - "ResultLambda0.Keywords[$].ProtLigInter"
            - Appends a new "ProtLigInter" data to "ResultLambda0.Keywords".
        """
        subkeys = _get_subkeys(key)

        if subkeys[-1] in ("[i]", "[*]"):
            raise CompositeKeySyntaxError(
                "Cannot determine array index because there is nothing to "
                f"match after meta-index `{subkeys[-1]}`")

        if "[$]" in subkeys or "[@]" in subkeys:
            raise CompositeKeySyntaxError(
                "Meta-indices `[$]` and `[@]` cannot be used for getting "
                f"data: '{key}'")

        try:
            val = _get_impl(self._db, subkeys)
            # Note the subtle difference in the return type between list and
            # tuple:
            # - If `val` is a `sea.List` object we return a list.
            # - If `val` is a `list` object, it must be the result of `[*]`, and
            #   we return a tuple.
            return (tuple(e.val for e in val) if isinstance(val, list) and
                    not isinstance(val, sea.List) else val.val)
        except _GetExceptions:
            if isinstance(default, type) and issubclass(default, Exception):
                # `default` is an exception _class_ (not instance).
                raise default(f"Key '{key}' not found")
            return default

    def put(self, key: str, value):
        """
        Puts a value associated with the given key into this database. `value`
        can be either of a scalar type, or of `list`, or an empty `dict` (`{}`),
        or of `sea.Sea`. `key` can be a composite key, see the docstring of
        `ArkDb.get` for detail.

        :raises CompositeKeySyntaxError: if `key` has a syntax error. You
            normally should NOT catch this exception, because this means your
            code has a syntactical error.
        :raises ArkDbPutError: if putting failed.
        """
        subkeys = _get_subkeys(key)

        if "[*]" in subkeys:
            raise CompositeKeySyntaxError(
                f"Meta-index `[*]` cannot be used for putting data: '{key}'")

        try:
            _put_impl(self._db, subkeys, value)
        except (ArkDbPutError, AttributeError) as e:
            raise ArkDbPutError("Putting data at key '%s' failed: %s" %
                                (key, e))   # yapf: disable

    def delete(self, key: str, matches: KeyValues = None, ignore_badkey=False):
        """
        Deletes a given `key` and the value from the database. If the `key` is
        not found, `ArkDbDelError` will be raised unless `ignore_badkey` is
        `True`.

        `matches`, if specified, provides one or more key-value pairs for
        checking on the value. If and only if all key-value pairs are found in
        the value, the key and the value will be deleted from the database.
        Each key-value pair is a string in the format of "<key>=<value>". Note
        that the key and the value are connected by a single "=" symbol, no
        spaces allowed in the connection. Key is in the extended standard
        composite format (see the docstring of the `ArkDb` class above). Value
        is in the ARK format (note that spaces are allowed in the value). The
        value part is optional, when it's missing, the "=" symbol should be
        absent as well, and this function will only look for the key in `db` and
        disregard the value.

        Examples::

          db.delete("a.b.c")
          db.delete("a.b.d[i].e")
          db.delete("a.b.d[i]", matches="e")
          db.delete("a.b.d[i]", matches=("e=5", "h=10"))
        """
        subkeys = _get_subkeys(key)

        if subkeys[-1] in ("[i]", "[*]") and not matches:
            raise CompositeKeySyntaxError(
                "Cannot determine array index because there is nothing to "
                f"match after meta-index `[{subkeys[-1]}]`")

        if "[$]" in subkeys or "[@]" in subkeys:
            raise CompositeKeySyntaxError(
                "Meta-indices `[$]` and `[@]` cannot be used for deleting "
                f"data: '{key}'")

        try:
            val = _del_impl(self._db, subkeys, matches)
        except _GetExceptions:
            if not ignore_badkey:
                raise ArkDbDelError(f"Key '{key}' not found")

    def find(self, key: str, picker: Union[int, Iterable[int], Callable]=None) \
        -> Union[Tuple, ForEachDo]:
        """
        Finds the given `key` and returns the corresponding data as a
        `ForEachDo` object. The `ForEachDo` object allows to iterate over the
        found data, each as a new `ArkDb` (or its subclass) object. It also
        allows us to concatenate operations on the found data.

        Example::

          db.find("stage[*].simulate").put("ensemble", "NVT")
          # Resets all simulate stages' "ensemble" parameter's value to "NVT".

        If the key is not found, this method will return `()` (i.e., empty
        tuple).

        :param picker: This is to cherry-pick the found data. The follow types
            or values are supported:
            - None      All found data will be returned.
            - int       Among the found data, a single datum as indexed by
                        `picker` will be returned. The index is zero-based.
            - List[int] Among the found data, multiple data as indexed by
                        `picker`'s elements will be returned. The indices are
                        zero-based.
            - Callable  `picker` will be called on each found data, and the
                        results will be `filter`-ed and returned.

        Example::

          db.find("stage[*].task", picker=1) \
            .put("set_family.simulate.temperature", 300)
          # Mutates the second "task" stage.

          db.find("stage[*].simulate.restrain", picker=lambda x: x.parent()) \
            .put("temperature", 400)
          # For any simulate stages with "restrain" setting, resets temperature
          # to 400.
        """
        subkeys = _get_subkeys(key)

        if subkeys[-1] in ("[i]", "[*]"):
            raise CompositeKeySyntaxError(
                "Cannot determine array index because there is nothing to "
                f"match after meta-index `{subkeys[-1]}`")

        if "[$]" in subkeys or "[@]" in subkeys:
            raise CompositeKeySyntaxError(
                "Meta-indices `[$]` and `[@]` cannot be used for getting "
                f"data: '{key}'")

        try:
            subdb = _get_impl(self._db, subkeys)
        except _GetExceptions:
            return ()

        # We must distinguish three types of `val` here:
        # - `list`      Multiple matches
        # - `sea.List`  A single match to a list datum
        # - others      A single match to a non-list datum
        if not isinstance(subdb, list) or isinstance(subdb, sea.List):
            subdb = [subdb]
        subdbs = subdb
        if picker is None:
            return ForEachDo(ArkDb(db=e) for e in subdbs)
        if isinstance(picker, int):
            picker = (picker,)
        if callable(picker):
            subdbs = tuple(filter(None, (picker(e) for e in subdbs)))
        else:
            subdbs = tuple(subdbs[i] for i in picker)

        return (subdbs and ForEachDo(ArkDb(db=e) for e in subdbs)) or ()

    def write(self, fname: str):
        with open(fname, "w") as fh:
            fh.write(str(self))


class Datum:
    """
    An instance of this class represents a particular datum in the database.
    A datum could be a scalar value, or a list/dict object. Each datum is
    assigned a key for identification in the database. The key can be accessed
    via the `key` public attribute. The actual value of the datum is obtained
    by the `val` public attribute.

    N.B.: A limitation on the `val`'s value: For putting, the value cannot be
    a `dict` object.
    """

    def __init__(self, key: Optional[str], val=None):
        """
        Creates a `Datum` object with the given `key` and the default value
        `val`.
        `key`'s value can be `None`, and in this case the `get_from` method will
        always return the default value `val`.
        """
        self.val = val
        self._key = key

    def __str__(self):
        # Useful for debugging
        return "%s(%s): key=%s, val=%s" % \
            (type(self).__name__, id(self), self.key, self.val)

    @property
    def key(self):
        return self._key

    def get_from(self, arkdb):
        """
        Gets the value of this datum from the database `arkdb`. The new value
        is used to update the public attribute `val` and is also returned.

        :raises ArkDbGetError: if the key is not found in the database.
        :raises CompositeKeySyntaxError: if the key has a syntax error.
        """
        if self.key is not None:
            self.val = arkdb.get(self.key)
        return self.val

    def put_to(self, arkdb):
        """
        Saves the value of this datum into the database `arkdb`.

        :raises ArkDbPutError: if saving the datum fails.
        :raises CompositeKeySyntaxError: if the key has a syntax error.
        """
        if not self.key:
            raise CompositeKeySyntaxError(f"invalid key: '{self.key}'")
        arkdb.put(self.key, self.val)

    def del_from(self, arkdb):
        """
        Deletes the key and the value of this datum from the database `arkdb`,
        Noop if the key is `None`.

        :raises ArkDbDelError: if the key is not found in the database.
        :raises CompositeKeySyntaxError: if the key has a syntax error.
        """
        if self.key is not None:
            self.val = arkdb.delete(self.key)


class Premise(Datum):
    """
    A premise here is a datum that must be available for a task (see the
    definition below) to be successfully executed.
    """

    def __init__(self, key):
        super().__init__(key)


class Option(Datum):
    """
    An option here is a datum that does NOT have to be available for a task
    (see the definition below) to be successfully executed.
    """
    pass


def select_data(data: Iterable[Datum], **match) -> List[Datum]:
    """
    The following are from the real world:

      Keywords = [
        {RMSD = {
           ASL = "((protein and not (m.n 3) and backbone) and not (a.e H) )"
           Frame = 0
           Panel = pl_interact_survey
           Result = [8.57678438812e-15 0.837188833342 ]
           SelectionType = Backbone
           Tab = pl_rmsd_tab
           Type = ASL
           Unit = Angstrom
         }
        }

        {RMSD = {
           ASL = "m.n 1"
           FitBy = "protein and not (m.n 3)"
           Frame = 0
           Panel = pl_interact_survey
           Result = [3.54861302804e-15 1.36992917763]
           SelectionType = Ligand
           Tab = pl_rmsd_tab
           Type = Ligand
           Unit = Angstrom
           UseSymmetry = true
         }
        }
      ]

    There are two dict data keyed by "RMSD". If, for example, we want to select
    the one with "SelectionType" being "Ligand", we can use this function for
    that:

      rmsds = arkdb.get("Keywords[*].RMSD")
      select_data(rmsds, SelectionType="Ligand")

    :param **match: Key-value pairs for matching `data`. `data`'s elements
        should be `dict` objects. All elements that have all key-value pairs
        specified by `match` are returned. Note that for floating numbers, if
        the relative or the absolute difference is less than 1E-7, the two
        numbers are considered the same.

    See `DSC` above for special codes to be used in `match`'s values.
    This function returns an empty list if no matches found.
    """
    selected = []
    for datum in data:
        if isinstance(datum, dict):
            for k, v in match.items():
                try:
                    vv = datum[k]
                except KeyError:
                    if v != DSC.NO_VALUE:
                        break
                else:
                    if ((v == DSC.NO_VALUE) or
                        (not (v == DSC.ANY_VALUE or v == vv or
                              ((isinstance(v, (float, int)) or
                                isinstance(vv, (float, int))) and math.isclose(
                                    v, vv, rel_tol=1E-7, abs_tol=1E-7))))):
                        break
            else:
                selected.append(datum)
    return selected


def expect_single_datum(data, exc, **match):
    """
    Similar to `select_data`, except that this function expects one and only one
    `dict` object that matches. If that's not the case, an exception of the type
    `type(exc)` will be raised. The error message of `exc` is used to describe
    the `key` used to get `data`.
    On success, a single `dict` object is returned.
    """
    data = select_data(data, **match)
    if len(data) != 1:
        raise type(exc)('Expected extactly 1 datum with key %s and match: %s,'
                        ' but found %d:\n%s' %
                        (exc, ", ".join("%s=%s" % e for e in match.items()),
                         len(data), data))
    return data[0]


class Task:
    """
    This is a base class. An instance of this class defines a concrete task to
    be executed. All subclasses are expected to implement the `__init__` and the
    `execute` methods. The `execute` should be either a public callable attribute
    or a public method. See `ParchTrajectoryForFepLambda` below for example.

    A task can be composed of one or more subtasks. The relationship among the
    premises of this task and its subtasks is the following:
    - If this task's premises are not met, no subtasks will be executed.
    - Failure of one subtask will NOT affect other subtasks being executed.

    Six public attributes/properties:
    - name: An arbitrary name for the task. Useful for error logging.
    - is_completed - A boolean value indicating if the particular task has been
      completed successfully.
    - results - A list of `Datum` objects as the results of the execution of the
      task. The data will be automatically put into the dababase.
    - log - A list of strings recording the error messages (if any) during the
      last execution of the task. The list is empty if there was no errors at
      all.
    - premises - A list of lists of `Premise` objects. The first list are the
      premises of this `Task` object, followed by that of the first subtask,
      and then of the second subtask, and so on. Each element list can be empty.
    - options - Similar to `premises` except that the object type is `Option`.
    """

    def __init__(self, name: str, subtasks: Optional[List] = None):
        """
        :param name: An arbitrary name. Useful for error logging.
        """
        self.name = name
        self.is_completed = False
        self.results = []
        self.errlog = []

        # A task can be composed of a list of subtasks.
        self._subtasks = subtasks or []

        # List of lists
        self._premises = None
        self._options = None

    def __str__(self):
        s = [
            "%s: %s" % (self.name, type(self).__name__),
            "  Completed: %s" % ((self.is_completed and "yes") or "no"),
            "  Results's Keys: %s" % ", ".join(
                e.key for e in self.results if isinstance(e, Datum)),
            "  Log: %s" %
            (self.errlog and
             ("\n    " + "\n    ".join(self.errlog)) or "(no errors)")
        ]
        return "\n".join(s)

    @property
    def premises(self):
        if self._premises is None:
            signature = inspect.signature(self.execute)
            self._premises = [[(name, param.annotation)
                               for name, param in signature.parameters.items()
                               if isinstance(param.annotation, Premise)]
                             ] + [sub.premises for sub in self._subtasks]
        return self._premises

    @property
    def options(self):
        if self._options is None:
            signature = inspect.signature(self.execute)
            self._options = [[(name, param.annotation)
                              for name, param in signature.parameters.items()
                              if isinstance(param.annotation, Option)]
                            ] + [sub.options for sub in self._subtasks]
        return self._options

    def clear(self):
        """
        Cleans the state of this object for a new execution.
        """
        self.is_completed = False
        self.results = []
        self.errlog = []

    def execute(self, db: ArkDb):
        """
        Executes this task. This should only be called after all premises of
        this task are met. The premises of the subtasks are ignored until the
        subtask is executed. Subclasses should implement an `execute`, either as
        an instance method, or as an instance's public callable attribute.
        After execution, all results desired to be put into the database should
        be saved as the `results` attribute.

        The first argument of `execute` should always be for the database.
        """
        if self._subtasks:
            if not execute(db, self._subtasks):
                self.errlog = collect_logs(self._subtasks)
                raise SubtaskExecutionError("Subtask execution failed.")


class ParchTrajectoryForFepLambda(Task):
    """
    Task to parch the trajectory for the given given FEP lambda state. The
    lambda state is represented by 0 and 1.

    Results are all `Datum` objects:
    - key = "ResultLambda{fep_lambda}.ParchedTrajectoryFileName", where
      `{fep_lambda}` is the value of the lambda state.
      val = Name of the parched trajectory file

    We leave this class here (1) to explain how the framework basically works
    and (2) to demonstrate how to create a concrete `Task` subclass.

    - Introduction
      From the architectural point of view, one of the common and difficult
      issues in computation is perhaps data coupling: Current computation needs
      data produced by previous ones. It's difficult because the coupling is
      implicit and across multiple programming units/modules/files, which often
      results in bugs when code change in one place implicitly breaks code
      somewhere else.

      Taking this class as an example, the task is trivial when explained at
      the conceptual level: Call the `trj_parch.py` script with properly set
      options to generated a "parched" trajectory. But when we get to the detail
      to incorporate this task in a workflow, it becomes very complicated,
      mostly because of the data coupling issue (which is the devil here): From
      the view point of this task, we have to check the following data
      dependencies:
      1. The input files (the output CMS file and the trajectory file) exist.
      2. We identify the input files by file name patterns that depend on the
         current jobname which is supposed to be stored in a (.sid) data file.
         So we have to ensure the jobname exists in the database.
         (Alternatively, we can pass the jobname through a series of function
         calls, but we won't discuss about the general issues of that approach)
      3. To call trj_parch.py, we must set the `-dew-asl` and `-fep-lambda`
         options correctly. The value for these options are either stored in
         .sid data file or passed into this class via an argument of the
         `__init__` method.
      Furthermore, when any of these conditions are not met, informative errors
      messages must be logged.
      All of these used to force the developer to write a LOT of biolerplate
      code to get/put data from the database, to check these conditions, and to
      log all errors, for even the most conceptually trivial task. So often than
      not, such boring (and repeated) code is either incomplete or not in place
      at all. And we take the risk of doing computations without verifying the
      data dependencies, until some code changes break one of the conditions.

    - Four types of data
      We must realize where the coupling comes into the architecture of our
      software. For this, it helps to categorize data into the following types
      in terms of the source of the data:
      1. Hard coded data
        - This type of data is hard coded and rarely needs to be modified
          customized. Example, `num_solvent=200`.
      2. Arguments
        - Data passed into the function by the caller code. Example,
          `fep_lambda`.
      3. From the database
        - Examples: jobname, ligand ASL, number of lambda windows.
      4. Assumptions
        - Assumptions are data generated by previous stages in a workflow but
          are out of the control of the task of interest.
          For example, we have to assume the CMS and trajectory files following
          certain naming patterns exist in the file system. In theory, the less
          assumptions, the more robust the code. But in practice, it is very
          difficult (if not impossible) to totally avoid assumptions.
      Implicit data coupling happens for the types (3) and (4) data.

    - The task framework
      The basic idea of this framework is to make the types (3) and (4) data
      more explicitly and easily defined in our code, which will then make it
      possible to automatically check their availabilities and log errors.
      For the type (3) data, we provide `Premise` and `Option` classes for
      getting the data.
      For the type (4) data, we have to rely on a convention to verify the
      assumpations. But utility functions are provided to make that easier and
      idiomatic.
      In both cases, when the data are unavailable, informative error messages
      will be automatically logged.
      The goal of this framework is to relieve the developer from writing a lot
      of biolerplate code and shift their attentions to writing reusable tasks.
    """

    def __init__(self,
                 name,
                 fep_lambda: int,
                 result_lambda: int,
                 cms_fname_pattern: str,
                 trj_fname_pattern: str,
                 out_bname_pattern: str,
                 num_solvent: int = 200):
        """
        The values of the arguments: `cms_fname_pattern`, `trj_fname_pattern`,
        and `out_bname_pattern`, are simple strings that specify f-string
        patterns to be evaluated yet to get the corresponding file names.
        Example, `"{jobname}_replica_{index}-out.cms"`, note that it's a simple
        string and uses two f-string variables `{jobname}` and `{index}`. The
        values of the f-string variables will be obtained on the fly when the
        task is executed. Currently, the following f-string variables are
        available for this task:
          {jobname}    - The FEP job's name
          {fep_lambda} - Same value as that of the argument `fep_lambda`. It's
                         either 0 or 1.
          {result_lambda} - Same value as that of the argument `result_lambda`. It's
                  either 0 or 1
          {index}      - The index number of the replica corresponding to either
                         the first lambda window or the last one, depending on
                         the value of the `fep_lambda` argument.
        """
        super().__init__(name)

        # Because the `execute` depends on the arguments of the `__init__`
        # method so we define `execute` on the fly.
        # It's possible to define `execute` as an instance method. But then we
        # need to save the `cms_fname_pattern`, etc. arguments, which are not
        # used elsewhere. It's less verbose to define `execute` as a callable
        # attribute.
        # yapf: disable
        def execute(_,
            jobname: Premise("Keywords[i].FEPSimulation.JobName"), # noqa: F821
            dew_asl: Premise(f"Keywords[i].ResultLambda{result_lambda}.LigandASL"),  # noqa: F821,F722
            replica: Premise("Keywords[i].Replica"), # noqa: F821
            ref_mae: Option("ReferenceStruct") # noqa: F821
        ):
            # yapf: enable
            """
            We define three `Premise`s for `execute`. Each of them refers to
            a datum keyed by the corresponding string in the database.
            The `Premise`s will be checked against the present database by the
            module-level `execute` function below. If any of these `Premise`s
            are not met, an error will be recorded, and this `execute` function
            will not be called.
            """
            from schrodinger.application.desmond.packages import parch

            num_win = len(replica)
            index = fep_lambda and (num_win - 1)
            cms_fname = eval(f"f'{cms_fname_pattern}'")
            cms_fname = util.gz_fname_if_exists(cms_fname)
            cms_fname = util.verify_file_exists(cms_fname)
            trj_fname = util.verify_traj_exists(eval(f"f'{trj_fname_pattern}'"))
            out_bname = eval(f"f'{out_bname_pattern}'")

            # yapf: disable
            cmd = util.commandify([
                cms_fname, trj_fname, out_bname,
                ['-output-trajectory-format', 'auto'],
                ['-dew-asl', dew_asl],
                ['-n', num_solvent],
                ['-fep-lambda', fep_lambda],
                ['-ref-mae', ref_mae]])
            # yapf: enable

            out_cms_fname, out_trj_fname = parch.main(cmd)
            result_field = f"Keywords[i].ResultLambda{result_lambda}"
            self.results = [
                Datum(f"{result_field}.ParchedCmsFname", out_cms_fname),
                Datum(f"{result_field}.ParchedTrjFname", out_trj_fname),
            ]

        self.execute = execute


class ParchTrajectoryForFep(Task):
    """
    Task to generate parched trajectories for both FEP lambda states. The lambda
    state is represented by 0 and 1.

    Results are all `Datum` objects:
    - key = "ResultLambda0.ParchedCmsFname"
      val = Name of the parched CMS file for lambda state 0: "lambda0-out.cms"
    - key = "ResultLambda1.ParchedCmsFname"
      val = Name of the parched CMS file for lambda state 1: "lambda1-out.cms"
    - key = "ResultLambda0.ParchedTrjFname"
      val = Name of the parched trajectory file for lambda state 0:
            "lambda0{ext}", where "{ext}" is the same extension of the input
            trajectory file name.
    - key = "ResultLambda1.ParchedTrjFname"
      val = Name of the parched trajectory file for lambda state 1:
            "lambda0{ext}", where "{ext}" is the same extension of the input
            trajectory file name.

    We leave this class here to demonstrate how to define a concrete `Task`
    subclass by composition.
    """

    def __init__(self, name, num_solvent=200):
        # Hardcodes the file name patterns, which are not expected to change.
        # If different patterns are used, create a new `Task`'s subclass similar
        # to this one.
        cms_fname_pattern = "{jobname}_replica{index}-out.cms"
        trj_fname_pattern = "{jobname}_replica{index}"
        out_bname_pattern = "lambda{fep_lambda}"
        args = [
            cms_fname_pattern, trj_fname_pattern, out_bname_pattern, num_solvent
        ]

        super().__init__(name, [
            ParchTrajectoryForFepLambda(name + "_lambda0", 0, 0, *args),
            ParchTrajectoryForFepLambda(name + "_lambda1", 1, 1, *args)
        ])


class ParchTrajectoryForAbsoluteFep(Task):
    """
    Task to generate the parched trajectory for the
    lambda state with the fully-interacting ligand.

    Results are all `Datum` objects:
    - key = "ResultLambda0.ParchedCmsFname"
      val = Name of the parched CMS file: "lambda0-out.cms"

    - key = "ResultLambda0.ParchedTrjFname"
      val = Name of the parched trajectory file:
            "lambda0{ext}", where "{ext}" is the same extension of the input
            trajectory file name.
    """

    def __init__(self, name, num_solvent=200):
        cms_fname_pattern = "{jobname}_replica0-out.cms"
        trj_fname_pattern = "{jobname}_replica0"
        out_bname_pattern = "lambda0"
        args = [
            cms_fname_pattern, trj_fname_pattern, out_bname_pattern, num_solvent
        ]

        # Absolute binding calculations are set up such that
        # the mutant structure of replica 0 contains the fully
        # interacting ligand. Parch will remove the reference
        # structure (dummy particle) but keep the mutant structure
        # (ligand) when fep_lambda=0.
        # Results are reported to fep_lambda=0, to make it consistent
        # with the rest of the analysis.
        # TODO: we may decide to keep the apo (fep_lambda=1) state of
        #       the protein, in which case we need to handle it here.
        fep_lambda, report_lambda = 1, 0
        super().__init__(name, [
            ParchTrajectoryForFepLambda(f"{name}_lambda0", fep_lambda,
                                        report_lambda, *args)
        ])


def execute(arkdb: ArkDb, tasks: Iterable[Task]) -> bool:
    """
    Executes one or more tasks against the given database `arkdb`.

    This function is guaranteed to do the following:
    1. This function will examine each task's premises against the database.
    2. If the premises are NOT met, it skips the task; otherwise, it will
       proceed to check the task's options against the database.
    3. After getting the premises and options data, it will call the task's
       `execute` callable object. If the execution of the task is completed
       without errors, it will set the task's `is_completed` attribute to true.
    4. During the above steps, errors (if any) will be logged in the task's
       `log` list.
    5. After doing the above for all tasks, this function will return `True` if
       all tasks are completed without errors, or `False` otherwise.
    """
    for ta in tasks:
        ta.clear()
        kwargs = {}
        for arg_name, dat in ta.premises[0]:
            try:
                dat.get_from(arkdb)
            except ArkDbGetError as e:
                ta.errlog.append(f"Premise '{dat.key}' failed: {e}")
            kwargs[arg_name] = dat.val
        if not ta.errlog:
            # Preimses are met.
            for arg_name, dat in ta.options[0]:
                try:
                    dat.get_from(arkdb)
                except ArkDbGetError as e:
                    ta.errlog.append(f"Option '{dat.key}' failed: {e}")
                kwargs[arg_name] = dat.val
            try:
                ta.execute(arkdb, **kwargs)
            except SubtaskExecutionError as e:
                ta.errlog.insert(0, f"{e}")
            except Exception as e:
                ta.errlog.append("Task execution failed:\n%s\n%s" %
                                 (e, traceback.format_exc()))
            else:
                for r in ta.results:
                    if isinstance(r, Datum):
                        r.put_to(arkdb)
                ta.is_completed = True
    return all(ta.is_completed for ta in tasks)


def collect_logs(tasks: Iterable[Task]) -> List[str]:
    """
    Iterates over the given `Task` objects, and aggregates the logs of
    uncompleted tasks into a list to return.
    The returned strings can be joined and printed out:

      print("\n".join(collect_logs(...)))

    and the text will look like the following:

task0: Task
  message
  another message
  another multiword message
task1: ConcreteTaskForTesting
  message
  another arbitrary message
  another completely arbitrary message

    Note that the above is just an example to demostrate the format as explained
    further below. Do NOT take the error messages literally. And all the error
    messages here are unrelated to each other, and any patterns you might see is
    unintended!

    So for each uncompleted task, the name and the class' name of the task will
    be printed out, and following that are the error messages of the task, each
    in a separate line indented by 2 spaces.

    Note the purpose of returning a list of strings instead of a single string
    is to make it slightly easier to further indent the text. For example, if
    you want to indent the whole text by two spaces. You can do this:

      print("  %s" % "\n  ".join(collect_logs(...)))

    which will look like the following:

  task0: Task
    message
    another message
    another multiword message
  task1: ConcreteTaskForTesting
    message
    another arbitrary message
    another completely arbitrary message

    """
    logs = []
    for ta in tasks:
        if not ta.is_completed:
            logs.append("%s: %s" % (ta.name, type(ta).__name__))
            logs.extend("  " + e for e in ta.errlog)
    return logs