Source code for schrodinger.application.desmond.cmj

"""
This module provides fundamental facilities for writing a multisim driver
script, for writing multisim concrete stage classes, and for dealing with
protocol files.

Copyright Schrodinger, LLC. All rights reserved.

"""

import copy
import gc
import glob
import itertools
import os
import pickle
import shutil
import signal
import subprocess
import sys
import json
import tarfile
import threading
import time
import re
import weakref
from io import BytesIO
from typing import BinaryIO
from typing import List
from typing import Optional
from typing import Union

import schrodinger.application.desmond.bld_ver as bld
import schrodinger.application.desmond.cmdline as cmdline
import schrodinger.application.desmond.envir as envir
import schrodinger.application.desmond.picklejar as picklejar
import schrodinger.application.desmond.util as util
import schrodinger.infra.mm as mm
import schrodinger.job.jobcontrol as jobcontrol
import schrodinger.job.queue as que
import schrodinger.utils.sea as sea
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import queue
from schrodinger.utils import fileutils

from .picklejar import Picklable
from .picklejar import PicklableMetaClass
from .picklejar import PickleJar

# Contributors: Yujie Wu

# Info
VERSION = "3.8.5.19"
BUILD = bld.desmond_build_version()

# Machinery
QUEUE = None
ENGINE = None

# Log
LOGLEVEL = [
    "silent",
    "quiet",
    "verbose",
    "debug",
]
GENERAL_LOGLEVEL = "quiet"

# Suffixes
PACKAGE_SUFFIX = ".tgz"
CHECKPOINT_SUFFIX = "-multisim_checkpoint"

# Filenames
CHECKPOINT_FNAME = "$MASTERJOBNAME" + CHECKPOINT_SUFFIX

_PRODUCTION_SIMULATION_STAGES = ["lambda_hopping", "replica_exchange"]


def _print(loglevel, msg):
    if LOGLEVEL.index(loglevel) <= LOGLEVEL.index(GENERAL_LOGLEVEL):
        if loglevel == "debug":
            print("MSJDEBUG: %s" % msg)
        else:
            print(msg)

        sys.stdout.flush()


def print_tonull(msg):
    pass


def print_silent(msg):
    _print("silent", msg)


def print_quiet(msg):
    _print("quiet", msg)


def print_verbose(msg):
    _print("verbose", msg)


def print_debug(msg):
    _print("debug", msg)


def _time_str_to_time(time_str, scale=1.0):
    h, m, s = [e[:-1] for e in time_str.split()]
    return scale * (float(h) * 3600 + float(m) * 60 + float(s))


def _time_to_time_str(inp_time):
    h, r = divmod(int(inp_time), 3600)
    m, s = divmod(r, 60)
    return "%sh %s' %s\"" % (h, m, s)


class JobStatus(object):

    # Good status
    WAITING = 101
    RUNNING = 102
    SUCCESS = 103

    # Bad status and non-retriable
    BACKEND_ERROR = 201
    PERMANENT_LICENSE_FAILURE = 202
    NON_RETRIABLE_FAILURE = 299

    # Bad status and retriable
    TEMPORARY_LICENSE_FAILURE = 301
    KILLED = 302
    FIZZLED = 303
    LAUNCH_FAILURE = 304
    FILE_NOT_FOUND = 305
    FILE_CORRUPT = 306
    STRANDED = 307
    CHECKPOINT_REQUESTED = 308
    CHECKPOINT_WITH_RESTART_REQUESTED = 309
    RETRIABLE_FAILURE = 399

    STRING = {
        WAITING: "is waiting for launching",
        RUNNING: "is running",
        SUCCESS: "was successfully finished",
        PERMANENT_LICENSE_FAILURE: ("could not run due to permanent license "
                                    "failure"),
        TEMPORARY_LICENSE_FAILURE: "died due to temporary license failure",
        KILLED: "was killed",
        FIZZLED: "fizzled",
        STRANDED: "was stranded",
        LAUNCH_FAILURE: "failed to launch",
        FILE_NOT_FOUND: ("was finished, but registered output files were not "
                         "found"),
        FILE_CORRUPT: ("was finished, but an essential output file was found "
                       "corrupt"),
        BACKEND_ERROR: "died due to backend error",
        RETRIABLE_FAILURE: "died on unknown retriable failure",
        NON_RETRIABLE_FAILURE: "died on unknown non-retriable failure",
        CHECKPOINT_REQUESTED: "user requested job be checkpointed",
        CHECKPOINT_WITH_RESTART_REQUESTED: "user requested job be checkpointed and restarted"
    }

    def __init__(self, code=WAITING):
        self._code = code
        self._error = None

    def __str__(self):
        s = ""
        try:
            s += JobStatus.STRING[self._code]
        except KeyError:
            if self._error is None:
                s += "unknown error"
        if self._error is not None:
            s += "\n" + self._error
        return s

    def __eq__(self, other):
        if isinstance(other, JobStatus):
            return self._code == other._code
        else:
            try:
                return self._code == int(other)
            except ValueError:
                raise NotImplementedError

    def __ne__(self, other):
        return not self.__eq__(other)

    def set(self, code, error=None):
        if isinstance(code, JobStatus):
            self._code = code
        else:
            try:
                self._code = int(code)
            except ValueError:
                raise NotImplementedError
        self._error = error

    def is_good(self):
        return self._code < 200

    def is_retriable(self):
        return self._code > 300

    def should_restart_from_checkpoint(self):
        return self._code == self.CHECKPOINT_WITH_RESTART_REQUESTED


class JobOutput(object):

    def __init__(self):
        # Key: file name. Value: None or a callable that checks the file
        self._file = {}
        self._type = {}  # Key: file name. Value: "file" | "dir"
        self._tag = {}  # Key: tag.       Value: file name
        self._struct = None
        # Note on pickling: Values in `self._file' will be set to None when
        # `self' is pickled.

    def __len__(self):
        """
        Returns the number of registered output files.
        """
        return len(self._file)

    def __iter__(self):
        """
        Iterates through the registered output files.
        Note that the order of the files here are not necessarily the same order
        of file registration.
        """
        for f in self._file:
            yield f

    def __list__(self):
        return list(self._files)

    def __deepcopy__(self, memo={}):  # noqa: M511
        newobj = JobOutput()
        memo[id(self)] = newobj
        newobj._file = copy.deepcopy(self._file)
        newobj._type = copy.deepcopy(self._type)
        newobj._tag = copy.deepcopy(self._tag)
        return newobj

    def __getstate__(self):
        tmp_dict = copy.copy(self.__dict__)
        _file = tmp_dict["_file"]
        for k in _file:
            _file[k] = None
        return tmp_dict

    def update_basedir(self, old_basedir, new_basedir):
        old_basedir += os.sep
        new_basedir += os.sep
        new_file = {}
        for k in self._file:
            v = self._file[k]
            if k.startswith(old_basedir):
                k = k.replace(old_basedir, new_basedir)
            new_file[k] = v
        self._file = new_file
        new_type = {}
        for k in self._type:
            v = self._type[k]
            if k.startswith(old_basedir):
                k = k.replace(old_basedir, new_basedir)
            new_type[k] = v
        self._type = new_type
        for k in self._tag:
            v = self._tag[k]
            if v.startswith(old_basedir):
                v = v.replace(old_basedir, new_basedir)
            self._tag[k] = v
        try:
            if self._struct and self._struct.startswith(old_basedir):
                self._struct = self._struct.replace(old_basedir, new_basedir)
        except AttributeError:
            pass
        try:
            new_cms = []
            for e in self.cms:
                new_cms.append(e.replace(old_basedir, new_basedir))
            self.cms = new_cms
        except AttributeError:
            pass

    def add(self, filename, checker=None, tag=None, type="file"):
        """
        :param type: either "file" and "dir".
        """
        if filename:
            if type not in ("file", "dir"):
                raise ValueError(
                    'Valid values for \'type\' are "file" and "dir". '
                    f'But "{type}" is given')
            self._file[filename] = checker
            self._type[filename] = type
            if tag is not None:
                if tag in self._tag:
                    old_filename = self._tag[tag]
                    del self._file[old_filename]
                    del self._type[old_filename]
                self._tag[tag] = filename

    def remove(self, filename):
        """

        """
        try:
            del self._file[filename]
        except KeyError:
            pass
        try:
            del self._type[filename]
        except KeyError:
            pass
        for key, value in self._tag.items():
            if value == filename:
                del self._tag[key]
                break

    def get(self, tag):
        return self._tag.get(tag)

    def check(self, status):
        for fname in self._file:
            _print("debug", "checking output file: %s" % fname)
            if self._type[fname] == "file":
                if os.path.isfile(fname):
                    checker = self._file[fname]
                    if checker:
                        err_msg = checker(fname)
                        if err_msg:
                            status.set(JobStatus.FILE_CORRUPT, err_msg)
                            return
                else:
                    _print("debug", "Output file: %s not found" % fname)
                    try:
                        _print("debug", "Files in current directory: %s" % str(
                            os.listdir(os.path.dirname(fname))))
                    except OSError:
                        _print(
                            "debug",
                            "Directory not found: %s" % os.path.dirname(fname))
                    status.set(JobStatus.FILE_NOT_FOUND)
                    return
            elif self._type[fname] == "dir":
                if not os.path.isdir(fname):
                    _print("debug", "Output directory: %s not found" % fname)
                    try:
                        _print("debug", "Files in parent directory: %s" % str(
                            os.listdir(os.path.dirname(fname))))
                    except OSError:
                        _print(
                            "debug",
                            "Directory not found: %s" % os.path.dirname(fname))
                    status.set(JobStatus.FILE_NOT_FOUND)
                    return
        status.set(JobStatus.SUCCESS)

    def set_struct_file(self, fname):
        self._struct = fname
        if fname not in self._file:
            self.add(fname)

    def struct_file(self):
        if not self._struct:
            for fname in self:
                if fname.endswith((".mae", ".cms", ".maegz", ".cmsgz",
                                   ".mae.gz", ".cms.gz")):
                    return fname
        else:
            return self._struct
        return None

    def log_file(self):
        for fname in self:
            if fname.endswith(".log"):
                return fname
        return None


class JobInput(JobOutput):

    def __deepcopy__(self, memo={}):  # noqa: M511
        newobj = JobInput()
        memo[id(self)] = newobj
        newobj._file = copy.deepcopy(self._file)
        newobj._type = copy.deepcopy(self._type)
        newobj._tag = copy.deepcopy(self._tag)
        return newobj

    def cfg_file(self):
        for fname in self:
            if fname.endswith(".cfg"):
                return fname
        return None

    def incfg_file(self):
        for fname in self:
            if fname.endswith("in.cfg"):
                return fname
        return None

    def outcfg_file(self):
        for fname in self:
            if fname.endswith("out.cfg"):
                return fname
        return None


class JobErrorHandler:

    @staticmethod
    def default(job):
        """
        If the job status is bad, attempt to print the log file and
        nvidia-smi output.
        """
        if not job.status.is_good():
            job._print(
                "quiet", "jlaunch_dir: %s\n" % job.dir +
                "jlaunch_cmd: %s" % subprocess.list2cmdline(job.jlaunch_cmd))
            log_fname = job.output.log_file()
            if log_fname and os.path.exists(log_fname):
                job._print("quiet", "Log file   : %s" % log_fname)
                with open(log_fname, "r") as f:
                    log_content = f.readlines()
                job._print("quiet",
                           "Log file content:\n%s" % ">".join(log_content))
                job._print("quiet", "(end of log file)\n")
            else:
                job._print("quiet", "No log file registered for this job\n")
            # call nvidia-smi and print output to log file
            if job.USE_GPU:
                try:
                    output = subprocess.check_output(
                        "nvidia-smi", universal_newlines=True)
                    job._print("quiet", "nvidia-smi output:\n%s" % output)
                except (FileNotFoundError, subprocess.CalledProcessError):
                    job._print("quiet", "No nvidia-smi output available\n")

    @staticmethod
    def restart_for_backend_error(job):
        """
        Run the default handler and if the status is
        killed or backend error, mark the failure as retriable.
        """
        if not job.status.is_good():
            JobErrorHandler.default(job)

        if job.status in [JobStatus.BACKEND_ERROR, JobStatus.KILLED]:
            job.status.set(JobStatus.RETRIABLE_FAILURE)


