Source code for schrodinger.tasks.tasks

"""
A task represents a block of work that has a defined input and output and runs
without user intervention. Different task classes share a common external API
but have different implementations for defining and executing the work, such as
blocking calls, threads, subprocesses, or job control (see jobtasks).

To define a task, follow these basic instructions:

1. Choose a task class to subclass. The choice of task class is primarily
dictated by how the task needs to run - thread, subprocess, job, etc. See the
Task Class Selection Guide for help.

2. Override the input and output params. The task.input and task.output params
may be of any Param type, including CompoundParam (typical). For CompoundParams,
either use an existing class to override task.input, OR define a nested class
named Input within the task. Doing so will automatically override task.input.
The same goes for task.output. Example::

    class FooTask(tasks.ThreadFunctionTask):
        input = AtomPair()  # AtomPair is an existing CompoundParam subclass

        # This will magically override FooTask.output = Output()
        class Output(parameters.CompoundParam):
            charge: float
            processed_atom_pair: AtomPair

3. Define the work of the task. This is done differently for different task
classes, but generally involves overriding a method to either provide python
logic directly as the work to be done or to construct a command line with the
appropriate arguments that will be invoked via the appropriate mechanism for the
task type.

Once a task is defined, it can be instantiated, set up, and started::

    task = FooThreadTask()
    task.input.x = 3
    task.input.y = 4
    task.start()
    assert task.status is tasks.Status.RUNNING
    task.wait()
    assert task.status is tasks.Status.DONE
    print(task.output)

.. warning::
    `wait()` executes a local event loop, so it should not be called directly
    from a GUI - see PANEL-18317 for discussion. `wait()` is safe to call
    inside a subprocess or job (e.g. if a jobtask spawns child tasks).
    Run `git grep "task[.]wait("` to see safe examples annotated with "# OK".

==================
Pre/postprocessors
==================

Tasks support pre/post processing functions. These can either be methods in the
class that are decorated with the preprocessor or postprocessor decorators, or
external functions that are added to a task instance. Example::

    class MyTask(tasks.BlockingFunctionTask):
        @tasks.preprocessor
        def checkInput(self):
            if self.input.x <0:
                return False, 'x must be a nonnegative number.'

For more information, see the module-level preprocessor and postprocessor
decorators as well as the start(), preprocessors(), and addPreprocessor()
methods of AbstractTask.

========================
Task directory (taskdir)
========================

Tasks have a concept of a taskdir. While the task framework will never actually
chdir into a different directory, the task provides functions for specifying
and accessing a directory that is considered that task's directory by
convention. Subprocesses started by the task will use the taskdir as their
working directory.

To specify a taskdir, override AbstractTask.DEFAULT_TASKDIR_SETTING or use
task.specifyTaskDir(). Example::

    class MyTask(tasks.BlockingFunctionTask):
        DEFAULT_TASKDIR_SETTING = tasks.AUTO_TASKDIR

    task = MyTask()
    task.specifyTaskDir('foo_dir')

The taskdir is created during preprocessing. Once the taskdir is created, use
task.getTaskDir() and task.getTaskFilename() when reading and writing files for
the task. Example::

    class MyTask(tasks.SubprocessCmdTask):
        @tasks.preprocessor(order=tasks.AFTER_TASKDIR)
        def writeInputFiles(self):
            with open(self.getTaskFilename('foo_data.txt'), 'w') as f:
                f.write(self.input.foo_data)

For more details on taskdir, see task.specifyTaskDir() task.getTaskDir().

==========================
Input/Output File Handling
==========================

To specify a task input file or folder, use the `TaskFile` or `TaskFolder`
classes as a subparam on the task.input param. If the task runs its unit
of work on a different machine or process, the input files/folders will
automatically be copied to the right location on the compute host. The
path to the `TaskFile`/`TaskFolder` will also be updated so it points
to the right location, regardless of when or where it's accessed.

`TaskFile`/`TaskFolder`s may be nested under the input param in supported
container types. Supported container types are:
There are few restrictions on how nested you can define your
`TaskFile/TaskFolder` on the input param. For example, if you have
a variable number of input files, you can define the input with a list::

    - List
    - Dict
    - Set
    - Tuple
    - CompoundParam

For example::

    class Input(parameters.CompoundParam):
        receptor_filename: TaskFile
        ligand_filenames: List[TaskFile]

Task output files/folders behave in the exact same way as task input
files/folders except they're defined as `TaskFile` or `TaskFolder` on the
output param.
"""
import contextlib
import copy
import enum
import inspect
import os
import pathlib
import pickle
import random
import shutil
import string
import sys
import tempfile
import traceback
import typing
from collections import namedtuple
from datetime import datetime
from typing import List

from schrodinger.models import json
from schrodinger.models import jsonable
from schrodinger.models import parameters
from schrodinger.models import paramtools
from schrodinger.Qt import QtCore
from schrodinger.Qt.QtCore import QProcess
from schrodinger.tasks import cmdline
from schrodinger.tasks import _filepaths
from schrodinger.ui.qt.appframework2 import application
from schrodinger.utils import fileutils
from schrodinger.utils import funcchains
from schrodinger.utils import imputils
from schrodinger.utils import qt_utils
from schrodinger.utils import scollections
from schrodinger.utils import subprocess as subprocess_utils


class TaskDirNotFoundError(RuntimeError):
    pass


class TaskFile(str):
    """
    See the "Input/Output File Handling" section of the module docstring
    for information.
    """


class TaskFolder(str):
    """
    See the "Input/Output File Handling" section of the module docstring
    for information.
    """


#===============================================================================
# Task pre/post processing
#===============================================================================

# Ordering constants
BEFORE_TASKDIR = -2000  # Runs preprocessor before taskdir creation
AFTER_TASKDIR = 0  # Runs preprocesser after taskdir creation (default)
_TASKDIR_ORDER = -1000  # Order for taskdir creation
_WRITE_JSON_ORDER = 10000

# Taskdir settings
AUTO_TASKDIR = object()
TEMP_TASKDIR = object()


class _ProcessorMarker(funcchains.FuncChainMarker):

    def customizeFuncResult(self, func, result):
        return _cast_processing_result(result, func)


