Source code for schrodinger.application.jaguar.workflow_input

"""
Functions and classes for defining the input to a Workflow workflow.
"""
# Contributors: Mark A. Watson, Leif D. Jacobson, Daniel S. Levine

import abc
import os
import time
import copy

import yaml
import numpy as np

from collections import defaultdict

import schrodinger.application.jaguar.workflow_validation as wv
from schrodinger.application.jaguar import utils as jag_utils
from schrodinger.application.jaguar import file_logger
from schrodinger.application.jaguar.autots_bonding import Coordinate
from schrodinger.application.jaguar.exceptions import JaguarUserFacingException
from schrodinger.application.jaguar.input import JaguarInput
from schrodinger.application.jaguar.validation import \
    keyword_value_pair_is_valid

#------------------------------------------------------------------------------


class WorkflowInput(abc.ABC):
    """
    A Base class for specifying parameters in workflow calculations.

    The following functions must be defined in the inheriting class:
        - validate(self)
        - generate_keywords(self)

    """

    @abc.abstractmethod
    def generate_keywords(self):
        """
        Initialize the list of possible keywords

        :return type: dictionary (keys are lowercase string,
                      values are WorkflowKeyword)
        :return param: key, WorkflowKeyword pairs specific to this
                       job's input file
        """
        return {}

    @abc.abstractmethod
    def validate(self):
        """
        Perform a self-consistency check of all currently set keywords.

        write for inheriting class
        :return param: Whether or no input passes workflow-specific
                       validation checks
        :return type: bool
        """
        msg = "This function must be specified in a derived Workflow class"
        raise NotImplementedError(msg)

    input_file_keys = []
    workflow_name = 'Workflow'

    def __init__(self,
                 inputfile=None,
                 keywords=None,
                 jaguar_keywords=None,
                 jobname=None):
        """
        Create a WorkflowInput instance.
        If a keyword is specified in both 'inputfile' and 'keywords',
        then the values in 'keywords' will be set preferrentially.
        This also applies to 'jaguar_keywords'.

        :type  inputfile: str
        :param inputfile: Path to a Workflow input file

        :type  keywords: dict
        :param keywords: Workflow keyword/value pairs

        :type  jaguar_keywords: dict
        :param jaguar_keywords: Jaguar &gen section keyword/value pairs

        :type jobname: string
        :param jobname: Name of job, if it is not None it will be set to
                        the basename of the input file name.
        """

        self._inputfile = None

        self._keywords = self.generate_keywords()

        self._constraints = defaultdict(list)
        self._frozen_atoms = defaultdict(list)

        # Initialize JaguarInput instance
        self._jaguarinput = JaguarInput()

        # Set user-defined keyword values.
        self._jaguar_user_keys = dict()
        if inputfile is not None:
            self.read(inputfile)
            self._inputfile = inputfile

        if keywords is not None:
            self.setValues(keywords)
        if jaguar_keywords is not None:
            self.setJaguarValues(jaguar_keywords)
            self._jaguar_user_keys.update(jaguar_keywords)

        self.setJobname(jobname)

    def __iter__(self):
        """
        Provide convenient way to iterate over the keywords.
        """
        return iter(self._keywords)

    @property
    def keywords(self):
        return self._keywords

    @property
    def values(self):
        """
        Support access to WorkflowKeyword values via attribute syntax.
        e.g.
            wi = WorkflowInput()
            print wi.values.optimize   # print 'optimize' keyword value
        """

        class KeywordValuesNamespace(object):

            def __init__(self, inp):
                super(KeywordValuesNamespace, self).__setattr__('inp', inp)

            def __getattr__(self, name):
                return self.inp.getValue(name)

            def __setattr__(self, name, value):
                return self.inp.setValue(name, value)

        return KeywordValuesNamespace(self)

    def getValue(self, keyword):
        """
        Return the value for Workflow keyword.
        The return type depends on the keyword.

        :type  keyword: string
        :param keyword: name of keyword

        :raise WorkflowKeywordError if no keyword found
        """

        key = keyword.lower()
        if key in self._keywords:
            return self._keywords[key].value
        else:
            raise wv.WorkflowKeywordError(key, list(self._keywords))

    def setValue(self, keyword, value):
        """
        Set the Workflow keyword 'keyword' to value 'value'.
        Note that there may be type-checking and conversion
        by the WorkflowKeyword class.

        If 'value' is None, the keyword will be reset.

        :type  keyword: string
        :param keyword: name of keyword

        :type  value: anytype
        :param value: value of keyword

        :raise WorkflowKeywordError if no keyword found
        """

        key = keyword.lower()
        if key not in self._keywords:
            raise wv.WorkflowKeywordError(key, list(self._keywords))

        if value is not None:
            self._keywords[key].value = value
        else:
            self.resetKey(key)

    __getitem__ = getValue
    __setitem__ = setValue

    def setValues(self, keywords):
        """
        Set multiple Workflow keywords.

        :type  keywords: dict of string/anytype pairs
        :param keywords: keyword/value pairs
        """

        for k, v in keywords.items():
            self.setValue(k, v)

    def getJaguarValue(self, key):
        """
        Return the value for Jaguar keyword 'key'.
        The return type depends on the keyword.

        :type  key: string
        :param key: name of keyword
        """

        return self._jaguarinput.getValue(key)

    def setJaguarValue(self, key, value):
        """
        Set the Jaguar &gen section keyword 'key' to value 'value'.

        :type  key: string
        :param key: name of keyword

        :type  value: anytype
        :param value: value of keyword

        :raise JaguarKeywordException if keyword is invalid
        """

        keyword_value_pair_is_valid(key, str(value))
        self._jaguarinput.setValue(key, value)

    def setJaguarValues(self, keywords):
        """
        Set multiple Jaguar &gen section keywords.

        :type  keywords: dict of string/anytype pairs
        :param keywords: Jaguar &gen section keyword/value pairs
        """

        # store the jaguar keywords that have been set by the user (as a set)
        for k, v in keywords.items():
            self.setJaguarValue(k, v)

    def getJaguarNonDefault(self):
        """
        Return a dictionary of all non-default Jaguar keys except 'multip' and
        'molchg', which must be retrieved explicitly.

        """
        jinp_keys = self._jaguarinput.getNonDefault()
        for k, v in self._jaguar_user_keys.items():
            if k not in jinp_keys:
                jinp_keys[k] = v
            elif jinp_keys[k] != v:
                raise JaguarUserFacingException(
                    "Inconsistency in Jaguar keywords of workflow")
        return jinp_keys

    def resetKey(self, keyword):
        """
        Reset keyword to default state.

        :type  keyword: string
        :param keyword: name of keyword
        """

        key = keyword.lower()
        if key in self._keywords:
            self._keywords[key].reset()
        else:
            raise wv.WorkflowKeywordError(key, list(self._keywords))

    def resetAll(self):
        """
        Reset all keywords to their default states.
        """

        for key in self._keywords:
            self.resetKey(key)

        #FIXME this could be done in a different way
        self._jaguarinput = JaguarInput()

    def getDefault(self, keyword):
        """
        Return the default value for Workflow keyword 'keyword'.
        The return type depends on the keyword.

        :type  keyword: string
        :param keyword: name of keyword

        :raise WorkflowKeywordError if no keyword found
        """

        key = keyword.lower()
        if key in self._keywords:
            return self._keywords[key].default
        else:
            raise wv.WorkflowKeywordError(key, list(self._keywords))

    def getNonDefaultKeys(self):
        """
        Return a dictionary of all non-default-value WorkflowKeyword
        instances indexed by name.
        """

        nondefault_keywords = {}
        for key, kwd in self._keywords.items():
            if kwd.isNonDefault():
                nondefault_keywords[kwd.name] = kwd
        return nondefault_keywords

    def isNonDefault(self, keyword):
        """
        Has the specified keyword been set to a non-default value?

        :type  keyword: str
        :param keyword: The key to check

        :return: True if the specified keyword is set to a non-default value.
                 False otherwise.
        """

        key = keyword.lower()
        if key in self._keywords:
            return self._keywords[key].isNonDefault()
        else:
            raise wv.WorkflowKeywordError(key, list(self._keywords))

    def setConstraints(self, constraints):
        """
        Set the constraints.

        :type constraints: dict of str: Coordinate instance
        :param constraints: dictionary relating structure title
                            to a list of Coordinate instances which
                            define the constraints
        """
        for title in constraints:
            assert all(
                len(c.indices) > 1 and len(c.indices) <= 4
                for c in constraints[title])

        self._constraints = copy.deepcopy(constraints)

    def getConstraints(self):
        """
        Return the constraints defined in the input file

        :return: a dict relating structure titles to constraints
                 the constraints are given as a list of
                 Coordinate instances and the titles refer to
                 the titles of the input molecules.
        """
        return dict(self._constraints)

    def setFrozenAtoms(self, frozen_atoms):
        """
        Set the constraints.

        :type frozen_atoms: dict of str: Coordinate instance
        :param frozen_atoms: dictionary relating structure title
                            to a list of Coordinate instances which
                            define the frozen atoms
        """
        for title in frozen_atoms:
            assert all(len(c.indices) == 1 for c in frozen_atoms[title])

        self._frozen_atoms = copy.deepcopy(frozen_atoms)

    def getFrozenAtoms(self):
        """
        Return the frozen atoms defined in the input file
        :return: a dict relating structure titles to constraints
                 the constraints are given as a list of
                 Coordinate instances and the titles refer to
                 the titles of the input molecules.
        """
        return dict(self._frozen_atoms)

    def validate_jaguar_keywords(self, sts):
        """
        Perform a check to ensure that Jaguar keywords are not
        set in a way that cannot be handled.

        :param sts: Structures whose basis needs validating
        :type sts: list of Structure objects
        """

        basis = self.getJaguarValue('basis')
        wv.basis_set_is_valid(sts, basis)

    def save(self, name):
        """
        Create a Workflow input file called 'name' in the current working
        directory based on this class instance.
        Only write the non-default keyword values.

        :type  name: str
        :param name: Path to a Workflow input file
        """

        with open(name, 'w') as fh:
            # Write non-default Workflow keywords
            rkeys = self.getNonDefaultKeys()
            if rkeys:
                for k, kwd in rkeys.items():
                    line = kwd.name + ' = ' + str(kwd.value)
                    fh.write(line + '\n')
            # Write non-default Jaguar &gen section keywords
            jkeys = self.getJaguarNonDefault()
            if jkeys:
                fh.write('&JaguarKeywords\n')
                for k, v in jkeys.items():
                    line = k + ' = ' + str(v)
                    fh.write(line + '\n')
                fh.write('&\n')
            # Write constraints
            if self.have_constraints() or self.have_frozen_atoms():
                fh.write('&Constraints\n')
                for k, constraints in {**self._constraints, **self._frozen_atoms}.items():
                    for constraint in constraints:
                        constraint_string = [k]
                        constraint_string.extend(map(str, constraint.indices))
                        if len(constraint.indices) > 1:
                            constraint_string.append(
                                "%12.4f" % constraint.value)
                        fh.write(" ".join(constraint_string) + "\n")
                fh.write('&\n')

    def setJobname(self, jobname):
        """
        Set the attribute jobname.

        :type jobname: string
        :param jobname: input name of job.
                        If jobname is None we try to use self._inputfile.
                        If that is also None we assign a unique name.
        """

        if jobname is not None:
            self.jobname = jobname
        elif self._inputfile is not None:
            name, ext = os.path.splitext(self._inputfile)
            self.jobname = os.path.basename(name)
        else:
            # assign randomly based on the workflow name and a time stamp
            stamp = str(time.time())
            self.jobname = jag_utils.get_jobname(self.workflow_name, stamp)

    def read(self, inputfile):
        """
        Read an existing Workflow input file.
        Any keywords specified in the input file will override
        existing values in this WorkflowInput instance.

        Jaguar &gen section keywords are defined like:
            &JaguarKeywords
              key=val
              key=val
              ...
            &

        Constraints can be defined with
            &Constraints
                st_title atom_index1 atom_index2... value
            &

        :type  inputfile: str
        :param inputfile: Path to a Workflow input file
        """

        if not os.path.exists(inputfile):
            raise IOError("No such file: '%s'" % inputfile)

        with open(inputfile, 'r') as fh:
            for line in fh:
                if line.strip().startswith('&JaguarKeywords'):
                    # Parse Jaguar keywords &gen section
                    line = next(fh)
                    while line.partition('#')[0].strip() != '&':
                        key, value = self._parse_keyword_line(line)
                        if key is not None:
                            self.setJaguarValue(key, value)
                            self._jaguar_user_keys[key] = value
                        line = next(fh)
                elif line.strip().startswith('&Constraints'):
                    # parse constraints
                    line = next(fh)
                    while line.partition('#')[0].strip() != '&':
                        title, constraint = self._parse_constraint_line(line)
                        if title is not None and constraint is not None:
                            if len(constraint.indices) == 1:
                                self._frozen_atoms[title].append(constraint)
                            else:
                                self._constraints[title].append(constraint)
                        line = next(fh)
                else:
                    # Parse Workflow keywords
                    key, value = self._parse_keyword_line(line)
                    if key is not None:
                        self.setValue(key, value)

    def remove_input_file_paths(self):
        """
        Remove full paths from file specifications.
        A new input file is no longer written, if that is desired
        the user must call the save method.

        """

        # Update internal state of this class
        for key in self.input_file_keys:
            filepath = self.getValue(key)
            if isinstance(filepath, list):
                basename = [os.path.basename(x) for x in filepath]
            else:
                basename = os.path.basename(filepath)
            self.setValue(key, basename)

    def update_input_file_paths(self):
        """
        Update full paths for file specifications to reflect the CWD.
        This is useful if running this job in a subdirectory or on a
        remote host.

        """

        def _prepend_cwd(cwd, filename):
            if os.path.exists(filename):
                return os.path.join(cwd, filename)
            else:
                raise IOError("File does not exist: %s" % filename)

        # remove old paths
        self.remove_input_file_paths()
        cwd = os.getcwd()

        # Update internal state of this class
        for key in self.input_file_keys:

            filepath = self.getValue(key)

            if filepath:  #only act on non-empty lists/strings
                if isinstance(filepath, list):
                    basename = [_prepend_cwd(cwd, x) for x in filepath]
                else:
                    basename = _prepend_cwd(cwd, filepath)

                self.setValue(key, basename)

    def have_constraints(self):
        """
        Do we have internal coordinate constraints
        """
        return len(self._constraints) > 0

    def have_frozen_atoms(self):
        """
        Do we have frozen atom constraints
        """
        return len(self._frozen_atoms) > 0

    def get_input_files(self):
        """
        Return set of expected input files.
        """

        infiles = set([])

        # Main input file
        if self._inputfile is not None:
            infiles.add(self._inputfile)

        # Auxiliary files
        for key in self.input_file_keys:
            filepath = self.getValue(key)
            if isinstance(filepath, list):
                infiles.update(filepath)
            elif filepath:
                infiles.add(filepath)

        return infiles

    def _parse_keyword_line(self, line):
        """
        Parse line of text for keyword=value and store it.
        Ignore comment lines starting with #

        :type  line: str
        :param line: line from input file

        :raise WorkflowKeywordFormatError if line can't be parsed.

        :return (key, value) or (None, None) if empty line.
        """

        key = None
        value = None

        # Ignore blank lines and comment lines starting with #
        # and transform keywords to lower case
        if line.partition('#')[0].strip():
            try:
                key, value = [x.strip() for x in line.split('=', 1)]
                key = key.lower()
            except ValueError:
                raise wv.WorkflowKeywordFormatError(line)

            # Convert to YAML form and parse. This conveniently
            # ignores all text starting with # and coerces
            # strings into corresponding python types.
            # e.g. '1' becomes an int, '2.0' becomes a float,
            #      '[1, 2]' becomes a list of int's,
            #      '[a, b]' becomes a list of str's, etc.
            line = key + ': ' + value
            parsed_line = yaml.load(line)
            value = parsed_line.get(key)

            # If keyword expects a list but input file has only
            # one element, convert it to a list
            keyword = self.keywords.get(key)
            if keyword and isinstance(keyword.valid_type, list):
                if not isinstance(value, list):
                    value = [value]

        return key, value

    def _parse_constraint_line(self, line):
        """
        Parse a line specifying a constraint.
        The format is structure title, list of indexes and value for constraint.
        The indexes are assumed to be referring to an input reactant or product
        molecule as defined by the structure title.

        :type line: str
        :param line: a line from a &Constraint section
        :return: the structure title and a Coordinate instance describing constraint.
        """
        title, constraint = (None, None)

        # Ignore blank lines and comment lines starting with #
        if line.partition('#')[0].strip():
            items = line.split()

            # atoms use a dummy value (frozen)
            if len(items) == 2:
                value = np.zeros(3)
                indexes = [int(items[1])]
            # internal coords
            elif len(items) >= 4 and len(items) <= 6:
                value = float(items[-1])
                indexes = map(int, items[1:-1])
            else:
                raise wv.ConstraintFormatError(line)

            title = items[0]

            constraint = Coordinate(value, *indexes)

        return title, constraint