def exit_code_is_defined(job):
    """
    Return True if job has an exit code. Failed jobs may not have exit codes if
    they are killed by the queueing system or otherwise untrackable.
    """
    try:
        int(job.ExitCode)
    except ValueError:
        return False
    return True


class Job(object):
    # most jobs do not use gpu
    USE_GPU = False

    class Time(object):

        def __init__(self, launch, start, end, num_cpu, cpu_time, duration):
            self.launch = launch
            self.start = start
            self.end = end
            self.num_cpu = num_cpu
            self.cpu_time = cpu_time
            self.duration = duration

    @staticmethod
    def _get_time_helper(jobtime):
        try:
            t = time.mktime(time.strptime(jobtime, jobcontrol.timestamp_format))
            s = time.ctime(t)
        except AttributeError:
            t = None
            s = "(unknown)"
        return t, s

    @staticmethod
    def get_time(jctrl, num_cpu):
        launch_time, str_launch_time = Job._get_time_helper(jctrl.LaunchTime)
        if jctrl.StartTime:
            start_time, str_start_time = Job._get_time_helper(jctrl.StartTime)
        else:
            return Job.Time(str_launch_time, "(not started)", "N/A", num_cpu,
                            "N/A", "N/A")
        if jctrl.StopTime:
            stop_time, str_stop_time = Job._get_time_helper(jctrl.StopTime)
        else:
            return Job.Time(str_launch_time, start_time, "(stranded)", num_cpu,
                            "N/A", "N/A")

        if start_time is not None and num_cpu != "(unknown)":
            cpu_time = util.time_duration(start_time, stop_time, num_cpu)
            duration = util.time_duration(start_time, stop_time)
        else:
            cpu_time = "(unknown)"
            duration = "(unknown)"
        return Job.Time(str_launch_time, str_start_time, str_stop_time, num_cpu,
                        cpu_time, duration)

    def get_proc_time(self):
        proc_time = Job.get_time(self.jctrl, self.num_cpu).cpu_time
        return _time_str_to_time(proc_time) if proc_time != "(unknown)" else 0.0

    def __init__(self,
                 jobname,
                 parent,
                 stage,
                 jlaunch_cmd,
                 dir,
                 host_list=None,
                 prefix=None,
                 what=None,
                 err_handler=JobErrorHandler.default):
        self.jobname = jobname
        self.tag = None
        # Job object from which this `Job' object was derived.
        self.parent = parent
        # other Job objects from which this `Job' object was derived.
        self.other_parent = None
        # Job control object, will be set once the job is launched.
        self.jctrl = None
        self.jlaunch_cmd = jlaunch_cmd  # Job launch command
        # List of hosts where this job can be running
        self.host_list = host_list
        # Actual host where this job is running
        self.host = jobcontrol.Host("localhost")
        # By default, subjobs do not need a host other than localhost.
        self.need_host = False
        self.num_cpu = 1
        self.use_hostcpu = False
        # Launch directory, also where the job's outputs will be copied back
        self.dir = dir
        self.prefix = prefix  # Prefix directory of the launch directory
        self.what = what  # A string that stores more specific job description
        self.output = JobOutput()  # Output file names
        self.input = JobInput()  # Input  file names
        self.status = JobStatus()  # Job status
        # `None' or a callable object that will be called to handle job errors.
        self.err_handler = err_handler

        self._jctrl_hist = []
        self._has_run = False

        if self.parent and self.prefix is None:
            self.prefix = self.parent.prefix
        if isinstance(stage, weakref.ProxyType):
            self.stage = stage
        else:
            self.stage = weakref.proxy(stage)
        # Note on pickling: `self.err_handler' will not be picked.

    def __deepcopy__(self, memo={}):  # noqa: M511
        newobj = object.__new__(self.__class__)
        memo[id(self)] = newobj
        for k, v in self.__dict__.items():
            if k in ["stage", "jctrl", "parent"]:
                value = self.__dict__[k]
            elif k == "other_parent":
                value = copy.copy(self.other_parent)
            elif k == "_jctrl_hist":
                value = []
            else:
                value = copy.deepcopy(v, memo)
            setattr(newobj, k, value)
        return newobj

    def __getstate__(self, state=None):
        state = state if (state) else copy.copy(self.__dict__)
        if "err_handler" in state:
            del state["err_handler"]
        if "jctrl" in state:
            state["jctrl"] = str(self.jctrl)
        if "_jctrl_hist" in state:
            state["_jctrl_hist"] = ["removed_in_serialization"]
        if "jlaunch_cmd" in state:
            if callable(state["jlaunch_cmd"]):
                state["jlaunch_cmd"] = "removed_in_serialization"
        if "stage" in state:
            state["stage"] = (self.stage if (isinstance(self.stage, int)) else
                              self.stage._INDEX)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        if "stage" in state and ENGINE:
            self.stage = weakref.proxy(ENGINE.stage[self.stage])

    def __repr__(self):
        """
        Returns the jobname string in the format: <jobname>.
        """
        r = f"<{self.jobname}"
        if self.jctrl:
            r += f"({self.jctrl})"
        r += ">"
        return r

    def _print(self, loglevel, msg):
        """
        The internal print function of this job. Printing is at the same
        'loglevel' as self.stage.
        """
        self.stage._print(loglevel, msg)

    def _log(self, msg):
        """
        The internal log function of this job.
        """
        self.stage._log(msg)

    def _host_str(self):
        """
        Returns a string representing the hosts.
        """
        if isinstance(self.host, list):
            host_str = ""
            for e in self.host:
                host_str += "%s:%d " % (e[0].name, e[1])
        else:
            host_str = self.host.name
            if self.use_hostcpu and -1 == host_str.find(":"):
                host_str += ":%d" % self.num_cpu

        return host_str

    def describe(self):
        if self.status != JobStatus.LAUNCH_FAILURE:
            self._print("quiet", "  Launch time: %s" % self.jctrl.LaunchTime)
            # TODO: should this look for -HOST in the jlaunch_cmd?
            self._print("quiet", "  Host       : %s" % self._host_str())
        self._print(
            "quiet",
            "  Jobname    : %s\n" % self.jobname + "  Stage      : %d (%s)" %
            (self.stage._INDEX, self.stage.NAME))
        self._print(
            "verbose", "  Prefix     : %s\n" % self.prefix +
            "  Jlaunch_cmd: %s\n" % subprocess.list2cmdline(self.jlaunch_cmd) +
            "  Outputs    : %s" % str(list(self.output)))
        if self.what:
            self._print("quiet", "  Description: %s" % self.what)

    def process_completed_job(self,
                              jctrl: jobcontrol.Job,
                              checkpoint_requested=False,
                              restart_requested=False):
        """
        Check for valid output and set status of job, assuming job is already
        complete.

        :param checkpoint_requested: Set to True if the job should checkpoint.
            Default if False.
        :param restart_requested: Set to True if the job should checkpoint and restart.
            Default if False.
        """
        self.jctrl = jctrl

        # Make sure the job data has been downloaded and flushed to disk
        self.jctrl.download()
        # Not available on windows
        if hasattr(os, 'sync'):
            os.sync()

        self._print(
            "debug",
            "Job seems finished. Checking its exit-status and exit-code...")
        self._print("debug", "Job exit-status = '%s'" % self.jctrl.ExitStatus)
        if self.jctrl.ExitStatus == "killed":
            self._print("debug", "Job exit-code = N/A")
            self.status.set(JobStatus.KILLED)
        elif self.jctrl.ExitStatus == "fizzled":
            self._print("debug", "Job exit-code = N/A")
            self.status.set(JobStatus.FIZZLED)
        else:
            exit_code = self.jctrl.ExitCode
            if not exit_code_is_defined(self.jctrl):
                # If the exit code is not set, the backend must have died
                # without collecting the exit code. This could happen if a job
                # is qdeled, or the backend gets killed by OOM, or the job
                # monitoring process is killed by any reason.
                # Set status to a retriable status.
                self.status.set(JobStatus.KILLED)
            elif exit_code == 0:
                if checkpoint_requested:
                    self.status.set(JobStatus.CHECKPOINT_REQUESTED)
                elif restart_requested:
                    self.status.set(JobStatus.CHECKPOINT_WITH_RESTART_REQUESTED)
                else:
                    self.output.check(self.status)
            elif exit_code == 17:
                # The mmlic3 library will return the following error codes upon
                # checkout:
                #    0 : success
                #   15 : temporary, retryable failure; perhaps the server
                #        couldn't be contacted
                #   16 : all licenses are in use. SGE is capable of requeuing
                #        the job.
                #   17 : fatal, unrecoverable license error.
                self.status.set(JobStatus.PERMANENT_LICENSE_FAILURE)
            elif exit_code in {15, 16}:
                self.status.set(JobStatus.TEMPORARY_LICENSE_FAILURE)
            else:
                self.status.set(JobStatus.BACKEND_ERROR)

    def requeue(self, jctrl: jobcontrol.Job):
        # Make sure the job data has been downloaded and flushed to disk
        jctrl.download()
        # Not available on windows
        if hasattr(os, 'sync'):
            os.sync()

        # Delete stale checkpoint files that are not needed for restarting
        def _filter_tgz(input_fnames: List[str]):
            return set(filter(lambda x: x.endswith('-out.tgz'), input_fnames))

        stale_input_tgz_fnames = _filter_tgz(jctrl.InputFiles) - _filter_tgz(
            jctrl.OutputFiles)

        for fname in stale_input_tgz_fnames:
            util.remove_file(fname)

        self._print("quiet", f"Restart checkpointed job: {self.jlaunch_cmd}")
        self._print("quiet",
                    f"Deleted stale input files: {stale_input_tgz_fnames}")

        self.stage.restart_subjobs([self])
        self.status.set(JobStatus.WAITING)

    def finish(self):
        if self.status != JobStatus.LAUNCH_FAILURE:
            jobtime = Job.get_time(self.jctrl, self.num_cpu)
            self._print("quiet", "\n%s %s." % (str(self.jctrl),
                                               str(self.status)))
            self._print(
                "quiet",
                "  Host       : %s\n" % self._host_str() +
                "  Launch time: %s\n" % jobtime.launch +
                "  Start time : %s\n" % jobtime.start +
                "  End time   : %s\n" % jobtime.end +
                "  Duration   : %s\n" % jobtime.duration +
                "  CPUs       : %s\n" % self.num_cpu +
                "  CPU time   : %s\n" % jobtime.cpu_time +
                "  Exit code  : %s\n" % self.jctrl.ExitCode +
                "  Jobname    : %s\n" % self.jobname +
                "  Stage      : %d (%s)" % (self.stage._INDEX, self.stage.NAME),
            )
        if self.err_handler:
            self.err_handler(self)

        if self.status.is_retriable():
            self._print("quiet",
                        "  Retries    : 0 - Job has failed too many times.")
        self.stage.capture(self)


class _create_param_when_needed(object):

    def __init__(self, param):
        self._param = param

    def __get__(self, obj, cls):
        if cls == StageBase:
            a = sea.Map(self._param)
            a.add_tag("generic")
        else:
            a = None
            for c in cls.__bases__[::-1]:  # left-most base takes precedence
                if issubclass(c, StageBase):
                    if a is None:
                        a = copy.deepcopy(c.PARAM)
                    else:
                        a.update(copy.deepcopy(c.PARAM))
            a.update(self._param, tag="stagespec")
        setattr(cls, "PARAM", a)
        return a


class _StageBaseMeta(PicklableMetaClass):

    def __init__(cls, name, bases, dict):
        PicklableMetaClass.__init__(cls, name, bases, dict)
        cls.stage_cls[cls.NAME] = cls