"""
The preprocessor and post processor decorators can be used to mark functions to
be run before/after a task. These decorators may be used on task methods both
with or without args::


    class MyTask(tasks.BlockingFunctionTask):

        @tasks.preprocessor  # Use without args
        def checkInput(self):
            pass

        @tasks.preprocessor(order=tasks.AFTER_TASKDIR)  # Use with args
        def writeInput(self):
            pass

The optional order argument is a float that is used as a sorting key to
determine the order of execution of pre/postprocessors. It's recommended that
one of the module level ordering constants is used, with +/- increments to fine-
tune the order. For example::

    class MyTask(tasks.BlockingFunctionTask):

        @tasks.preprocessor(order=tasks.AFTER_TASKDIR)
        def checkInput(self):
            pass

        def writeInput(self, order=tasks.AFTER_TASKDIR+1):
            pass


External functions may also be decorated. In this case, the function must also
be added to a task instance. Example::

    @tasks.preprocessor(order=tasks.AFTER_TASKDIR)
    def foo()
        pass

    task = MyTask()
    task.addPreprocessor(foo)

Pre/postprocessors may optionally return a ProcessingResult. As a convenience,
a (passed, message) tuple return value will automatically be cast into a
ProcessingResult by the decorator. Examples::

        @tasks.preprocessor
        def checkInput(self):
            if self.input.x < 0:  # Preprocessing failure
                return False, 'x must be nonnegative.'

            if self.input.x > 100:  # Preprocessing warning
                return True, 'Large values of x may take a long time.'

            return True  # Pass (equivalent to returning None)

Returning False without a message will be a silent failure.

"""
preprocessor = _ProcessorMarker('preprocessor')
postprocessor = _ProcessorMarker('postprocessor')


class ProcessingResult:
    """
    A general-purpose return value for task pre/post processors
    """

    def __init__(self, passed, message=None):
        """
        :param passed: Whether the result is considered to be passing
        :type passed: bool

        :param message: A message for this result
        :type message: str

        """
        self.func = None
        self.passed = passed
        self.message = message

    def processorName(self):
        if self.func is None:
            return
        return self.func.__name__

    def __bool__(self):
        return self.passed

    def __repr__(self):
        return str(self)

    def __str__(self):
        msg = ''
        if self.func is not None:
            msg += f'{self.func.__name__}: '
        if self.passed and not self.message:
            msg += 'Passed'
        elif self.passed and self.message:
            msg += f'WARNING - {self.message}'
        elif not self.passed and not self.message:
            msg += 'FAILED'
        elif not self.passed and self.message:
            msg += f'FAILED - {self.message}'
        return msg


class CallingContext(enum.IntEnum):
    CMDLINE = enum.auto()
    GUI = enum.auto()


def _cast_processing_result(result, func=None):
    """
    Convert the return value of a pre/post-processor to a ProcessingResult,
    if necessary. If a func is supplied, it will be recorded in the
    ProcessingResult.

    :param func: the function that produced this result

    :param result: the return value of the pre/post-processor. This can be
        represented in one of three ways: (1) True/False for passsed (2) tuple
        of (passed, message) (3) a ProcessingResult instance
    :type result: bool, tuple, or ProcessingResult

    :return: the wrapped return value
    :rtype: ProcessingResult
    """

    if result is None:
        result = True

    if isinstance(result, bool):
        result = ProcessingResult(result)

    if isinstance(result, tuple):
        result = ProcessingResult(*result)

    if isinstance(result, ProcessingResult):
        result.func = func
        return result

    raise TypeError(f'Return value should be bool or tuple. Got {result}')


#===============================================================================
# Task exceptions
#===============================================================================


class TaskFailure(Exception):
    """
    Exception raised when a task fails for reasons other than an unexpected
    error occuring during execution.
    """

    # This class intentionally left blank.


class TaskKilled(TaskFailure):
    pass


class _TaskTestTimeout(TaskFailure):
    """
    Exception raised if a task times out under pytest
    """
    pass


#===============================================================================
# Status
#===============================================================================

FailureInfo = namedtuple('FailureInfo', 'exception traceback message')


class FailureInfo(FailureInfo):

    def __str__(self):
        if self.exception is None:
            return 'No failure recorded.'
        else:
            return f'Task failure:\n{self.traceback}\n{self.exception}'


class Status(jsonable.JsonableIntEnum):
    WAITING, RUNNING, FAILED, DONE = range(4)


FINISHED_STATUSES = {Status.FAILED, Status.DONE}
STARTABLE_STATUSES = {Status.WAITING, Status.FAILED, Status.DONE}
NON_RUNNING_STATUSES = {Status.WAITING, Status.FAILED, Status.DONE}


def _wait(task, timeout=None):
    """
    Block until the task is finished executing or `timeout` seconds have
    passed.

    :param timeout: Amount of time in seconds to wait before timing out. If
        None or a negative number, this method will wait until the task
        is finished.
    :type timeout: NoneType or int

    :return: whether the task finished during the wait. Returns False if wait
        timed out
    """
    return _wait_for(task, NON_RUNNING_STATUSES, timeout=timeout)


@application.require_application
def _wait_for(task, end_statuses, timeout=None):
    """
    Block until a task reaches one of the specified statuses. Blocks using a
    local event loop.

    :param task: the task to wait on
    :param end_statuses: the task statuses to wait for
    :param timeout: an optional timeout in seconds

    :return: whether the wait succeeded. Returns False if wait timed out
    """
    if task.status in end_statuses:
        return True

    event_loop = QtCore.QEventLoop()

    def check_status(status):
        if status in end_statuses:
            event_loop.exit()

    def time_out_event_loop():
        event_loop.exit()

    if timeout is not None:
        QtCore.QTimer.singleShot(timeout * 1000, time_out_event_loop)
    task.statusChanged.connect(check_status)
    event_loop.exec()
    return task.status in end_statuses


#===============================================================================
# Abstract Task
#===============================================================================