class StageBase(Picklable, metaclass=_StageBaseMeta):

    count = 0, Picklable
    stage_cls = {}
    stage_obj = {}  # key = stage name; value = stage instance.
    NAME = "generic"

    # Basic stage parameters
    PARAM = _create_param_when_needed("""
    DATA = {
    title         = ?
    should_sync   = true
    dryrun        = false
    prefix        = ""
    jobname       = "$MASTERJOBNAME_$STAGENO"
    dir           = "$[$JOBPREFIX/$]$[$PREFIX/$]$MASTERJOBNAME_$STAGENO"
    compress      = "$MASTERJOBNAME_$STAGENO-out%s"
    struct_output = ""
    should_skip   = false
    effect_if     = ?
    jlaunch_opt   = []
    transfer_asap = no
    }

    VALIDATE = {
    title         = [{type = none} {type = str}]
    should_sync   = {type = bool}
    dryrun        = {type = bool}
    prefix        = {type = str }
    jobname       = {type = str }
    dir           = {type = str }
    compress      = {type = str }
    struct_output = {type = str }
    should_skip   = {type = bool}
    effect_if     = [{type = none} {type = list size = -2 _skip = all}]
    jlaunch_opt   = {
       type  = list size = 0
       elem  = {type = str}
       check = ""
       black_list = ["-HOST" "-USER" "-JOBNAME"]
    }
    transfer_asap = {type = bool}
    }
    """ % (PACKAGE_SUFFIX,))

    def __init__(self, should_pack=True):
        # Will be set by the parser (see `parse_msj' function below).
        self.param = None

        self._PREV_STAGE = None  # Stage object of the previous stage
        self._NEXT_STAGE = None  # Stage object of the next stage
        self._ID = StageBase.count  # ID number of this stage
        self._INDEX = None  # Stage index. Not serialized.

        self._is_shown = False
        self._is_packed = False
        self._should_pack = should_pack
        self._pre_job = []  # Jobs pashed to this stage for crunching
        # Copies of all jobs pashed to this stage for crunching. Used by
        # restarting.
        self._pre_jobre = []
        self._rls_job = []  # Released jobs
        self._cap_job = []  # Captured jobs
        # Jobs that will be directly captured without going through the queue
        self._pas_job = []
        self._job = []  # Jobs that will be sent to the queue

        # For parameter validation
        # function objects to be called before the main parameter check
        self._precheck = []
        # function objects to be called after the main parameter check
        self._postcheck = []
        self._files4pack = []
        self._files4copy = []
        self._pack_fname = ""

        # Has the `prestage' method been called?
        self._is_called_prestage = False
        self._used_jobname = []

        self._start_time = None  # Holds per stage start time
        self._stage_duration = None  # Holds per stage duration time
        self._gpu_time = 0.0  # Accumulates total GPU time
        self._num_gpu_subjobs = 0  # Number of GPU subjobs
        self._packed_fnames = set()

        StageBase.count += 1

    def __getstate__(self, state=None):
        state = state if (state) else picklejar.PickleState()

        state.NAME = self.NAME
        state._ID = self._ID
        state._is_shown = self._is_shown
        state._is_packed = self._is_packed
        state._pre_job = self._pre_job
        state._pre_jobre = self._pre_jobre
        state._rls_job = self._rls_job
        state._cap_job = self._cap_job
        try:
            state._pack_fname = self._pack_fname
        except AttributeError:
            state._pack_fname = ""
        return state

    def __setstate__(self, state):
        if state.NAME != self.NAME:
            raise TypeError("Unmatched stage: %s vs %s" % (state.NAME,
                                                           self.NAME))
        self.__dict__.update(state.__dict__)

    def _print(self, loglevel, msg):
        _print(loglevel, msg)

    def _log(self, msg):
        self._print("quiet", "stage[%d] %s: %s" % (self._INDEX, self.NAME, msg))

    def _get_macro_dict(self):
        macro_dict = copy.copy(ENGINE.macro_dict)
        macro_dict["$STAGENO"] = self._INDEX
        return macro_dict

    def _gen_unique_jobname(self, suggested_jobname):
        trial_jobname = suggested_jobname
        number = 1
        while trial_jobname in self._used_jobname:
            trial_jobname = suggested_jobname + ("_%d" % number)
            number += 1
        self._used_jobname.append(trial_jobname)
        sea.update_macro_dict({"$JOBNAME": trial_jobname})
        return trial_jobname

    def _get_jobname_and_dir(self, job, macro_dict={}):  # noqa: M511
        sea.set_macro_dict(self._get_macro_dict())
        sea.update_macro_dict(macro_dict)
        if self.param.prefix.val != "":
            sea.update_macro_dict({"$PREFIX": self.param.prefix.val})
        if job.prefix != "" and job.prefix is not None:
            sea.update_macro_dict({"$JOBPREFIX": job.prefix})
        try:
            if job.tag is not None:
                sea.update_macro_dict({"$JOBTAG": job.tag})
        except AttributeError:
            pass
        util.chdir(ENGINE.base_dir)
        sea.update_macro_dict({"$JOBNAME": self.param.jobname.val})
        return (
            self.param.jobname.val,
            os.path.abspath(self.param.dir.val),
        )

    def _param_jlaunch_opt_check(self, key, val_list, prefix, ev):
        try:
            black_list = set(self.PARAM.VALIDATE.jlaunch_opt.black_list.val)
        except AttributeError:
            return
        jlaunch_opt = set(val_list.val)
        bad_opt = jlaunch_opt & black_list
        if bad_opt:
            s = " ".join(bad_opt)
            ev.record_error(prefix,
                            "Bad values for jlaunch_opt of %s stage: %s" %
                            (self.NAME, s))

    def _reg_param_precheck(self, func):
        if func not in self._precheck:
            self._precheck.append(func)

    def _reg_param_postcheck(self, func):
        if func not in self._postcheck:
            self._postcheck.append(func)

    def _set(self, key, setter, transformer=None):
        param = self.param[key]
        if param.has_tag("setbyuser"):
            if callable(setter):
                setter(param)
            elif isinstance(setter, sea.Atom):
                if callable(transformer):
                    setter.val = transformer(param.val)
                else:
                    setter.val = param.val

    def _effect(self, param):
        effect_if = param.effect_if
        if isinstance(effect_if, sea.List):
            for condition, block in zip(effect_if[0::2], effect_if[1::2]):
                # TODO: Don't use private function
                val = sea.evalor._eval(PARAM, condition)
                if isinstance(val, bool):
                    condition = val
                elif isinstance(val[0], str):
                    condition = _operator[val[0]](self, PARAM, val[1:])
                else:
                    condition = val[0]
                if condition:
                    if isinstance(block, sea.Atom):
                        block = sea.Map(block.val)
                    # Checks if within the `block' is the 'effect_if' parameter
                    # set.
                    if "effect_if" not in block:
                        block.effect_if = sea.Atom("none")
                    # TODO what is the purpose of this line below?
                    effect_if[1] = block
                    block = block.dval
                    param.update(block)
                    self._effect(param)
        return param

    def describe(self):
        self._print("quiet", "\nStage %d - %s" % (self._INDEX, self.NAME))
        self._print("verbose",
                    "{\n" + self.param.__str__("  ", tag="setbyuser") + "}")

    def migrate_param(self, param: sea.Map):
        """
        Subclasses can implement this to migrate params to provide backward
        compatibility with older msj files, ideally with a deprecation warning.
        """

    def check_param(self):

        def clear_trjidx(prmdata):
            """
            do not use idx files
            """
            try:
                if "maeff_output" in prmdata:
                    del prmdata["maeff_output"]["trjidx"]
            except (KeyError, TypeError):
                pass

        check_func_name = "multisim_stage_%d_jlaunch_opt_check" % self._ID
        self.PARAM.VALIDATE.jlaunch_opt.check.val = check_func_name
        sea.reg_xcheck(check_func_name, self._param_jlaunch_opt_check)

        # Note that `self.param's parent should be the global `PARAM'.
        # But this statement will implicitly change its parent to `self.PARAM'.
        # At the end of this function we need to change it back to `PARAM'.
        orig_param_data = self.PARAM.DATA
        self.PARAM.DATA = self.param

        clear_trjidx(self.PARAM.DATA)
        ev = sea.Evalor(self.param, "\n")
        for func in self._precheck:
            try:
                func()
            except ParseError as e:
                ev.record_error(err=str(e))
        sea.check_map(self.PARAM.DATA, self.PARAM.VALIDATE, ev, "setbyuser")
        for func in self._postcheck:
            try:
                func()
            except ParseError as e:
                ev.record_error(err=str(e))
        self.param.set_parent(PARAM.stage)
        self.PARAM.DATA = orig_param_data
        return ev

    def push(self, job):
        if not self._is_called_prestage and not self.param.should_skip.val:
            self._is_called_prestage = True
            self.prestage()
        if job is None:
            self._print("debug",
                        "All surviving jobs have been pushed into stage[%d]." %
                        self._INDEX)
            self.release()
        else:
            self._print(
                "debug",
                "Job was just pushed into stage[%d]: %s" % (self._INDEX,
                                                            str(job)),
            )
            self._pre_job.append(job)
            if job not in self._pre_jobre:
                self._pre_jobre.append(job)
            if not self.param.should_sync.val:
                self.release()

    def determine(self):
        param = self._effect(self.param)
        if param.should_skip.val:
            self._pas_job.extend(self._pre_job)
            self._pre_job = []

    def crunch(self):
        """
        This is where jobs of this stage are created. This function should
        be overriden by the subclass.
        """

    def restart_subjobs(self, jobs):
        """
        Subclass should override this if it supports subjob restarting.
        """

    def release(self, is_restarting=False):
        """
        Calls the 'crunch' method to generate new jobs objects and submits
        them to the 'QUEUE'.
        """
        util.chdir(ENGINE.base_dir)
        if not self._is_shown:
            self.describe()
            self._is_shown = True
            is_restarting = True

        if self._start_time is None:
            self._start_time = time.time()

        self.determine()

        if is_restarting:
            self._rls_job = list(set(self._rls_job))
            self.restart_subjobs(self._rls_job)
            self._job = self._rls_job
            self._rls_job = []

        self.crunch()
        jlaunch_opt = [str(e) for e in self.param.jlaunch_opt.val]
        if jlaunch_opt != [""]:
            for job in self._job:
                job.jlaunch_cmd += jlaunch_opt
        self._rls_job = self._rls_job + self._job + self._pas_job

        if self.param.dryrun.val:
            for job in self._job:
                job.status.set(JobStatus.SUCCESS)
                self.capture(job)
        else:
            ENGINE.write_checkpoint()
            QUEUE.push(self._job)

        for job in self._pas_job:
            if not job._has_run and callable(job.jlaunch_cmd):
                if not self.param.dryrun.val:
                    job.jlaunch_cmd(job)
                job._has_run = True
            self.capture(job)
        self._pas_job = []
        self._job = []

    def capture(self, job):
        self._print("debug", "Captured %s" % job)
        if job.status == JobStatus.SUCCESS:
            job_fate = None
            if not self.param.should_skip.val:
                job_fate = self.hook_captured_successful_job(job)
            if self._NEXT_STAGE is not None and job_fate != "dissolved":
                self._NEXT_STAGE.push(job)
            self._rls_job.remove(job)
            if job not in self._cap_job:
                self._cap_job.append(job)
            if self.param.transfer_asap.val:
                self.pack_stage(force=True)

        if (isinstance(job.jlaunch_cmd, list) and
                isinstance(job.jctrl, jobcontrol.Job) and
                job not in self._pas_job):
            if job.USE_GPU:
                self._gpu_time += job.get_proc_time()
                self._num_gpu_subjobs += 1

        ENGINE.write_checkpoint()

        self._print("debug", "released jobs:")
        self._print("debug", self._rls_job)
        self._print("debug", "captured jobs:")
        self._print("debug", self._cap_job)

        if self._rls_job == []:
            if self.param.should_skip.val:
                self._print("quiet", "\nStage %d is skipped.\n" % self._INDEX)
            else:
                self._print("quiet",
                            "\nStage %d completed successfully." % self._INDEX)
                self.poststage()
                self.pack_stage(force=self.param.transfer_asap.val)
            if self._NEXT_STAGE is not None:
                self._NEXT_STAGE.push(None)

    def prestage(self):
        pass

    def poststage(self):
        pass

    def hook_captured_successful_job(self, job):
        pass

    def time_stage(self):
        this_stop_time = time.time()
        self._stage_duration = util.time_duration(self._start_time,
                                                  this_stop_time)

    def pack_stage(self, force=False):
        if force or ((not self.param.should_skip.val) and self._should_pack and
                     (not self._is_packed)):
            self._is_packed = True

            util.chdir(ENGINE.base_dir)

            # Standard checkpoint to a file
            pack_fname = None
            if self.param.compress.val != "":
                sea.update_macro_dict({"$STAGENO": self._INDEX})
                pack_fname = self.param.compress.val
                if not pack_fname.lower().endswith((
                        PACKAGE_SUFFIX,
                        "tar.gz",
                )):
                    pack_fname += PACKAGE_SUFFIX
                self.param.compress.val = pack_fname
                self._pack_fname = pack_fname

            print_debug(f"pack_stage: pack_fname:{pack_fname}")
            # Collects all data paths for transferring.
            data_paths = set()
            all_jobs = self._rls_job + self._cap_job
            try:
                all_jobs += self._fai_job
            except AttributeError:
                pass
            for job in all_jobs:
                # Some stages just pass on a job from the previous stage
                # directly to the next stage. So we check the stage ID
                # to avoid packing the same job again.
                if job.stage._ID == self._ID:
                    if job.dir and pack_fname:
                        data_paths.add(job.dir)
                    else:
                        for e in job.output:
                            data_paths.add(e)
                        reg_file = []
                        if isinstance(job.jctrl, jobcontrol.Job):
                            reg_file.extend(job.jctrl.OutputFiles)
                            reg_file.extend(job.jctrl.InputFiles)
                            reg_file.extend(job.jctrl.LogFiles)
                            if job.jctrl.StructureOutputFile:
                                reg_file.append(job.jctrl.StructureOutputFile)
                        for fname in reg_file:
                            if not os.path.isabs(fname):
                                data_paths.add(os.path.join(job.dir, fname))

            for job in self._pas_job:
                if job.stage._ID == self._ID:
                    if job.dir and pack_fname:
                        data_paths.add(job.dir)
                    else:
                        for e in job.output:
                            data_paths.add(e)

            # Creates a stage-specific checkpoint file -- just a symbolic link
            # to the current checkpoint file.
            ENGINE.write_checkpoint()
            stage_checkpoint_fname = None
            if os.path.isfile(CHECKPOINT_FNAME):
                stage_checkpoint_fname = (
                    os.path.basename(CHECKPOINT_FNAME) + "_" + str(self._INDEX))
                shutil.copyfile(CHECKPOINT_FNAME, stage_checkpoint_fname)
                # Includes this checkpoint file for transferring.
                data_paths.add(os.path.abspath(stage_checkpoint_fname))

            if pack_fname:
                with tarfile.open(
                        pack_fname,
                        mode="w:gz",
                        format=tarfile.GNU_FORMAT,
                        compresslevel=1) as pack_file:
                    pack_file.dereference = True
                    for path in data_paths | set(self._files4pack):
                        print_debug(
                            f"pack_stage: add_to_tar: {path} exists: "
                            f"{os.path.exists(path)} cwd: {os.getcwd()}")
                        if os.path.exists(path):
                            relpath = util.relpath(path, ENGINE.base_dir)
                            pack_file.add(relpath)
                data_paths = [pack_fname]

            if ENGINE.JOBBE:
                for path in data_paths:
                    # Makes all paths relative. Otherwise jobcontrol won't
                    # transfer them!!!
                    path = util.relpath(path, ENGINE.base_dir)
                    if not path:
                        continue
                    if path in self._packed_fnames:
                        continue
                    self._packed_fnames.add(path)

                    print_debug(f"pack_stage: outputFile: {path} relpath: "
                                f"{util.relpath(path, ENGINE.base_dir)} "
                                f"cwd: {os.getcwd()}")

                    # Only when we do NOT compress files, we allow to transfer
                    # files ASAP. DESMOND-7401.
                    if (self.param.transfer_asap.val and
                            not pack_fname) and os.path.exists(path):
                        ENGINE.JOBBE.copyOutputFile(path)
                    else:
                        ENGINE.JOBBE.addOutputFile(path)
                for path in self._files4copy:
                    path = util.relpath(path, ENGINE.base_dir)
                    if path in self._packed_fnames:
                        continue
                    self._packed_fnames.add(path)
                    print_debug(f"pack_stage: files4copy: {path} relpath: "
                                f"{util.relpath(path, ENGINE.base_dir)} "
                                f"cwd: {os.getcwd()}")
                    ENGINE.JOBBE.copyOutputFile(path)
            try:
                self.time_stage()
                self._print("quiet", "Stage %d duration: %s\n" %
                            (self._INDEX, self._stage_duration))
            except TypeError:
                self._print(
                    "quiet",
                    "Stage %d duration could not be calculated." % self._INDEX)


class StructureStageBase(StageBase):
    """
    StructureStageBase can be used for stages that take in
    a path to a structure, apply some transformation,
    and then write out an updated structure.
    """

    def __init__(self, *args, **kwargs):
        self.TAG = self.NAME.upper()
        super().__init__(*args, **kwargs)

    def crunch(self):
        self._print("debug", f"In {self.NAME}.crunch")
        for pj in self._pre_job:
            jobname, jobdir = self._get_jobname_and_dir(pj)
            if not os.path.isdir(jobdir):
                os.makedirs(jobdir)

            with fileutils.chdir(jobdir):
                new_job = copy.deepcopy(pj)
                new_job.stage = weakref.proxy(self)
                new_job.output = JobOutput()
                new_job.need_host = False
                new_job.dir = jobdir
                new_job.status.set(JobStatus.SUCCESS)
                new_job.parent = pj

                output_fname = self.run(jobname, pj.output.struct_file())
                if output_fname is None:
                    new_job.status.set(JobStatus.BACKEND_ERROR)
                else:
                    new_job.output.set_struct_file(
                        os.path.abspath(output_fname))

            self._pas_job.append(new_job)

        self._pre_job = []
        self._print("debug", f"Out {self.NAME}.crunch")

    def run(self, jobname: str, input_fname: str) -> Optional[str]:
        """
        :param jobname: Jobname for this stage.
        :param input_fname: Filename for the input structure.

        :return: Filename for the output structure or `None`
            if there was an error generating the output.
        """
        raise NotImplementedError


class _get_jc_backend_when_needed(object):

    def __get__(self, obj, cls):
        jobbe = jobcontrol.get_backend()
        setattr(cls, "JOBBE", jobbe)
        return jobbe