@qt_utils.add_enums_as_attributes(Status)
@qt_utils.add_enums_as_attributes(CallingContext)
class AbstractTask(funcchains.FuncChainMixin, parameters.CompoundParam):
    input: parameters.CompoundParam
    output: parameters.CompoundParam
    status: Status
    name: str
    progress: int
    max_progress: int
    progress_string: str
    calling_context = parameters.NonParamAttribute()
    failure_info = parameters.NonParamAttribute()

    # Convenience Signals
    taskDone = QtCore.pyqtSignal()
    taskStarted = QtCore.pyqtSignal()
    taskFailed = QtCore.pyqtSignal()

    DEFAULT_TASKDIR_SETTING = None
    AUTO_TASKDIR = AUTO_TASKDIR  # Add these to the class namespace for
    TEMP_TASKDIR = TEMP_TASKDIR  # convenience.
    _all_task_tempdirs = []

    _is_debug_enabled = False

    #===========================================================================
    # Construction
    #===========================================================================

    @classmethod
    def runFromCmdLine(cls):
        return cmdline.run_task_from_cmdline(cls)

    @classmethod
    def fromJsonFilename(cls, filename):
        with open(filename) as f:
            json_dict = json.load(f)
            task = cls.fromJson(json_dict)
        return task

    def initConcrete(self):
        super().initConcrete()
        self.statusChanged.connect(self.__onStatusChanged)
        self.failure_info = None
        self._taskdir = None
        self._taskdir_setting = self.DEFAULT_TASKDIR_SETTING
        self.calling_context = None
        self._in_preprocessing = False
        self._interruption_requested = False
        self._tempdir = None

    def initializeValue(self):
        """
        @overrides: parameters.CompoundParam
        """
        if not self.name:
            self.name = self.__class__.__name__

    #===========================================================================
    # Abstract Methods
    #===========================================================================
    INTERRUPT_ENABLED = False

    def run(self):
        # Implementations of run are responsible for directly calling
        # `_finish` or connecting a signal to `_finish`.
        raise NotImplementedError()

    def kill(self):
        """
        Implementations are responsible for immediately stopping the task. No
        threads or processes should be running after this method is complete.

        This method should be called sparingly since in many contexts the task
        will be forced to terminate without a chance to clean up or free
        resources.
        """
        raise NotImplementedError()

    #===========================================================================
    # Public API
    #===========================================================================

    def start(self, skip_preprocessing=False):
        """
        This is the main method for starting a task. Start will check if a task
        is not already running, run preprocessing, and then run the task.

        Failures in preprocessing will interrupt the task start, and the task
        will never enter the RUNNING state.

        :param skip_preprocessing: whether to skip preprocessing. This can be
            useful if preprocessing was already performed prior to calling
            start.
        :type skip_preprocessing: bool

        """
        self.printDebug('start')
        if not self.isStartable():
            raise RuntimeError(
                f"Can't start a task with status {self.status.name}")
        if not self.name:
            raise RuntimeError("Can't start a task with name: ''")
        self.status = Status.WAITING
        self._interruption_requested = False
        self.failure_info = None
        if not skip_preprocessing:
            with self.guard():
                self.runPreprocessing(callback=self._processingCallback)
        if self.failure_info is not None:
            self.status = self.FAILED
            return
        self.status = self.RUNNING
        with self.guard():
            self.run()
        if self.failure_info is not None:
            self.status = self.FAILED
            return

    def wait(self, timeout=None):
        r"""
        Block until the task is finished executing or `timeout` seconds have
        passed.

        .. warning::
            This should not be called directly from GUI code - see PANEL-18317.
            It is safe to call inside a subprocess or job. Run
            `git grep "task\.wait("` to see safe examples annotated with "# OK".

        :param timeout: Amount of time in seconds to wait before timing out. If
            None or a negative number, this method will wait until the task
            is finished.
        :type timeout: NoneType or int
        """
        # Call the module-level wait function
        self.printDebug(f'wait({timeout})')
        retval = _wait(self, timeout)
        self.printDebug('wait done')
        return retval

    def isRunning(self):
        return self.status is self.RUNNING

    def isStartable(self):
        return self.status in STARTABLE_STATUSES

    def specifyTaskDir(self, taskdir_spec):
        """
        Specify the taskdir creation behavior. Use one of the following options:

        A directory name (string). This may be a relative or absolute path

        None - no taskdir is requested. The task will use the CWD as its taskdir

        AUTO_TASKDIR - a new subdirectory will be created in the CWD using the
        task name as the directory name.

        TEMP_TASKDIR - a temporary directory will be created in the schrodinger
        temp dir. This directory is cleaned up when the task is deleted.

        :param taskdir_spec: one of the four options listed above
        """
        if ((self._in_preprocessing and self._taskdir is not None) or
                self.isRunning()):
            raise RuntimeError('Taskdir specification may not be changed once '
                               'the taskdir is created.')
        self._taskdir_setting = taskdir_spec
        self._taskdir = None

    def taskDirSetting(self):
        """
        Returns the taskdir spec. See specifyTaskDir() for details.
        """
        return self._taskdir_setting

    def getTaskDir(self):
        """
        Returns the full path of the task directory. This is only available if
        the task directory exists (after creation of the taskdir or, if no task
        dir is specified, any time).
        """
        if self._taskdir_setting is None:
            return os.getcwd()
        if isinstance(self._taskdir_setting, (str, pathlib.Path)):
            if os.path.exists(self._taskdir_setting):
                self._taskdir = os.path.abspath(self._taskdir_setting)
        if self._taskdir is None:
            raise TaskDirNotFoundError(
                'Taskdir has not been created yet. Consider '
                'moving this call to an AFTER_TASKDIR '
                'preprocessor.')

        return self._taskdir

    def getTaskFilename(self, fname):
        """
        Return the appropriate absolute path for an input or output file in the
        taskdir.
        """
        parent_dir = self.getTaskDir()
        return os.path.join(parent_dir, fname)

    def addPreprocessor(self, func, order=None):
        """
        Adds a preproceessor function to this task instance. If the function has
        been decorated with @preprocessor, the order specified by the decorator
        will be used as the default.

        :param func: the function to add

        :param order: the sorting order for the function relative to all other
            preprocessors. Takes precedence over order specified by the
            preprocessor decorator.
        :type order: float
        """
        if order is None:
            decorated_order = funcchains.get_marked_func_order(func)
            if decorated_order is None:
                order = AFTER_TASKDIR
            else:
                order = decorated_order
        self.addFuncToGroup(func, preprocessor, order)

    def addPostprocessor(self, func, order=0):
        """
        Adds a postproceessor function to this task instance. If the function
        has been decorated with `@postprocessor`, the order specified by the
        decorator will be used.

        :param func: the function to add
        :type func: typing.Callable
        :param order: the sorting order for the function relative to all other
            preprocessors. Takes precedence over order specified by the
            preprocessor decorator.
        :type order: float
        """

        self.addFuncToGroup(func, postprocessor, order)

    def preprocessors(self):
        """
        :return: A list of preprocessors (both decorated methods on the task and
            external functions that have been added via
            addPreprocessor)
        """
        return self.getFuncGroup(preprocessor)

    def postprocessors(self):
        """
        :return: A list of postprocessors, both decorated methods on the task
            and external functions that have been added via `addPostprocessor()`
        :rtype: list[typing.Callable]
        """

        return self.getFuncGroup(postprocessor)

    def reset(self, *args, **kwargs):
        if not args and not kwargs:
            if self.status is self.RUNNING:
                raise RuntimeError("Can't reset a task while it's running")
            elif self.status is self.FAILED:
                self.failure_info = None
        super().reset(*args, **kwargs)

    def replicate(self):
        """
        Create a new task with the same input and settings (but no output)
        """
        old_task = self
        new_task = self.__class__()
        new_task.specifyTaskDir(old_task.taskDirSetting())
        old_preprocess_callbacks = old_task.getAddedFuncs(preprocessor)
        for func, order in old_preprocess_callbacks:
            new_task.addPreprocessor(func, order)
        for func, order in old_task.getAddedFuncs(postprocessor):
            new_task.addPostprocessor(func, order)
        if isinstance(new_task.input, parameters.CompoundParam):
            new_task.input.setValue(old_task.input)
        else:
            new_task.input = old_task.input
        return new_task

    def isDebugEnabled(self):
        return self._is_debug_enabled

    def printDebug(self, *args):
        if not self.isDebugEnabled():
            return
        info = self.getDebugString()
        print(f'{info}:', *args)

    def getDebugString(self):
        return f'{datetime.now()} {self.name}-{self.status.name}'

    def requestInterruption(self):
        """
        Request the task to stop.

        To enable this feature, subclasses should periodically check whether an
        interruption has been requested and terminate if it has been. If such
        logic has been included, `INTERRUPT_ENABLED` should be set to `True`.
        """
        if not self.INTERRUPT_ENABLED:
            raise RuntimeError("Interruption is not enabled for this task.")
        self._interruption_requested = True

    def isInterruptionRequested(self):
        return self._interruption_requested

    #===========================================================================
    # Internal methods
    #===========================================================================
    @preprocessor(order=BEFORE_TASKDIR - 1000)
    def _validateTaskName(self):
        is_valid = fileutils.is_valid_jobname(self.name)
        if not is_valid:
            return False, fileutils.INVALID_JOBNAME_ERR % self.name

    def __copy__(self):
        task_copy = super().__copy__()
        if task_copy.status is task_copy.RUNNING:
            task_copy.status = task_copy.WAITING
        return task_copy

    def __deepcopy__(self, memo):
        task_copy = super().__deepcopy__(memo)
        if task_copy.status is task_copy.RUNNING:
            task_copy.status = task_copy.WAITING
        return task_copy

    def __eq__(self, other):
        """
        Tasks compare equal if all params excluding the status are equal.
        """
        is_eq = super().__eq__(other)
        if is_eq:
            return True
        else:
            if isinstance(other, self.__class__):
                self_copy = copy.copy(self)
                other_copy = copy.copy(other)
                return self_copy.toDict() == other_copy.toDict()
            return False

    def __onStatusChanged(self, status):
        if status is self.RUNNING:
            self.taskStarted.emit()
        elif status is self.FAILED:
            self.taskFailed.emit()
        elif status is self.DONE:
            self.taskDone.emit()

    def _processingCallback(self, result):
        if not result.passed:
            self._recordFailure(TaskFailure(result.message))
        return result.passed

    def _defaultResultCallback(self, result):
        """
        @overrides: funcchains.FuncChainMixin
        """
        if not result.passed:
            raise TaskFailure(result.message)
        return True

    @typing.final
    def runPreprocessing(self, callback=None, calling_context=None):
        """
        Run the preprocessors one-by-one. By default, any failing preprocessor
        will raise a TaskFailure exception and terminate processing. This
        behavior may be customized by supplying a callback function which will
        be called after each preprocessor with the result of that preprocessor.

        This method is "final" so that all preprocessing logic will be enclosed
        in the try/finally block.

        :param callback: a function that takes result and returns a bool that
            indicates whether to continue on to the next preprocessor

        :param calling_context: specify a value here to indicate the context
            in which this preprocessing is being called. This value will be
            stored in an instance variable, self.calling_context, which can be
            accessed from any preprocessor method on this task. Typically this
            value will be either self.GUI, self.CMDLINE, or None, but any value
            may be supplied here and checked for in the preprocessor methods.
            self.calling_context always reverts back to None at the end of
            runPreprocessing.

        """
        self.printDebug('runPreprocessing')
        self._in_preprocessing = True
        self._taskdir = None
        self.calling_context = calling_context
        try:
            return self.processFuncChain(preprocessor, result_callback=callback)
        finally:
            self.calling_context = None
            self._in_preprocessing = False
            self.printDebug('done preprocessing')

    def runPostprocessing(self, callback=None):
        return self.processFuncChain(postprocessor, result_callback=callback)

    def _makeTempTaskDir(self):
        parent_dir = fileutils.get_directory_path(fileutils.TEMP)
        self._tempdir = tempfile.TemporaryDirectory(dir=parent_dir)
        self._taskdir = self._tempdir.name
        self._registerTempDir(self._tempdir)

    def _registerTempDir(self, tmpdir):
        """
        Register a tempdir to the class. This is used to clean up all tempdirs
        in unit tests.
        """
        self._all_task_tempdirs.append(tmpdir)

    def _makeDir(self, taskdir):
        os.makedirs(taskdir)

    @preprocessor(order=_TASKDIR_ORDER)
    def _createTaskDir(self):
        """
        Create a task directory for running the task in.
        """
        if self._taskdir_setting is TEMP_TASKDIR:
            self._makeTempTaskDir()
            return True

        cwd = os.getcwd()
        if self._taskdir_setting is None:
            self._taskdir = cwd
            return True

        if self._taskdir_setting is AUTO_TASKDIR:
            taskdir = os.path.abspath(self.name)
        else:
            taskdir = os.path.abspath(self._taskdir_setting)
        self._taskdir = taskdir

        try:
            self._makeDir(taskdir)
        except FileExistsError:
            if self._taskdir_setting is not AUTO_TASKDIR:
                # Allow specified path to already exist
                return True
            if self.calling_context is self.GUI:
                return (True, f'Task directory {self._taskdir} already exists. '
                        'Contents will be overwritten. Continue?')
            return False, f"Task directory {self._taskdir} already exists."
        return True

    def _recordFailure(self, exception, exc_traceback_str=None):
        """
        Store the exception in `failure_info` and set status to failed
        """
        if self.failure_info is not None:
            return
        message = str(exception)
        self.failure_info = FailureInfo(
            exception=exception, traceback=exc_traceback_str, message=message)
        if exc_traceback_str:
            tb = exc_traceback_str
        else:
            tb = ''
        print(
            f'{tb}{repr(self)}> failed: {type(exception).__name__}("{message}")'
        )

    @contextlib.contextmanager
    def guard(self):
        """
        Context manager that saves any Exception raised inside
        """
        try:
            yield
        except Exception:
            err_type, exc_value, exc_traceback = sys.exc_info()
            if err_type is TaskFailure:
                exc_traceback_str = None
            else:
                exc_traceback_str = ''.join(
                    traceback.format_tb(exc_traceback)[-10:])
                # We have to delete the traceback to prevent a circular ref.
                # See the `traceback` module documentation for additional info.
                del exc_traceback
            self._recordFailure(exc_value, exc_traceback_str)

    def _finish(self):
        self.printDebug('_finish')
        if self.failure_info is not None:
            self.status = Status.FAILED
            return
        with self.guard():
            self.runPostprocessing(callback=self._processingCallback)
        if self.failure_info is not None:
            self.status = Status.FAILED
            return
        self.status = Status.DONE

    def __repr__(self):
        if self.isAbstract():
            return super().__repr__()
        return (f'<{self.__class__.__name__}: {self.name} - '
                f'{Status(self.status).name}>')  # sometimes status is an int

    @classmethod
    def _populateClassParams(cls):
        cls._convertNestedClassToDescriptor('Input', 'input')
        cls._convertNestedClassToDescriptor('Output', 'output')
        super()._populateClassParams()

    @classmethod
    def _convertNestedClassToDescriptor(cls, nested_class_name,
                                        descriptor_name):
        """
        If a nested class of the specified name is defined, this method will
        instantiate that class and set that instance as a class variable. Ex:

            class Foo:
                class Bar:
                    pass

        Calling Foo._convertNestedClassToDescriptor('Bar', 'bar') will do the
        equivalent of putting bar = Bar() inside the Foo class. Typically used
        to instatiate Param classes as descriptors on the class.

        :param nested_class_name: the name of the class to look for
        :param descriptor_name: the name that the descriptor instance to be
            added to the class.
        """
        if nested_class_name in cls.__dict__:
            nested_class = getattr(cls, nested_class_name)
            desc = nested_class()
            desc.__set_name__(cls, descriptor_name)
            setattr(cls, descriptor_name, desc)