class Engine(object):

    JOBBE = _get_jc_backend_when_needed()

    def __init__(self, opt=None):
        # This may be reset by the command options.
        self.jobname = None
        self.username = None
        self.masterhost = None
        self.host = None
        self.cpu = None
        self.inp_fname = None
        self.msj_fname = None  # The .msj file of this restarting job.
        self.MSJ_FNAME = None  # Original .msj file name.
        self.msj_content = None
        self.out_fname = None
        # Not serialized because it will be always reset at restarting
        self.set = None
        self.cfg = None
        self.cfg_content = None
        self.maxjob = None
        self.max_retry = None
        self.relay_arg = None
        self.launch_dir = None
        self.description = None
        self.loglevel = GENERAL_LOGLEVEL

        self.stage = []  # Serialized. Will be set when serialization.
        self.date = None  # Date of the original job.
        self.time = None  # Time of the original job.
        self.START_TIME = None  # Start time of the original job.
        self.start_time = None  # Start time. Will change in restarting.
        self.stop_time = None  # Stop time. Will change in restarting.
        self.base_dir = None  # Current base dir. Will change in restarting.
        # Stage No. to restart from. Will change in restarting.
        self.refrom = None
        self.base_dir_ = None  # Base dir of last job.
        self.jobid = None  # Current job ID.   Will change in restarting.
        # Job ID of the original job. Not affected by restarting.
        self.JOBID = None

        # version numbers and installation will change in restarting
        self.version = VERSION  # MSJ version.
        self.build = BUILD
        self.mmshare_ve = envir.CONST.MMSHARE_VERSION
        # Installation dir. Will change in restarting.
        self.schrodinger = envir.CONST.SCHRODINGER
        # Installation dir of the previous run. Will change in restarting.
        self.schrod_old = None
        self.old_jobnames = []

        # Will be set when probing the checkpoint file
        self.chkpt_fname = None
        self.chkpt_fh = None
        self.restart_stage = None

        self.__more_init()

        if opt:
            self.reset(opt)

    def __more_init(self):
        """
        Will be called by '__init__' and 'deseriealize'.
        This is introduced to avoid breaking the previous checkpoint file by
        adding a new attribute.
        """
        self.notify = None
        self.macro_dict = None
        self.max_walltime = None
        self.checkpoint_requested_event = None

    def __find_restart_stage_helper(self, stage):
        if stage._pre_job != [] or stage._rls_job != []:
            self.restart_stage = self.restart_stage if (
                self.restart_stage) else stage
            stage._is_shown = False
            stage._is_packed = False

    def _find_restart_stage(self):
        self.restart_stage = None
        self._foreach_stage(self.__find_restart_stage_helper)

    def _fix_job(self, stage):
        import schrodinger.application.desmond.stage as stg

        pre_job = set()

        # For multisim and simulate stages which can restart from checkpoint
        # files (as opposed to rerunning from scratch),
        # we don't reset pre_job.
        if not isinstance(stage, (stg.Simulate, stg.Multisim)):
            for job in stage._rls_job:
                pre_job.add(job.parent)

        stage._pre_job = list(pre_job) + stage._pre_job
        for job in itertools.chain(stage._pre_job, stage._cap_job,
                                   stage._pre_jobre, stage._rls_job):
            if job.dir and job.dir.startswith(self.base_dir_ + os.sep):
                job.dir = job.dir.replace(self.base_dir_, self.base_dir)
            elif job.dir and job.dir == self.base_dir_:
                # With JOB_SERVER, the job dir may not be a subdirectory
                # so replace the top-level dir too. This is needed for
                # restarting MD jobs from the production stage.
                job.dir = job.dir.replace(self.base_dir_, self.base_dir)

            job.output.update_basedir(self.base_dir_, self.base_dir)
            try:
                job.input.update_basedir(self.base_dir_, self.base_dir)
            except AttributeError:
                pass

        # FIXME: remove existing dirs for stage._rls_job?
        # We leave stage._rls_job as is to be able to restart Desmond and
        # Multisim subjobs from checkpoint files.
        if not isinstance(stage, (stg.Simulate, stg.Multisim)):
            stage._rls_job = []

        # These two lists contain the subjobs that were running when the
        # previous master job was stopped.  We clean up the two lists because
        # all unfinished subjobs have been recovered from `stage._rls_job' list.
        stage._job = []
        stage._pas_job = []

        # Fixes the "stage" attribute of all jobs of this stage. And fixes job
        # launching command.
        all_jobs = itertools.chain(stage._pre_job, stage._pre_jobre,
                                   stage._rls_job, stage._cap_job)
        for e in all_jobs:
            if isinstance(e.stage, int):
                e.stage = weakref.proxy(self.stage[e.stage])
            if isinstance(e.jlaunch_cmd, list) and isinstance(
                    e.jlaunch_cmd[0], str):
                e.jlaunch_cmd[0] = e.jlaunch_cmd[0].replace(
                    self.schrod_old, self.schrodinger)

    def restore_stages(self, print_func=print_quiet):
        # DESMOND-7934: Preserve the task stage from the checkpoint
        # if a custom msj is specified.
        checkpoint_stage_list = None
        if self.msj_fname and self.msj_content:
            checkpoint_stage_list = parse_msj(
                None, msj_content=self.msj_content, pset=self.set)

        parsee0 = "the multisim script file" if (self.msj_fname) else None
        parsee1 = "the '-set' option" if (self.set) else None
        parsee = (parsee0 + " and " + parsee1
                  if (parsee0 and parsee1) else parsee0
                  if (parsee0) else parsee1)
        if parsee:
            print_func("\nParsing %s..." % parsee)
        try:
            msj_content = None if (self.msj_fname) else self.msj_content
            stage_list = parse_msj(self.msj_fname, msj_content, self.set)
        except ParseError as a_name_to_make_flake8_happy:
            print_quiet(
                "\n%s\nParsing failed." % str(a_name_to_make_flake8_happy))
            sys.exit(1)

        if checkpoint_stage_list and stage_list:
            refrom = self.refrom
            # Find the restart stage index if not specified
            if refrom is None:
                # The first stage has the parameters we want
                # restore from the checkpoint.
                refrom = 2
                for idx, s in enumerate(self.stage):
                    if s._pre_job != [] or s._rls_job != []:
                        refrom = idx
                        break
            # Restore stages before the restart stage from the checkpoint
            # and update the ones after the checkpoint
            stage_list = checkpoint_stage_list[:refrom -
                                               1] + stage_list[refrom - 1:]

        if "task" != stage_list[0].NAME:
            print("ERROR: The first stage is not a 'task' stage.")
            sys.exit(1)

        if self.cfg:
            with open(self.cfg, "r") as fh:
                cfg = sea.Map(fh.read())
            for stage in stage_list:
                if "task" == stage.NAME:
                    if "desmond" in stage.param.set_family:
                        stage.param.set_family.desmond.update(cfg)
                    else:
                        stage.param.set_family["desmond"] = cfg

        if self.cpu:
            # Value of `self.cpu' is a string, which specifies either a single
            # integer or 3 integers separated by spaces. We must parse the
            # string to get the integers and assign the latter to stages.
            cpu_str = self.cpu.split()
            try:
                cpu = [int(e) for e in cpu_str]
                n_cpu = len(cpu)
                cpu = cpu[0] if (1 == n_cpu) else cpu
                if 1 != n_cpu and 3 != n_cpu:
                    raise ValueError(
                        "Incorrect configuration of the CPU: %s" % self.cpu)
            except ValueError:
                raise ParseError(
                    "Invalid value for the 'cpu' parameter: '%s'" % self.cpu)

            for stage in stage_list:
                if stage.NAME in [
                        "simulate",
                        "minimize",
                        "replica_exchange",
                        "lambda_hopping",
                        "vrun",
                        "fep_vrun",
                        "watermap",
                ]:
                    stage.param["cpu"] = cpu
                    stage.param.cpu.add_tag("setbyuser")
                elif stage.NAME in [
                        "mcpro_simulate",
                        "watermap_cluster",
                        "ffbuilder",
                ]:
                    stage.param["cpu"] = (cpu if (1 == n_cpu) else
                                          (cpu[0] * cpu[1] * cpu[2]))
                    stage.param.cpu.add_tag("setbyuser")

        if self.refrom and self.refrom > 0:
            pre_job_of_restart_stage = self.stage[self.refrom]._pre_jobre

        # Note that `self.refrom' is None by default and the condition
        # "None > 0" gives false.
        stage_state = [
            e.__getstate__()
            for e in (self.stage[:self.refrom] if (
                self.refrom and self.refrom > 0) else self.stage)
        ]
        self.stage = build_stages(stage_list, self.out_fname, stage_state)

        # Fixes the dir and absolute paths in `Job' objects.
        if self.refrom and self.refrom > 0:
            restart_stage = self.stage[self.refrom]
            restart_stage._pre_job = copy.copy(pre_job_of_restart_stage)
            restart_stage._pre_jobre = pre_job_of_restart_stage
            stage_list_for_jobfix = self.stage[:self.refrom + 1]
            for stg in self.stage[self.refrom + 1:]:
                stg._pre_jobre = []
            import schrodinger.application.desmond.stage as stg

            if isinstance(restart_stage, stg.DesmondExtend):
                restart_stage._pre_jobre = []
        else:
            stage_list_for_jobfix = self.stage
        for stage in stage_list_for_jobfix:
            self._fix_job(stage)

        # `self.msj_content' contains only user's settings. `stage_list[1:-1]'
        # will avoid the initial ``primer'' and the final ``concluder'' stages.
        self.msj_content = write_msj(stage_list[1:-1], to_str=True)

    def reset(self, opt):
        """
        Resets this engine with the command options.
        """
        # Resets the '_is_reset_*' attributes.
        for k in self.__dict__:
            if k[:10] == "_is_reset_":
                self.__dict__[k] = False

        if opt.refrom:
            self.refrom = opt.refrom
        if opt.jobname:
            self.jobname = opt.jobname
        if opt.user:
            self.username = opt.user
        if opt.masterhost:
            self.masterhost = opt.masterhost
        if opt.host:
            self.host = opt.host
        if opt.cpu:
            self.cpu = opt.cpu
        if opt.inp:
            self.inp_fname = os.path.abspath(opt.inp)
        if opt.msj:
            self.msj_fname = os.path.abspath(opt.msj)
        if opt.out:
            self.out_fname = opt.out
        if opt.set:
            self.set = opt.set
        if opt.maxjob is not None:
            self.maxjob = opt.maxjob
        if opt.max_retries is not None:
            self.max_retry = opt.max_retries
        if opt.relay_arg:
            self.relay_arg = sea.Map(opt.relay_arg)
        if opt.launch_dir:
            self.launch_dir = opt.launch_dir
        if opt.notify:
            self.notify = opt.notify
        if opt.encoded_description:
            self.description = cmdline.get_b64decoded_str(
                opt.encoded_description)
        if opt.quiet:
            self.loglevel = "quiet"
        if opt.verbose:
            self.loglevel = "verbose"
        if opt.debug:
            self.loglevel = "debug"
        if opt.max_walltime:
            self.max_walltime = opt.max_walltime

        self.cfg = opt.cfg

    def boot_setup(self, base_dir=None):
        """
        Set up an `Engine` object, but do not start the queue.

        :param base_dir: Set to the path for the base_dir or
            `None`, the default, to use the cwd.
        """
        global ENGINE, GENERAL_LOGLEVEL, CHECKPOINT_FNAME

        self._init_signals()

        GENERAL_LOGLEVEL = self.loglevel
        ENGINE = self
        if self.loglevel == "debug":
            _print("quiet", "Multisim debugging mode is on.\n")

        if self.description:
            _print("quiet", self.description)
        #######################################################################
        # Boots the engine.
        _print("quiet", "Booting the multisim workflow engine...")
        self.date = time.strftime("%Y%m%d") if (not self.date) else self.date
        self.time = time.strftime("%Y%m%dT%H%M%S") if (
            not self.time) else self.time
        self.start_time = time.time() if (
            not self.start_time) else self.start_time
        self.base_dir_ = self.base_dir
        self.base_dir = base_dir or os.getcwd()
        self.jobid = envir.get("SCHRODINGER_JOBID")
        self.JOBID = self.JOBID if (self.JOBID) else self.jobid
        self.maxjob = 0 if self.maxjob < 1 else self.maxjob
        self.max_retry = (self.max_retry
                          if (self.max_retry is not None) else int(
                              envir.get("SCHRODINGER_MAX_RETRIES", 3)))
        self.MSJ_FNAME = self.MSJ_FNAME if (self.MSJ_FNAME) else self.msj_fname

        # Resets these variables.
        self.version = VERSION
        self.build = BUILD
        self.mmshare_ver = envir.CONST.MMSHARE_VERSION
        self.schrod_old = self.schrodinger
        self.schrodinger = envir.CONST.SCHRODINGER

        _print("quiet", "           multisim version: %s" % self.version)
        _print("quiet", "            mmshare version: %s" % self.mmshare_ver)
        _print("quiet", "                    Jobname: %s" % self.jobname)
        _print("quiet", "                   Username: %s" % self.username)
        _print("quiet", "            Master job host: %s" % self.masterhost)
        _print("quiet", "                Subjob host: %s" % self.host)
        _print("quiet", "                     Job ID: %s" % self.jobid)
        _print("quiet", "            multisim script: %s" % os.path.basename(
            self.msj_fname if (self.msj_fname) else self.MSJ_FNAME))
        _print("quiet", "       Structure input file: %s" % os.path.basename(
            self.inp_fname))
        if self.cpu:
            _print("quiet", '            CPUs per subjob: "%s"' % self.cpu)
        else:
            _print("quiet",
                   "            CPUs per subjob: (unspecified in command)")
        _print("quiet",
               "             Job start time: %s" % time.ctime(self.start_time))
        _print("quiet", "           Launch directory: %s" % self.launch_dir)
        _print("quiet", "               $SCHRODINGER: %s" % self.schrodinger)
        sys.stdout.flush()

        self.macro_dict = {
            "$MASTERJOBNAME": self.jobname,
            "$USERNAME": self.username,
            "$MASTERDATE": self.date,
            "$MASTERTIME": self.time,
            "$SUBHOST": self.host,
        }
        sea.set_macro_dict(copy.copy(self.macro_dict))

        self.restore_stages()
        if self.chkpt_fh:

            def show_job_state(stage, engine=self):
                engine._check_stage(stage)
                if stage._final_status[0] == "1":
                    _print("quiet", "    Jobnames of failed subjobs:")
                    for job in stage._rls_job:
                        _print("quiet", "      %s" % job.jobname)

            _print("quiet", "")
            _print("quiet", "Checkpoint state:")
            self._foreach_stage(show_job_state)

        _print("quiet", "\nSummary of user stages:")
        for stage in self.stage[1:-1]:
            if stage.param.title.val:
                _print("quiet", "  stage %d - %s, %s" %
                       (stage._INDEX, stage.NAME, stage.param.title.val))
            else:
                _print("quiet", "  stage %d - %s" % (stage._INDEX, stage.NAME))
        _print("quiet", "(%d stages in total)" % (len(self.stage) - 2))

        CHECKPOINT_FNAME = os.path.join(self.base_dir,
                                        sea.expand_macro(
                                            CHECKPOINT_FNAME,
                                            sea.get_macro_dict()))

    def boot(self):
        """
        Boot the `Engine` and run the jobs.
        """
        global QUEUE
        self.boot_setup()
        max_walltime_timer = None
        if self.max_walltime:
            self.checkpoint_requested_event = threading.Event()
            _print("quiet", f"Checkpoint after {self.max_walltime} seconds.")
            max_walltime_timer = threading.Timer(
                self.max_walltime,
                lambda: self.checkpoint_requested_event.set())
            max_walltime_timer.start()

        QUEUE = queue.Queue(
            self.host,
            self.maxjob,
            max_retries=self.max_retry,
            periodic_callback=self.handle_jobcontrol_message)
        self.JOBBE.addOutputFile(os.path.basename(CHECKPOINT_FNAME))
        self.start_time = time.time()
        _print("quiet", "\nWorkflow is started now.")
        try:
            if self.START_TIME is None:
                self.START_TIME = self.start_time
                self.stage[0].start(self.inp_fname)
            else:
                if self.refrom is None or self.refrom < 1:
                    self._find_restart_stage()
                else:
                    self.restart_stage = self.stage[self.refrom]
                if self.restart_stage:
                    if self.msj_fname:
                        _print("quiet",
                               "Updating stages with the new .msj file: "
                               f"{self.msj_fname}...")
                        _print("quiet",
                               f"Stage {self.restart_stage._INDEX} and after "
                               "will be affected by the new .msj file.")

                    # We need to rerun the `set_family' functions.
                    self.run_set_family(self.restart_stage._INDEX)

                    _print("quiet", "Restart workflow from stage %d." %
                           self.restart_stage._INDEX)
                    self.restart_stage.push(None)

                    # Special treatment for restarting FEP Mapper workflow
                    if self.refrom is None:
                        next_stage = self.restart_stage._NEXT_STAGE
                        # FIXME: This creates a dependency on stages in stage.py
                        if next_stage.NAME in [
                                "calc_ddg",
                                "vacuum_report",
                                "solubility_fep_analysis",
                                "fep_absolute_binding_analysis",
                        ]:
                            next_stage._pre_job = copy.copy(
                                next_stage._pre_jobre)
                            next_stage._is_packed = False
                            next_stage = next_stage._NEXT_STAGE
                            while next_stage:
                                next_stage._pre_job = []
                                next_stage._pre_jobre = []
                                next_stage._is_packed = False
                                next_stage = next_stage._NEXT_STAGE
                else:
                    _print(
                        "quiet",
                        "The previous multisim job has completed successfully.")
                    _print("quiet",
                           "If you want to restart from a completed stage, "
                           "specify its stage number to")
                    _print("quiet", "the '-RESTART' option as: "
                           "-RESTART <checkpoint-file>:<stage_number>.")
            self.JOBBE.addMessageName("halt")
            QUEUE.run()
            exit_code = 0
            skip_stage_check = False
        except SystemExit:
            sys.exit(1)
        except StopRequest:
            restart_fname = queue.CHECKPOINT_REQUESTED_FILENAME
            with open(restart_fname, 'w') as f:
                pass
            self.JOBBE.addOutputFile(os.path.basename(restart_fname))
            exit_code = 0
            skip_stage_check = True
        except StopAndRestartRequest:
            restart_fname = queue.CHECKPOINT_WITH_RESTART_REQUESTED_FILENAME
            with open(restart_fname, 'w') as f:
                pass
            self.JOBBE.addOutputFile(os.path.basename(restart_fname))
            exit_code = 0
            skip_stage_check = True
        except Exception:
            ei = sys.exc_info()
            sys.excepthook(ei[0], ei[1], ei[2])
            _print("quiet",
                   "\n\nUnexpected exception occurred. Terminating the "
                   "multisim execution...")
            exit_code = 1
            skip_stage_check = False
        if max_walltime_timer is not None:
            max_walltime_timer.cancel()
        self.cleanup(exit_code, skip_stage_check=skip_stage_check)

    def run_set_family(self, max_stage_idx=None):
        """
        Re-run set_family for all task stages up to `max_stage_idx`.
        """
        max_stage_idx = max_stage_idx or len(self.stage)
        stage = self.stage[0]
        while stage is not None and stage._INDEX < max_stage_idx:
            if stage.NAME == "task":
                stage.set_family()
            stage = stage._NEXT_STAGE

    def handle_jobcontrol_message(self, stop=False):
        restart = False
        if self.checkpoint_requested_event is not None:
            restart = self.checkpoint_requested_event.is_set()

        if not stop and not restart and self.JOBBE.nextMessage() != "halt":
            return

        _print("quiet",
               "\nRecieved 'halt' message. Stopping job on user's request...")
        _print("quiet",
               f"{len(QUEUE.running_jobs)} subjob(s) are currently running.")

        num_killed = QUEUE.stop()
        if num_killed:
            _print("quiet",
                   f"{num_killed} subjob(s) failed to stop and were killed.")
        else:
            _print("quiet", "Subjobs stopped successfully.")
        if restart:
            raise StopAndRestartRequest()
        raise StopRequest()

    def _init_signals(self):
        # Signal handling stuff.
        for signal_name in [
                "SIGTERM", "SIGINT", "SIGHUP", "SIGUSR1", "SIGUSR2"
        ]:
            # Certain signals are not available depending on the OS.
            if hasattr(signal, signal_name):
                signal.signal(
                    getattr(signal, signal_name),
                    lambda x, stack_frame: self._handle_signal(signal_name),
                )

    def _reset_signals(self):
        signal.signal(signal.SIGTERM, signal.SIG_DFL)
        signal.signal(signal.SIGINT, signal.SIG_DFL)
        signal.signal(signal.SIGUSR1, signal.SIG_DFL)
        signal.signal(signal.SIGUSR2, signal.SIG_DFL)
        try:
            signal.signal(signal.SIGHUP, signal.SIG_DFL)
        except AttributeError:
            pass

    def _handle_signal(self, signal_name):
        self._reset_signals()
        print("\n\n%s: %s signal received" % (time.asctime(), signal_name))
        return self.handle_jobcontrol_message(stop=True)

    def _foreach_stage(self, callback):
        stage = self.stage[0]._NEXT_STAGE
        while stage._NEXT_STAGE is not None:
            callback(stage)
            stage = stage._NEXT_STAGE

    def _check_stage(self, stage, print_func=print_quiet):
        INTERPRETATION = {
            -2: "2 was skipped",
            -1: "0 not run",
            0: "0 failed",
            1: "1 partially completed",
            2: "2 completed",
        }
        subjob = ""
        if stage._is_shown:
            if stage.param.should_skip.val:
                status = INTERPRETATION[-2]
            else:
                num_done = len(stage._cap_job)
                num_active = len(stage._rls_job)
                if num_done > 0:
                    if num_active == 0:
                        status = INTERPRETATION[2]
                    else:
                        status = INTERPRETATION[1]
                        subjob = " %d subjobs failed, %d subjobs done." % (
                            num_active, num_done)
                else:
                    if num_active > 0:
                        status = INTERPRETATION[0]
                    else:
                        status = INTERPRETATION[-1]
        else:
            status = INTERPRETATION[-1]
        print_func("  Stage %d %s.%s" % (stage._INDEX, status[2:], subjob))
        stage._final_status = status

    def cleanup(self, exit_code=0, skip_stage_check=False):
        """
        :param skip_stage_check: Set to True to skip
            checking each stage to determine the exit code.
        """
        print("Cleaning up files...")
        sys.stdout.flush()
        self._foreach_stage(
            lambda stage: stage._is_shown and stage.pack_stage())

        self.stop_time = time.time()
        job_duration = util.time_duration(self.start_time, self.stop_time)
        print("\nMultisim summary (%s):" % time.ctime(self.stop_time))
        self._foreach_stage(self._check_stage)
        # FIXME: duration for this restarting?
        print("  Total duration: %s" % job_duration)

        all_gpu_times = []
        all_gpu_subjobs = []
        self._foreach_stage(lambda stage: all_gpu_times.append(stage._gpu_time))
        self._foreach_stage(
            lambda stage: all_gpu_subjobs.append(stage._num_gpu_subjobs))
        total_gpu_time = sum(all_gpu_times)
        if total_gpu_time:
            print("  Total GPU time: %s (used by %d subjob(s))" %
                  (_time_to_time_str(total_gpu_time), sum(all_gpu_subjobs)))

        final_status = []
        self._foreach_stage(
            lambda stage,
            fs=final_status: fs.append(int(stage._final_status[0])),
        )
        is_successful = min(final_status)
        if exit_code == 0:
            if is_successful == 2:
                print("Multisim completed.")
            elif is_successful == 1:
                print("Multisim partially completed.")
            else:
                print("Multisim failed.")
        else:
            print("Multisim failed.")

        if self.notify:
            recipients = (self.notify if (isinstance(self.notify, list)) else [
                self.notify,
            ])
            print("\nSending log file to the email address(es): %s" %
                  ", ".join(recipients))
            sys.stdout.flush()

            log_fname = self.jobname + "_multisim.log"
            if os.path.isfile(log_fname):
                email_message = open(log_fname, "r").read()
            else:
                email_message = "Log file: %s not found.\n"
                email_message += str(self.JOBID) + "\n"
                email_message += self.launch_dir + "\n"
                email_message += self.description + "\n"
                if exit_code == 0:
                    if is_successful == 2:
                        email_message += "Multisim completed."
                    elif is_successful == 1:
                        email_message += "Multisim partially completed."
                    else:
                        email_message += "Multisim failed."
                else:
                    email_message += "Multisim failed."

            import smtplib
            from email.mime.text import MIMEText

            composer = MIMEText(email_message)
            composer["Subject"] = "Multisim: %s" % self.jobname
            composer["From"] = "noreply@schrodinger.com"
            composer["To"] = ", ".join(recipients)

            try:
                smtp = smtplib.SMTP()
                smtp.connect()
                smtp.sendmail("noreply@schrodinger.com", recipients,
                              composer.as_string())
                smtp.close()
            except Exception:
                print("WARNING: Failed to send notification email.")
                print("WARNING: There is probably no SMTP server running on "
                      "master host.")

        if exit_code == 0 and is_successful != 2 and not skip_stage_check:
            exit_code = 1
        sys.exit(exit_code)

    def serialize(self, fh: BinaryIO):
        self.msj_fname = None
        self.set = None
        self.refrom = None
        self.chkpt_fh = None
        self.stop_time = time.ctime()
        pickle.dump(self, fh)
        PickleJar.serialize(fh)

    def serialize_bytes(self) -> bytes:
        """
        Return the binary contents of the serialized engine.
        """
        fh = BytesIO()
        self.serialize(fh)
        fh.flush()
        return fh.getvalue()

    def __getstate__(self):
        tmp_dict = copy.copy(self.__dict__)
        # Can't checkpoint event
        tmp_dict["checkpoint_requested_event"] = None
        return tmp_dict

    @staticmethod
    def deserialize(fh: BinaryIO):
        unpickler = picklejar.CustomUnpickler(fh, encoding="latin1")
        engine = unpickler.load()
        # This adds class metadata that was serialized
        # above. Without this, these values are reset to
        # the default.
        PickleJar.deserialize(fh)
        engine.chkpt_fh = fh
        engine.__more_init()
        try:
            engine.old_jobnames.append(engine.jobname)
        except AttributeError:
            engine.old_jobnames = [
                engine.jobname,
            ]
        return engine

    def write_checkpoint(self, fname=None, num_retry=10):
        if not fname:
            fname = CHECKPOINT_FNAME

        # Write to a temporary file
        fname_lock = fname + ".lock"
        with open(fname_lock, "wb") as fh:
            self.serialize(fh)

        for i in range(num_retry):
            try:
                # not available in py2
                os.replace(fname_lock, fname)
                return
            except AttributeError:
                # rename fails on Windows if the destination already exists
                if os.path.isfile(fname):
                    os.remove(fname)
                os.rename(fname_lock, fname)
            except PermissionError as err:  # TODO: DESMOND-9511

                print(i, os.getcwd(), fname_lock, fname)
                for fn in glob.glob("*"):
                    print(i, fn)
                if i == num_retry - 1:
                    raise err
                else:
                    print(f"retry {i+1} due to err: {err}")
                    time.sleep(30)