#===============================================================================
# Task interfaces
#===============================================================================


class _AbstractFunctionTask(AbstractTask):

    def run(self):
        self._runMainFunction()

    def _guardedMain(self):
        with self.guard():
            self.mainFunction()

    def _runMainFunction(self):
        raise NotImplementedError()

    def mainFunction(self):
        raise NotImplementedError()


class AbstractCmdTask(AbstractTask):

    def run(self):
        cmd = self.makeCmd()
        for idx, arg in enumerate(cmd):
            if not isinstance(arg, str):
                msg = (f"makeCmd() must return a string of lists. Item {idx} "
                       f"is type {type(arg)}.")
                raise ValueError(msg)
        self.runCmd(cmd)

    def runCmd(self, cmd):
        raise NotImplementedError()

    def makeCmd(self):
        return []


class AbstractComboTask(AbstractCmdTask, _AbstractFunctionTask):
    """
    Subclasses should only define params inside of input or output. Top-level
    params defined in subclasses do NOT get serialized between the frontend and
    backend task instances. Thus, any modifications of new top-level params in
    the backend (i.e. mainFunction) will not have any effect on the rehydrated
    frontend task.
    """
    _run_as_backend: bool = False
    ENTRYPOINT = 'combotask_entry_point.py'

    # Private params, not for use by child classes
    _task_module: str
    _task_class: str
    _task_script: str
    _failure_info: str = None
    _failure_tb: str = None
    _combo_id: str = None

    # Only these params will be serialized in frontend/backend conversions
    _FRONTEND_TO_BACKEND_PARAMS = [
        'name', 'input', '_run_as_backend', '_task_module', '_task_class',
        '_task_script', '_combo_id'
    ]
    _BACKEND_TO_FRONTEND_PARAMS = [
        'output', 'status', '_run_as_backend', '_failure_info', '_failure_tb'
    ]

    def _regenerateComboId(self):
        """
        Generate a new combo id for this task. A combo id is a random string
        that is used to prevent tasks with the same task name from overwriting
        each other's combo files (i.e. _frontend.json and _backend.json).
        """
        alphabet = string.ascii_lowercase + string.digits
        self._combo_id = ''.join(random.choices(alphabet, k=12))

    def initializeValue(self):
        super().initializeValue()
        if self._combo_id is None:  # no combo id from rehydrated json file
            self._regenerateComboId()

    @property
    def json_filename(self):
        return self.getTaskFilename(
            f'.{self.name}_{self._combo_id}_frontend.json')

    @property
    def json_out_filename(self):
        return self.getTaskFilename(
            f'.{self.name}_{self._combo_id}_backend.json')

    def start(self, *args, **kwargs):
        """
        @overrides: AbstractTask
        """
        if self.isBackendMode():
            return self.runBackend()
        super().start(*args, **kwargs)

    def isBackendMode(self):
        return self._run_as_backend

    def makeCmd(self):
        """
        @overrides: AbstractCmdTask
        """
        cmd = [
            get_schrodinger_run(), self.ENTRYPOINT, '--task_json',
            self._getFrontEndJsonArg()
        ]
        return cmd

    def _getFrontEndJsonArg(self):
        return self.json_filename

    def _writeFrontendJsonFile(self):
        task_module = self._get_module()
        backend_task = copy.deepcopy(self)
        # deepcopy of a compoundparam only copies params
        backend_task._taskdir = self._taskdir
        if task_module == '__main__':
            print(f'{self} is defined outside the build. Will attempt to copy '
                  'script to backend dir to run. If the script needs to import '
                  'other files, the task will still fail. In this case, move '
                  'the script and its dependencies to an importable location.')
            cp_filename = self._copyScriptToBackend()
            backend_task._task_script = os.path.basename(cp_filename)
        backend_task._task_module = task_module
        backend_task._task_class = type(self).__name__
        # need to get json_filename before setting _run_as_backend to True
        json_filename = self.json_filename
        backend_task._processTaskFilesForFrontendWrite()
        backend_task._run_as_backend = True
        backend_task._writeComboJsonFile(json_filename)

    def _copyScriptToBackend(self):
        script_filename = inspect.getfile(type(self))
        try:
            return shutil.copy(script_filename, self.getTaskDir())
        except shutil.SameFileError:
            return script_filename

    @preprocessor(order=_WRITE_JSON_ORDER)
    def _prepareComboTask(self, *args, **kwargs):
        self._writeFrontendJsonFile()

    def _finish(self):
        super()._finish()
        #  The next time this task is started, it should have a new combo id
        self._regenerateComboId()

    def backendMain(self):
        raise NotImplementedError

    def _processBackend(self):
        json_out_path = self.json_out_filename
        if not os.path.isfile(json_out_path):
            msg = "No json file was returned from the backend. "
            logfile = self.getTaskFilename(self._getLogFilename())
            if os.path.isfile(logfile):
                msg += f"Check {logfile} for more information."
                try:
                    self.printDebug('Log file contents:\n',
                                    self.getLogAsString())
                except Exception as e:
                    self.printDebug(str(e))
            else:
                msg += f"Log file not found at {logfile}"
            exception = RuntimeError(msg)
            # TODO report the path to the log file in this error message.
            self._recordFailure(exception)
        else:
            with open(json_out_path, 'r') as infile:
                # Create a new instance from the backend json output
                TaskClass = type(self)
                try:
                    rehydrated_backend = TaskClass.fromJson(json.load(infile))
                except json.JSONDecodeError as e:
                    self._recordFailure(e)
                else:
                    self._updateFromBackend(rehydrated_backend)

    def _updateFromBackend(self, rehydrated_backend):
        """
        Update the frontend task based on the rehydrated backend task
        """
        if isinstance(self.output, parameters.CompoundParam):
            self.output.setValue(rehydrated_backend.output)
            self._processTaskFilesForBackendRehydration()
        else:
            self.output = rehydrated_backend.output
        if rehydrated_backend.status == rehydrated_backend.FAILED:
            backend_exc = pickle.loads(
                rehydrated_backend._failure_info.encode())
            backend_tb = rehydrated_backend._failure_tb
            self._recordFailure(backend_exc, backend_tb)

    def _writeComboJsonFile(self, filename):
        if self.status is self.FAILED:
            # Use protocol 0 since it's ascii-encodable
            self._failure_info = pickle.dumps(self.failure_info.exception,
                                              0).decode()
            backend_tb = ''.join(
                traceback.format_tb(self.failure_info.exception.__traceback__))
            self._failure_tb = backend_tb
        ser_task = self._createSerializationTask()
        try:
            with open(filename, 'w') as f:
                json.dump(ser_task, f, indent=4)
        except:
            # If something goes wrong during serialization, we should make
            # sure to remove the empty json file.
            os.remove(filename)
            raise

    def runBackend(self):
        self._processTaskFilesForBackendExecution()
        self.progressChanged.connect(self._onBackendProgressChanged)
        self.max_progressChanged.connect(self._onBackendProgressChanged)
        self.progress_stringChanged.connect(self._onBackendProgressChanged)
        with self.guard():
            try:
                self.backendMain()
            except NotImplementedError:
                self.mainFunction()
        if self.failure_info:
            self.status = self.FAILED
            if not isinstance(self.failure_info.exception, TaskFailure):
                print(self.failure_info.traceback)
            if self.failure_info.message:
                print(self.failure_info.message)
        self._processTaskFilesForBackendWrite()
        # Mark as frontend to ensure correct params are serialized
        self._run_as_backend = False
        self._writeComboJsonFile(self.json_out_filename)

    def _onBackendProgressChanged(self):
        """
        Implement logic that will communicate progress change from the backend
        to the front-end.
        """

    def _get_module(self):
        """
        Return the module string defining where the class for `self` is defined.
        """
        return imputils.get_path_from_module(inspect.getmodule(self))

    def _createSerializationTask(self) -> 'AbstractComboTask':
        """
        Return a new instance of this task that has serialization param values
        set for frontend/backend conversion. Non-serialization params have
        default values.
        """
        ser_task = self.__class__()
        ser_param_names = self._getSerializationParamNames()
        for param_name in ser_param_names:
            param_value = getattr(self, param_name)
            if isinstance(param_value, parameters.CompoundParam):
                param_to_serialize = getattr(ser_task, param_name)
                param_to_serialize.setValue(param_value)
            else:
                setattr(ser_task, param_name, param_value)
        return ser_task

    def _getSerializationParamNames(self) -> List[str]:
        """
        Return a list of the names of params that should be serialized for
        frontend/backend combo task conversion.
        """
        if self._run_as_backend:
            param_names = self._FRONTEND_TO_BACKEND_PARAMS
        else:
            param_names = self._BACKEND_TO_FRONTEND_PARAMS
        return param_names

    #===========================================================================
    # TaskFile Processing
    #===========================================================================

    def _processTaskFilesForFrontendWrite(self):
        """
        This will be called before writing out the combotask frontend json file.
        Transforms all TaskFile and TaskFolder paths in self.input so that the
        json file within the taskdir will be portable, if possible.

        Raises a ValueError if any files/directories do not exist.
        """
        self._processTaskFiles(
            self.input, process_func=_filepaths.get_launch_path)

    def _processTaskFilesForBackendExecution(self):
        """
        This will be called in the backend before executing the mainFunction of
        the combotask. Override if the file paths are different in the backend
        compared to the paths used in the frontend.

        Raises a ValueError if any files/directories do not exist.
        """
        self._processTaskFiles(self.input, process_func=None)

    def _processTaskFilesForBackendWrite(self):
        """
        This will be called in the backend after the mainFunction returns before
        writing the combotask backend json file. Converts absolute paths into
        relative paths so that file references can remain valid if the taskdir
        is copied or moved.

        Raises a ValueError if any files/directories do not exist.
        """

        def process_output(path, launchdir):
            path = os.path.relpath(path)
            return path

        self._processTaskFiles(self.output, process_func=process_output)

    def _processTaskFilesForBackendRehydration(self):
        """
        This will be called before the output of the backend task is set back on
        the frontend task. Transforms all TaskFile and TaskFolder references
        into absolute paths so that they will be valid regardless of the CWD of
        the process that started the task.

        Raises a ValueError if any files/directories do not exist.
        """
        self._processTaskFiles(self.output, process_func=None)

    def _processTaskFiles(self,
                          param,
                          *,
                          process_func,
                          check_exist=True,
                          dir=None):
        if dir is None:
            dir = self.getTaskDir()

        def process_taskfile(path):
            if path is None:
                return None
            if check_exist and not os.path.exists(path):
                raise ValueError(
                    f'Filepath "{path}" does not exist. Make sure all '
                    'taskfiles and task folders point to existing files before '
                    'starting or completing the task.')
            if process_func is None:
                return path
            else:
                new_path = process_func(path, dir)
                return new_path

        if isinstance(param, parameters.CompoundParam):
            paramtools.map_subparams(process_taskfile, param, TaskFile)
        if isinstance(param, parameters.CompoundParam):
            paramtools.map_subparams(process_taskfile, param, TaskFolder)