class StopRequest(Exception):
    pass


class StopAndRestartRequest(Exception):
    pass


class ParseError(Exception):
    pass


def is_restartable_version(version_string):
    version_number = [int(e) for e in version_string.split(".")]
    current = [int(e) for e in VERSION.split(".")]
    for v, c in zip(version_number[:3], current[:3]):
        if v < c:
            return False
    return True


def is_restartable_build(engine):
    from . import bld_def as bd

    bld_comm = bd.bld_types[bd.DESMOND_COMMERCIAL]

    try:
        restart_files_build = engine.build
    except AttributeError:
        return True
    return restart_files_build != bld_comm or BUILD == bld_comm


def build_stages(stage_list, out_fname=None, stage_state=[]):  # noqa: M511
    """
    Build up the stages for the job, adding the initial Primer
    and final Concluder stages.
    """
    import schrodinger.application.desmond.stage as stg

    primer_stage = stg.Primer()
    concluder_stage = stg.Concluder(out_fname)
    primer_stage.param = copy.deepcopy(stg.Primer.PARAM.DATA)
    concluder_stage.param = copy.deepcopy(stg.Concluder.PARAM.DATA)
    stage_list.insert(0, primer_stage)
    stage_list.append(concluder_stage)
    build_stagelinks(stage_list)

    for stage, state in zip(stage_list, stage_state):
        if stage.NAME == state.NAME:
            stage.__setstate__(state)

    return stage_list