#===============================================================================
# Task execution mixins
#===============================================================================


def get_schrodinger_run():
    return 'run'


class _SaveTaskReferenceMixin:

    def __init_subclass__(cls):
        super().__init_subclass__()
        # Let each class have its own set so failures are easier to understand
        cls._saved_task_references = scollections.IdSet()

    def start(self, *args, **kwargs):
        super().start(*args, **kwargs)
        if self.status == Status.RUNNING:
            self._saveTaskReference()

    def _finish(self):
        super()._finish()
        self._discardTaskReference()

    def _saveTaskReference(self):
        self._saved_task_references.add(self)

    def _discardTaskReference(self):
        self._saved_task_references.discard(self)


class BlockingMixin:
    """
    Compatible with subclasses of AbstractFunctionTask.
    """

    def _runMainFunction(self):
        self._guardedMain()
        self._finish()


class ThreadMixin(_SaveTaskReferenceMixin):

    MAX_THREAD_TASKS = 500
    qthread = parameters.NonParamAttribute()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.qthread = None

    def kill(self):
        """
        @overrides: AbstractTask

        Killing threads is dangerous and can leading to deadlocking on
        Windows, so we intentionally leave it unimplemented rather than
        using QThread.terminate.
        """
        raise NotImplementedError

    def _runMainFunction(self):
        # Make sure that there is a QApplication running. If there isn't,
        # create a QCoreApplication.
        application.get_application(create=True, use_qtcore_app=True)
        self.qthread = QtCore.QThread()
        # TODO: Decide whether to leave this as a monkey-patch or hook up
        # qthread.started to _guardedMain instead. If we leave it as a patch,
        # we should add a strong warning against calling .start() from multiple
        # threads.
        self.qthread.run = self._guardedMain
        self.qthread.finished.connect(self.__onThreadFinished)
        self.qthread.start()

    @typing.final
    def __onThreadFinished(self):
        self._finish()


class QProcessError(Exception):

    def __init__(self, message):
        super().__init__(message)


class QProcessFailedToStartError(QProcessError):
    pass


class QProcessCrashedError(QProcessError):
    pass


class QProcessTimedout(QProcessError):
    pass


class QProcessWriteError(QProcessError):
    pass


class QProcessReadError(QProcessError):
    pass


class QProcessUnknownError(QProcessError):
    pass


_QProcessErrorToException = {
    QProcess.FailedToStart: QProcessFailedToStartError,
    QProcess.Crashed: QProcessCrashedError,
    QProcess.Timedout: QProcessTimedout,
    QProcess.WriteError: QProcessWriteError,
    QProcess.ReadError: QProcessReadError,
    QProcess.UnknownError: QProcessUnknownError
}


class SubprocessMixin(_SaveTaskReferenceMixin):
    cmd = parameters.NonParamAttribute()
    exit_code = parameters.NonParamAttribute()
    stdout = parameters.NonParamAttribute()
    stderr = parameters.NonParamAttribute()
    qprocess = parameters.NonParamAttribute()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cmd = None
        self.exit_code = None
        self.stdout = None
        self.stderr = None
        self.qprocess = None
        self._printing_output_to_terminal = False

    def printingOutputToTerminal(self):
        """
        :return: whether the `StdOut` and `StdErr` output from this task is
            being printed to the terminal
        :rtype: bool
        """

        return self._printing_output_to_terminal

    def setPrintingOutputToTerminal(self, print_to_terminal):
        """
        Set this task to print `StdOut` and `StdErr` output to terminal, or not.

        :param print_to_terminal: whether to send process output to terminal
        :type print_to_terminal: bool
        """

        self._printing_output_to_terminal = print_to_terminal

    def runCmd(self, cmd):
        # Make sure that there is a QApplication running. If there isn't,
        # create a QCoreApplication.
        application.get_application(create=True, use_qtcore_app=True)
        self.exit_code = None
        self.stdout = None
        self.stderr = None
        self.qprocess = None
        self.cmd = cmd

        cmd[0] = subprocess_utils.abs_schrodinger_path(cmd[0])
        self._setupQProcess()
        self.qprocess.start(cmd[0], cmd[1:])

    def _setupQProcess(self):
        self.qprocess = QtCore.QProcess()
        if self.printingOutputToTerminal():
            self.qprocess.setProcessChannelMode(
                QtCore.QProcess.ForwardedChannels)
        self.qprocess.setWorkingDirectory(self.getTaskDir())
        self.qprocess.finished.connect(self.__onSubprocessCompleted)
        self.qprocess.errorOccurred.connect(self.__onErrorOccurred)

    @typing.final
    def __onSubprocessCompleted(self):
        with self.guard():
            self._onSubprocessCompleted()
        self._finish()

    def _onSubprocessCompleted(self):
        self._readOutput()
        if not self.printingOutputToTerminal():
            logfilename = self._getLogFilename()
            with open(self.getTaskFilename(logfilename), 'w') as log_file:
                log_file.write(self.stdout)
                log_file.write(self.stderr)
        self.exit_code = self.qprocess.exitCode()
        if self.exit_code != 0:
            msg = f'{self} returned non-zero exit code.'
            if self.stderr:
                msg += f'\n{self.stderr}'
            self._recordFailure(TaskFailure(msg))

    def _readOutput(self):
        """
        Read stdout and stderr from the QProcess.
        """
        stdout = str(self.qprocess.readAllStandardOutput(), encoding='utf-8')
        stderr = str(self.qprocess.readAllStandardError(), encoding='utf-8')
        # Convert any Windows newlines
        self.stdout = stdout.replace("\r\n", "\n")
        self.stderr = stderr.replace("\r\n", "\n")

    @typing.final
    def __onErrorOccurred(self, error):
        with self.guard():
            self._onErrorOccurred(error)
        self._finish()

    def _onErrorOccurred(self, error):
        qprocess_exception = _QProcessErrorToException[error](
            message=
            f"Command: {self.cmd} had fatal error: {self.qprocess.errorString()}"
        )
        self.exit_code = self.qprocess.exitCode()
        self._recordFailure(qprocess_exception)

    def _getLogFilename(self):
        return self.name + '.log'

    def getLogAsString(self):
        with open(self.getTaskFilename(self._getLogFilename())) as log_file:
            return log_file.read()

    def kill(self):
        """
        @overrides: AbstractTask

        Kill the subprocess and set the status to FAILED.
        """
        if self.status is not self.RUNNING:
            raise RuntimeError("Can't kill a task that's not running.")
        if self.qprocess:
            self.qprocess.finished.disconnect(self.__onSubprocessCompleted)
            self.qprocess.errorOccurred.disconnect(self.__onErrorOccurred)
            self.qprocess.kill()
            self.qprocess.waitForFinished()
            self._recordFailure(TaskKilled())
        self._finish()