def build_stagelinks(stage_list):
    for i, stage in enumerate(stage_list[1:-1]):
        # Note the list that we are traversing here is `stage_list[1:-1]'.
        stage._PREV_STAGE = stage_list[i]
        stage._NEXT_STAGE = stage_list[i + 2]
    stage_list[0]._PREV_STAGE = None
    stage_list[-1]._NEXT_STAGE = None
    try:
        stage_list[0]._NEXT_STAGE = stage_list[1]
        stage_list[-1]._PREV_STAGE = stage_list[-2]
    except IndexError:
        stage_list[0]._NEXT_STAGE = None
        stage_list[-1]._PREV_STAGE = None
    for i, stage in enumerate(stage_list):
        stage._INDEX = i


def probe_checkpoint(fname, indent=""):
    print(indent + "Probing checkpoint file: %s" % fname)
    with open(fname, "rb") as fh:
        engine = Engine.deserialize(fh)
    engine.schrod_old = engine.schrodinger

    def probe_print(s):
        print(indent + "  " + s)

    probe_print("     multisim version: %s" % engine.version)
    probe_print("      mmshare version: %s" % engine.mmshare_ver)
    probe_print("              Jobname: %s" % engine.jobname)
    probe_print("    Previous jobnames: %s" % engine.old_jobnames)
    probe_print("             Username: %s" % engine.username)
    probe_print("      Master job host: %s" % engine.masterhost)
    probe_print("          Subjob host: %s" % engine.host)
    if engine.cpu:
        probe_print('      CPUs per subjob: "%s"' % engine.cpu)
    else:
        probe_print("      CPUs per subjob: unspeficied in command")
    probe_print("  Original start time: %s" % time.ctime(engine.START_TIME))
    probe_print("      Checkpoint time: %s" % engine.stop_time)
    probe_print("        Master job ID: %s" % engine.jobid)
    probe_print(
        " Structure input file: %s" % os.path.basename(engine.inp_fname))
    probe_print(
        "  Original *.msj file: %s" % os.path.basename(engine.MSJ_FNAME))

    engine.base_dir_ = engine.base_dir
    engine.restore_stages(print_func=print_tonull)
    probe_print("\nStages:")

    engine.chkpt_fname = fname

    def show_failed_jobs(stage, engine=engine):
        engine._check_stage(stage, probe_print)
        if stage._final_status[0] == "1":
            probe_print("    Jobnames of failed subjobs:")
            for job in stage._rls_job:
                probe_print("      %s" % job.jobname)

    engine._foreach_stage(show_failed_jobs)

    print()
    print("Current version of multisim is %s" % VERSION)
    print("This checkpoint file "
          "can%sbe restarted with the current version of multisim." %
          (" " if (is_restartable_version(engine.version)) else " not "))

    return engine


def escape_string(s):
    ret = ""
    should_quote = False
    if s == "":
        return '""'

    for c in s:
        if c == '"':
            ret += '\\"'
            should_quote = True
        elif c == "'" and ret[-1] == "\\":
            ret = ret[:-1] + "'"
            should_quote = True
        else:
            ret += c
            if c <= " ":
                should_quote = True
    if should_quote:
        ret = '"' + ret + '"'
    return ret


def append_stage(
        cmj_fname,
        stage_type,
        cfg_file=None,
        jobname=None,
        dir=None,
        compress=None,
        parameter={},  # noqa: M511
):
    if not os.path.isfile(cmj_fname):
        return None

    try:
        fh = open(cmj_fname, "r")
        s = fh.read()
        fh.close()
    except IOError:
        print("error: Reading failed. file: '%s'", cmj_fname)
        return None

    if stage_type == "simulate":
        s += "simulate {\n"
    elif stage_type == "minimize":
        s += "minimize {\n"
    elif stage_type == "replica_exchange":
        s += "replica_exchange {\n"
    else:
        print("error: Unknown stage type '%s'" % stage_type)
        return None

    if cfg_file is not None:
        s += '   cfg_file = "%s"\n' % cfg_file
    if jobname is not None:
        s += '   jobname  = "%s"\n' % jobname
    if dir is not None:
        s += '   dir      = "%s"\n' % dir
    if compress is not None:
        s += '   compress = "%s"\n' % compress
    for p in parameter:
        if parameter[p] is not None:
            s += "   %s = %s\n" % (p, parameter[p])
    s += "}\n"

    return s


def concatenate_relaxation_stages(raw):
    """
    Attempts to concatenate relaxation stages by finding all adjacent
    non-production `simulate` stages. If no concatenatable stages are found,
    None is returned. Otherwise, a new raw map with the relaxation `simulate`
    stages replaced with a single `concatenate` stage is returned.

    :param raw: the raw map representing the MSJ
    :type raw: `sea.Map`

    :return: a new raw map representing the updated msj, or None.
    :rtype: `sea.Map` or `None`
    """
    new_raw = copy.deepcopy(raw)

    while True:
        stages_to_concat, insertion_point = get_concat_stages(new_raw.stage)

        if len(stages_to_concat) > 1:
            concat_stage = sea.Map()
            concat_stage.__NAME__ = "concatenate"
            concat_simulate_stages = sea.List()
            concat_simulate_stages.add_tag("setbyuser")
            for stage in stages_to_concat:
                new_raw.stage.remove(stage)
                concat_simulate_stages.append(stage)
            concat_stage.simulate = concat_simulate_stages
            concat_stage.title = concat_stage.simulate[0].title
            if 'maeff_output' in concat_stage.simulate[0].val:
                concat_stage.maeff_output = concat_stage.simulate[
                    0].maeff_output
            new_raw.stage.insert(insertion_point, concat_stage)
            new_raw.stage.add_tag("setbyuser", propagate=False)
        else:
            break

    if len(new_raw.stage) != len(raw.stage):
        return new_raw
    return None


def get_concat_stages(stages, param_attr=""):
    """
    Get a list of the stages that can be concatenated together, and the
    insertion point of the resulting concatenate stage. Stages can be
    concatenated if they are adjacent simulate stages with the same restraints,
    excluding the final production stage, which can be lambda hopping, replica
    exchange, or otherwise the last simulate stage.

    :param stages: A list of objects representing multisim stages.
            For flexibility, these can be either maps or stages. For stages, a
            param attribute must be passed that will give the location of the
            param on the stage.
    :type stages: list of (sea.Map or stage.Stage)

    :param param_attr: optional name of the attribute of the objects param, in
            case of a stage.Stage object.
    :type param_attr: str
    """
    stages_to_concat = []
    insertion_point = None
    i = last_stage = 0
    has_permanent_restrain = False
    first_simulate_param = None
    last_gcmc_block = None

    def is_restrained(param):
        return (("restrain" in param and param.restrain.val != "none") or
                bool(has_explicit_restraints(param)))

    for stage in stages:
        stage_param = getattr(stage, param_attr) if param_attr else stage
        try:
            if stage_param.should_skip.val:
                # don't let skipped stages break up otherwise consecutive
                # simulate stages
                if last_stage:
                    last_stage = i
                i += 1
                continue
        except AttributeError:
            pass
        name = stage_param.__NAME__
        # TODO we can't check stage.AssignForcefield.NAME here because we can't
        # import module stage (would be circular). That's pretty strong
        # evidence that we should move these concatenation-related functions to
        # a stage_utils module
        if name == "assign_forcefield":
            has_permanent_restrain |= is_restrained(stage_param)
        if name in _PRODUCTION_SIMULATION_STAGES:
            break
        elif name == "simulate":
            # simulate stages must be adjacent to concatenate
            if last_stage and last_stage != i - 1:
                break

            # gcmc stages can only be concatenated if gcmc blocks are identical
            # across stages
            if "gcmc" in stage_param.keys(tag="setbyuser"):
                gcmc_param = stage_param.gcmc
                if (last_gcmc_block is not None and
                        gcmc_param.val != last_gcmc_block.val):
                    break
            else:
                gcmc_param = sea.Atom("none")
            last_gcmc_block = gcmc_param

            # conditions on restrain block to concatenate
            if first_simulate_param is None:
                # we use whole `stage_param` instead of the restraints
                # themselves to (partially) support both old-style "restrain"
                # and new-style "restraints" in "Concatenate" stage (single
                # "flavor" per stage); ideally this needs to be
                # revised/tightened at some point during or after DESMOND-10079
                first_simulate_param = stage_param
            if restraints_incompatible(stage_param, first_simulate_param,
                                       has_permanent_restrain):
                break

            if insertion_point is None:
                insertion_point = i
            last_stage = i
            stages_to_concat.append(stage)
        i += 1
    # the production stage can be either the last simulate stage or one of those
    # defined in _PRODUCTION_SIMULATION_STAGES. if we've reached the last stage
    # without breaking it means the production stage is a normal simulate stage.
    # In that case, we need to remove the production stage from the list of
    # stages to concatenate
    if i == len(stages) and stages_to_concat:
        stages_to_concat.pop()

    return stages_to_concat, insertion_point


def make_empty_restraints(existing='ignore') -> sea.Map:

    outcome = sea.Map()

    outcome["existing"] = existing
    outcome["new"] = sea.List()

    return outcome


def get_restrain(sm: sea.Map) -> sea.Sea:
    try:
        return sm.get_value("restrain")
    except KeyError:
        return sea.Atom("none")


def get_restraints(sm: sea.Map) -> sea.Map:
    try:
        return sm.get_value("restraints")
    except KeyError:
        return make_empty_restraints()


def get_restraints_xor_convert_restrain(param: sea.Map) -> sea.Map:
    """
    Returns `restrains` or `restrain` (converted into `restraints`
    format) from the `param`. Raises `ValueError` if both are set.

    :param param: stage parameters
    :return: restraints block
    """

    restrain = get_restrain(param)
    if has_explicit_restraints(param):
        if restrain.val != 'none':
            raise ValueError("Concatenate stage cannot include "
                             "`restrain` and `restraints` simultaneously")
        else:
            return get_restraints(param)
    else:
        return _restraints_from_restrain(restrain)


def restraints_incompatible(param: sea.Map, initial_param: sea.Map,
                            has_permanent_restrain: bool):
    """
    Returns whether restraints parameters are compatible with switching
    during a concatenate stage. For compatibility the parameters has to
    differ from the initial ones by only a scaling factor (which can include
    zero). Furthermore, there can be no differences between restraints and
    initial restraints if `permanent_restrain` is truthy, as there is no way
    to selectively scale restraints.

    :param param: the param for a given stage
    :type param: `sea.Map`

    :param initial_param: parameters for the first stage
    :type initial_param: `sea.Map`

    :param has_permanent_restrain: whether or not there are restraints applied
        to all stages via the `permanent_restraints` mechanism
    :type has_permanent_restrain: bool

    :return: a message declaring how the restraints are incompatible, or
        an empty string if they are compatible
    :rtype: str
    """

    param_restrain = get_restrain(param)
    initial_param_restrain = get_restrain(initial_param)

    have_restrain = (param_restrain.val != "none" or
                     initial_param_restrain.val != "none")
    have_restraints = (has_explicit_restraints(param) or
                       has_explicit_restraints(initial_param))

    if have_restrain and have_restraints:
        return ("We cannot concatenate stages that mix restraints "
                "given via the `restraints` and `restrain` parameters")

    if have_restrain:
        current = _restraints_from_restrain(param_restrain)
        initial = _restraints_from_restrain(initial_param_restrain)
    else:
        current = get_restraints(param)
        initial = get_restraints(initial_param)

    return _check_restraints_compatibility(
        current=current,
        initial=initial,
        has_permanent_restrain=has_permanent_restrain)


def has_explicit_restraints(param: sea.Map):
    """
    :param param: the param for a given stage
    :return: whether or not the `restraints` block has new or existing
        restraints
    """

    if "restraints" in param:
        explicit_restraints = param.restraints
        has_new = "new" in explicit_restraints and explicit_restraints.new.val
        has_existing = ("existing" in explicit_restraints and
                        explicit_restraints.existing.val !=
                        constants.EXISTING_RESTRAINT.IGNORE)
        return has_new or has_existing
    return False