#===============================================================================
# Prepackaged Task Classes
#===============================================================================
class BlockingFunctionTask(BlockingMixin, _AbstractFunctionTask):
    """
    A task that simply runs a function and blocks for the duration of it.
    To use, implement `mainFunction`.
    """


class ThreadFunctionTask(ThreadMixin, _AbstractFunctionTask):
    """
    A task that runs a function in a separate thread.
    To use, implement `mainFunction`.
    """


class SubprocessCmdTask(SubprocessMixin, AbstractCmdTask):
    """
    A task that launches a subprocess.
    To use, implement `makeCmd` and return a list of strings.
    """


class ComboBlockingFunctionTask(AbstractComboTask):
    """
    This is mostly for testing purposes.
    """

    def runCmd(self, cmd):
        cls = type(self)
        backend_task = cls.fromJsonFilename(self.json_filename)
        backend_task.specifyTaskDir(self.getTaskDir())
        backend_task.start()
        os.rename(backend_task.json_out_filename, self.json_out_filename)
        self._processBackend()
        self._finish()


class ComboSubprocessTask(SubprocessMixin, AbstractComboTask):

    def _processTaskFilesForFrontendWrite(self):
        # FIXME: this seems like unnecessary overhead
        self._processTaskFiles(
            self.input, process_func=self._copyFilesToTaskdir)
        super()._processTaskFilesForFrontendWrite()

    def _copyFilesToTaskdir(self, src_path, launchdir):
        dest_path = os.path.join(launchdir, src_path)
        if os.path.abspath(src_path) != os.path.abspath(dest_path):
            os.makedirs(os.path.dirname(dest_path), exist_ok=True)
            if os.path.isdir(src_path):
                shutil.copytree(src_path, dest_path)
            else:
                shutil.copyfile(src_path, dest_path)
        return dest_path

    def _processTaskFilesForBackendRehydration(self):
        # Can't check for existence until after transforming the filenames
        self._processTaskFiles(
            self.output,
            process_func=_filepaths.get_job_output_path,
            check_exist=False)
        # Now check for existence
        super()._processTaskFilesForBackendRehydration()

    def runBackend(self):
        # Specify the task dir as the cwd since we've already chdirs into
        # the directory with all the task files
        self.specifyTaskDir(None)
        return super().runBackend()

    def getTaskDir(self):
        if self.isBackendMode():
            return ''
        return super().getTaskDir()

    def _finish(self):
        with self.guard():
            self._processBackend()
        super()._finish()


class SignalTask(AbstractTask):
    """
    A task that relies on signals to proceed. Runs asynchronously via the event
    loop without requiring a worker thread. To use, implement setUpMain to
    connect any per-run signals and slots. Any slots should be decorated with
    SignalTask.guard_method so that exceptions in slots get converted into
    task failures. To end the task, emit self.mainDone to indicate the task has
    successfully completed. To fail, raise a TaskFailure or other exception.
    """
    mainDone = QtCore.pyqtSignal()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mainDone.connect(self._finish)

    @staticmethod
    def guard_method(func):

        def wrapped_func(self, *args, **kwargs):
            with self.guard():
                return func(self, *args, **kwargs)
            if self.failure_info:
                self._finish()

        return wrapped_func

    def run(self):
        with self.guard():
            self.setUpMain()

    def setUpMain(self):
        raise NotImplementedError()