def check_restrain_diffs(restrain, initial_restrain):
    """
    See if the differences between two restrain blocks are
    concatenation-compatible, meaning they are both `sea.Map` objects and
    differ only by a force constant.

    :param restrain: the restrain block for a given stage
    :type restrain: `sea.Map` or `sea.List`

    :param initial_restrain: the restraints for the first stage
    :type initial_restrain: `sea.Map` or `sea.List`

    :return: a message declaring how the restraints are incompatible, or
        an empty string if they are compatible
    :type: str
    """
    if restrain == initial_restrain:
        return ""

    def head_if_single(o):
        return o[0] if isinstance(o, sea.List) and len(o) == 1 else o

    restrain = head_if_single(restrain)
    initial_restrain = head_if_single(initial_restrain)

    if isinstance(restrain, sea.Map) and isinstance(initial_restrain, sea.Map):
        for restrain_diff in sea.diff(restrain, initial_restrain):
            for key in restrain_diff:
                if key not in ["force_constant", "fc", "force_constants"]:
                    return ("We cannot change restraint parameters other than "
                            "the force constant between integrators")
        return ""
    elif isinstance(restrain, sea.List) or isinstance(initial_restrain,
                                                      sea.List):
        return ("We cannot change between lists of restraint parameters "
                "unless they are identical.")
    else:
        raise ValueError("restraints definition blocks expected to be "
                         "`sea.List` or `sea.Map`")


def _check_restraints_compatibility(initial: sea.Map, current: sea.Map,
                                    has_permanent_restrain: bool) -> str:
    """
    Returns whether the restrain parameters are compatible with switching
    during a concatenate stage. For compatibility, `current` has to differ
    from the `initial` by only a scaling factor (which can include zero).

    :param initial: preceding `restraints` block
    :type initial: `sea.Map`

    :param current: `restraints` block
    :type current: `sea.Map`

    :param has_permanent_restrain: whether or not there are restraints applied
        to all stages via the `permanent_restraints` mechanism
    :type has_permanent_restrain: bool

    :return: a message declaring how the restraints are incompatible, or
        an empty string if they are compatible
    :rtype: str
    """

    def is_none(r):
        return r.existing.val == 'ignore' and not r.new

    def is_retain(r):
        return r.existing.val == 'retain' and not r.new

    if current != initial and not is_retain(current):
        # there can be no difference between restraints
        # blocks if system has permanent restraints
        if has_permanent_restrain:
            return ("Subsequent simulate blocks cannot have differing "
                    "restrain blocks when permanent restraints are used")
        # we cannot go from no restrain to some restrain
        if is_none(initial):
            return ("Subsequent simulate blocks cannot have restrain block "
                    "unless the first simulate block or concatenate stage does")
        elif not is_none(current):  # none is acceptable
            if current.existing != initial.existing:
                return ("Subsequent simulate blocks cannot have "
                        "differing restraints")
            else:
                return check_restrain_diffs(current.new, initial.new)
    return ""


def _restraints_from_restrain(
        old: Union[sea.Atom, sea.List, sea.Map]) -> sea.Map:
    """
    Translates old-style restraints specification ("restrain")
    into equivalent new-style blurb. Current version is incomplete,
    limited to the features needed for the concatenation support.

    :param old: old-style "restrain" block (string, map or list)
    :return: equivalent new-style "restrains" block
    """

    outcome = make_empty_restraints(
        existing='retain' if old.val == 'retain' else 'ignore')

    if old.val in ('none', 'retain'):
        pass
    elif isinstance(old, sea.Map):
        outcome["new"].append(old)  # copies `old`
    elif isinstance(old, sea.List):
        outcome["new"].extend(old)  # copies `old`
    else:
        raise ValueError("`restrain` block must be `none`, `retain`, "
                         "`sea.Map` or `sea.List`")

    for blk in outcome["new"]:
        fc = blk["fc"] if "fc" in blk else blk.force_constant
        blk["force_constants"] = fc  # copies `fc`

    return outcome


PARAM = None  # `sea.Map' object containing the whole job's msj setting


def msj2sea(fname, msj_content=None):
    """
    Parses a file as specified by 'fname' or a string given by 'msj_content'
    (if both are given, the former will be ignored), and returns a 'sea.Map'
    object that represents the stage settings with a structure like the
    following::

      stage = [
         { <stage 1 settings> }
         { <stage 2 settings> }
         { <stage 3 settings> }
         ...
      ]

    Each stage's name can be accessed in this way: raw.stage[1].__NAME__, where
    'raw' is the returned 'sea.Map' object.
    """
    if not msj_content:
        msj_file = open(fname, "r")
        msj_content = msj_file.read()
        msj_file.close()

    raw = sea.Map("stage = [" + msj_content + "]")

    # User might set a stage as "stagename = {...}" by mistake. Raises a
    # meaningful exception when this happens.
    for s in raw.stage:
        if isinstance(s, sea.Atom) and s.val == "=":
            raise SyntaxError(
                "Stage name cannot be followed by the assignment operator: '='")

    stg = list(range(len(raw.stage)))[::2]
    for i in stg:
        try:
            s = raw.stage[i + 1]
            name = raw.stage[i].val.lower()
            s.__NAME__ = name
        except IndexError:
            raise SyntaxError("stage %d is undefined" % i + 1)
    stg.reverse()
    for i in stg:
        del raw.stage[i]
    return raw


def msj2sea_full(fname, msj_content=None, pset=""):
    raw = msj2sea(fname, msj_content)
    for i, e in enumerate(raw.stage):
        try:
            stage_cls = StageBase.stage_cls[e.__NAME__]
        except KeyError:
            raise ParseError("Unrecognized stage name: %s\n" % e.__NAME__)
        param = copy.deepcopy(stage_cls.PARAM.DATA)
        param.update(e, tag="setbyuser")
        param.__NAME__ = e.__NAME__
        param.__CLS__ = stage_cls
        raw.stage[i] = param
    if pset:
        raw.stage.insert(0, sea.Atom("dummy"))
        pset = pset.split(chr(30))
        for e in pset:
            i = e.find("=")
            if i <= 0:
                raise ParseError("Syntax error in setting: %s" % e)
            try:
                key = e[:i].strip()
                value = e[i + 1:].strip()
            except IndexError:
                raise ParseError("Syntax error in setting: %s" % e)
            if key == "" or value == "":
                raise ParseError("Syntax error in setting: %s" % e)
            raw.set_value(
                key, sea.Map("value = %s" % value).value.val, tag="setbyuser")
        del raw.stage[0]
    return raw


def parse_msj(fname, msj_content=None, pset=""):
    """
    sea.update_macro_dict must be called prior to calling this function.
    """
    try:
        global PARAM
        PARAM = msj2sea_full(fname, msj_content, pset)
        PARAM.stage.insert(0, sea.Atom("dummy"))
    except Exception as e:
        raise ParseError(str(e))

    print_debug("All settings of this multisim job...")
    print_debug(PARAM)
    print_debug("All settings of this multisim job... End")

    # Constructs stage objects and their parameters.
    stg = []
    error = ""
    for i, e in enumerate(PARAM.stage[1:], start=1):
        s = e.__CLS__()  # Creates a stage instance.
        # handle backward-compatibility issues
        s.migrate_param(e)
        s.param = e

        # FIXME: How to deal with exceptions raised by the parsing and checking
        # functions?
        ev = s.check_param()
        if ev.err != "":
            error += "Value error(s) for stage[%d]:\n%s\n" % (i, ev.err)
        if ev.unchecked_map:
            error += "Unrecognized parameters for stage[%d]: %s\n\n" % (
                i, ev.unchecked_map)
        stg.append(s)
    if error:
        raise ParseError(error)

    return stg


def write_msj(stage_list, fname=None, to_str=True):
    """
    Given a list of stages, writes out a .msj file of the name 'fname'.

    If 'to_str' is True, a string will be returned. The returned string
    contains the contents of the .msj file.
    If 'to_str' is False and not file name is provided, then this function does
    nothing.
    """
    if fname is None and to_str is False:
        return

    s = ""
    for stage in stage_list:
        s += stage.NAME + " {\n"
        s += stage.param.__str__("  ", tag="setbyuser")
        s += "}\n\n"
    if fname is not None:
        fh = open(fname, "w")
        print(s, file=fh)
        fh.close()
    if to_str:
        return s


def write_sea2msj(stage_list, fname=None, to_str=True):
    if fname is None and to_str is False:
        return

    s = ""
    for stage in stage_list:
        name = stage.__NAME__
        s += name + " {\n"
        s += stage.__str__("  ", tag="setbyuser")
        s += "}\n\n"
    if fname is not None:
        fh = open(fname, "w")
        print(s, file=fh)
        fh.close()
    if to_str:
        return s


def _collect_inputfile_from_file_list(list_, fnames):
    for v in list_:
        if isinstance(v, sea.Atom) and isinstance(v.val, str) and v.val != "":
            fnames.append(v.val)
        elif isinstance(v, sea.Map):
            _collect_inputfile_from_file_map(v, fnames)
        elif isinstance(v, sea.List):
            _collect_inputfile_from_file_list(v, fnames)
    return fnames


def _collect_inputfile_from_file_map(map, fnames):
    for k, v in map.key_value():
        if isinstance(v, sea.Atom) and isinstance(v.val, str) and v.val != "":
            fnames.append(v.val)
        elif isinstance(v, sea.Map):
            _collect_inputfile_from_file_map(v, fnames)
        elif isinstance(v, sea.List):
            _collect_inputfile_from_file_list(v, fnames)
    return fnames


def _collect_inputfile_from_list(list_, fnames):
    for v in list_:
        if isinstance(v, sea.Map):
            _collect_inputfile_from_map(v, fnames)
        elif isinstance(v, sea.List):
            _collect_inputfile_from_list(v, fnames)
    return fnames


def _collect_inputfile_from_map(map, fnames):
    for k, v in map.key_value():
        if (isinstance(v, sea.Atom) and k.endswith("_file") and
                isinstance(v.val, str) and v.val != ""):
            fnames.append(v.val)
        elif isinstance(v, sea.Map):
            if k.endswith("_file"):
                _collect_inputfile_from_file_map(v, fnames)
            else:
                _collect_inputfile_from_map(v, fnames)
        elif isinstance(v, sea.List):
            if k.endswith("_file"):
                _collect_inputfile_from_file_list(v, fnames)
            else:
                _collect_inputfile_from_list(v, fnames)
    return fnames


def collect_inputfile(stage_list):
    """
    Returns a list of file names.
    """
    fnames = []
    for stage in stage_list:
        if not stage.param.should_skip.val:
            try:
                fnames.extend(stage.collect_inputfile())
            except AttributeError:
                _collect_inputfile_from_map(stage.param, fnames)
    return fnames


class AslValidator(object):

    CTSTR = """hydrogen


  1  0  0  0  1  0            999 V2000
   -1.6976    2.1561    0.0000 C   0  0  0  0  0  0
M  END
$$$$
"""
    CT = None

    def __init__(self):
        self.invalid_asl_expr = []

    def is_valid(self, asl):
        if AslValidator.CT is None:
            import schrodinger.structure as structure

            AslValidator.CT = next(
                structure.StructureReader.fromString(
                    AslValidator.CTSTR, format="sd"))

        import schrodinger.structutils.analyze as analyze

        try:
            analyze.evaluate_asl(AslValidator.CT, asl)
        except mm.MmException:
            return False
        return True

    def validate(self, a):
        if isinstance(a, sea.Atom):
            v = a.val
            if (isinstance(v, str) and v[:4].lower() == "asl:" and
                    not self.is_valid(v[4:])):
                self.invalid_asl_expr.append(v)
        elif isinstance(a, sea.Sea):
            a.apply(lambda x: self.validate(x))


def validate_asl_expr(stage_list):
    """
    Validates all ASL expressions that start with the "asl:" prefix.
    """
    validator = AslValidator()
    for stage in stage_list:
        if not stage.param.should_skip.val:
            stage.param.apply(lambda x: validator.validate(x))
    return validator.invalid_asl_expr


# - Registered functions should share this prototype: foo( stage, PARAM, arg ),
# and should return a boolean value, where PARAM is a sea.Map object in the
# global scope of this module.

_operator = {}


def reg_checking(name, func):
    _operator[name] = func