"""
Framework for writing computational workflows and running them in a highly
distributed manner. Each step of the workflow is either a "mapping" operation
(see `MapStep`) or "reducing" operation (see `ReduceStep). These steps can then
be chained together using the `Chain` class.
For a more complete introduction, see WordCount tutorial:
https://confluence.schrodinger.com/display/~jtran/Stepper+WordCount+Tutorial
For documentation on specific stepper features, see the following feature list.
You can ctrl+f the feature tag to jump to the relevant docstrings.
+----------------+------------------+
| Feature | Tag |
+================+==================+
| MapStep | _map_step_ |
+----------------+------------------+
| ReduceStep | _reduce_step_ |
+----------------+------------------+
| Chain | _chain_ |
+----------------+------------------+
| Settings | _settings_ |
+----------------+------------------+
| Configuration | _configuration_ |
+----------------+------------------+
| Serialization | _serialization_ |
+----------------+------------------+
| File Handling | _file_handling_ |
+----------------+------------------+
| Licensing | _licensing_ |
+----------------+------------------+
To run steps that aren't defined in the distribution:
The script should be executed inside the working directory and import steps from
a local package in the working directory.
Working dir contents::
script.py
my_lib/
__init__.py
steps.py
Minimal code in script.py if it needs to run under job control::
from schrodinger.job import launchapi
from schrodinger.ui.qt.appframework2 import application
from my_lib.steps import MyStep
def get_job_spec_from_args(argv):
jsb = launchapi.JobSpecificationArgsBuilder(argv)
jsb.setInputFile(__file__)
jsb.setInputDirectory('my_lib')
return jsb.getJobSpec()
def main():
step = MyStep()
set.getOutputs()
if __name__ == '__main__':
application.run_application(main)
#===============================================================================
# BigQuery Functionality
#===============================================================================
Stepper has BigQuery integrations. To use the BigQuery features, you must
install the requisite GCP libraries. To do so, run:
$SCHRODINGER/run python3 -m pip install google-cloud-storage google-cloud-bigquery==1.28.0 google-auth
to uninstall:
$SCHRODINGER/run python3 -m pip uninstall google-cloud-storage google-cloud-bigquery google-auth
This will be done automatically after cases BLDMGR-4907 and BLDMGR-4908 are
complete.
To use BigQuery with your workflow, set the `batch_by_bigquery` batch setting
to True. One dataset will be used per workflow run with up to two tables per
bigquery-batched step. The tables are named using step ids so there will be
no chance of a table name collision.
"""
DOUBLE_BATCH_THRESHOLD = 50
import collections
import configparser
import copy
import csv
import glob
import inspect
import itertools
import os
import pprint
import time
import traceback
import uuid
import zipfile
from typing import Any
from typing import Iterable
from typing import Optional
import more_itertools
import requests
from ruamel import yaml
from schrodinger.job import jobcontrol
from schrodinger.models import json
from schrodinger.models import parameters
from schrodinger.Qt import QtCore
from schrodinger.tasks import hosts
from schrodinger.tasks import jobtasks
from schrodinger.tasks import queue
from schrodinger.tasks import tasks
from schrodinger.ui.qt.appframework2 import application
from schrodinger.utils import imputils
from schrodinger.utils import env
import logging
logger = logging.getLogger('schrodinger.tasks.stepper')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('<STEPPER> %(levelname)s:%(message)s'))
logger.addHandler(handler)
#===============================================================================
# BigQuery Functions
#===============================================================================
try:
from google.cloud import bigquery
from google.cloud import storage
from google.oauth2 import service_account
except ImportError:
bigquery = storage = service_account = None
PROJECT = os.environ.get('SCHRODINGER_BQ_PROJECT')
_NO_PROJECT_ERR_MSG = (
"No bigquery project is defined for this run. Set "
"an environment variable for SCHRODINGER_BQ_PROJECT and try again.")
KEY_PATH = os.environ.get('SCHRODINGER_BQ_KEY') # Service account key
DATASET = 'stepper_test_sets'
BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery"
CLOUD_SCOPE = "https://www.googleapis.com/auth/cloud-platform"
BQ_CLIENT = None
MODULE_ROOT_BLACKLIST = ('schrodinger',)
def _generate_credentials():
if KEY_PATH is None:
return None
credentials = service_account.Credentials.from_service_account_file(
KEY_PATH,
scopes=[CLOUD_SCOPE, BIGQUERY_SCOPE],
)
return credentials
def _generate_clients():
credentials = _generate_credentials()
bq_client = bigquery.Client(project=PROJECT, credentials=credentials)
return bq_client
def _get_bq_client():
global BQ_CLIENT
if BQ_CLIENT is None:
BQ_CLIENT = _generate_clients()
return BQ_CLIENT
def _get_fully_qualified_table_id(table_id):
if PROJECT is None:
raise ValueError(_NO_PROJECT_ERR_MSG)
if PROJECT not in table_id:
return f'{PROJECT}.{table_id}'
else:
return table_id
def _create_table(table_id):
table_id = _get_fully_qualified_table_id(table_id)
client = _get_bq_client()
schema = [bigquery.SchemaField("data", "STRING", mode="REQUIRED")]
table = bigquery.Table(table_id, schema=schema)
client.create_table(table)
logger.info(f"<BigQuery>: Created table {table}")
def _create_dataset(dataset_name):
if PROJECT is None:
raise ValueError(_NO_PROJECT_ERR_MSG)
client = _get_bq_client()
dataset_name = f'{PROJECT}.{dataset_name}'
dataset = bigquery.Dataset(dataset_name)
dataset.location = "US"
dataset = client.create_dataset(dataset) # Make an API request.
logger.info(f"<BigQuery>: Created dataset {dataset.dataset_id}")
def _delete_dataset(dataset_name):
if PROJECT is None:
raise ValueError(_NO_PROJECT_ERR_MSG)
client = _get_bq_client()
dataset_name = f'{PROJECT}.{dataset_name}'
client.delete_dataset(dataset_name, delete_contents=True, not_found_ok=True)
def _delete_table(table_id):
table_id = _get_fully_qualified_table_id(table_id)
client = _get_bq_client()
client.delete_table(table_id)
def _load_into_table(fname, table_id):
table_id = _get_fully_qualified_table_id(table_id)
client = _get_bq_client()
job_config = bigquery.LoadJobConfig(
schema=[bigquery.SchemaField("data", "STRING", mode="REQUIRED")])
with open(fname, 'rb') as infile:
load_job = client.load_table_from_file(
infile, table_id, job_config=job_config) # Make an API request.
try:
load_job.result() # Wait for the job to complete.
except:
print("BigQuery Load Errors: ", load_job.errors)
raise
table = client.get_table(table_id)
logger.info("<BigQuery>: Loaded {} rows to table {}".format(
table.num_rows, table_id))
def _load_in_batches(gen,
serializer,
table_id,
csv_size_limit=3e9,
chunk_size=1000000):
"""
Load batches of outputs into a table specified by `table_id`. outputs are
batched so csv files are around `csv_size_limit` bytes. The csv files
are written in chunks of `chunk_size` before being checked for size.
:param gen: A generator of outputs to load into the table
:type gen: Iterator
:param serializer: A serializer to serialize the outputs, see `Serializer`
:type serializer: Serializer
:param table_id: The table to load the outputs into. Should include both
dataset and table name, i.e. "<DATASET>.<TABLE>"
:type table_id: str
"""
output_generator = more_itertools.peekable(gen)
def outputs_exhausted():
ITERATOR_EXHAUSTED = object()
return output_generator.peek(ITERATOR_EXHAUSTED) is ITERATOR_EXHAUSTED
tmp_out_fname = str(uuid.uuid4()) + '.csv'
while not outputs_exhausted():
with open(tmp_out_fname, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
fsize = os.fstat(csvfile.fileno()).st_size
# Write out chunks of outputs until the file is larger than
# csv_size_limit
while fsize < csv_size_limit and not outputs_exhausted():
for output in itertools.islice(output_generator, chunk_size):
writer.writerow([serializer.toString(output)])
fsize = os.fstat(csvfile.fileno()).st_size
# Upload file and then clean it up
_load_into_table(tmp_out_fname, table_id)
os.remove(tmp_out_fname)
def _table_row_count(table_id):
table_id = _get_fully_qualified_table_id(table_id)
bq_client = _get_bq_client()
table = bq_client.get_table(table_id)
return table.num_rows
def _get_table_data(table_id, starting_idx=None, num_rows=None):
"""
Get contents of a table specified by `table_id`. If `starting_idx` is
specified, then the data will start at that row index. Up to `num_rows`
will be returned.
"""
bq_client = _get_bq_client()
table_id = _get_fully_qualified_table_id(table_id)
def _unwrap_row_iterator():
for row in bq_client.list_rows(
table_id, start_index=starting_idx, max_results=num_rows):
yield row['data']
return _unwrap_row_iterator()
#===============================================================================
# StepMonitor calls
#===============================================================================
[docs]def get_stepmonitor_config_path():
return os.path.join(os.environ['SCHRODINGER'], 'stepmon_config.ini')
def _get_stepmonitor_config() -> Optional[configparser.ConfigParser]:
config = configparser.ConfigParser()
config_path = get_stepmonitor_config_path()
if not os.path.exists(config_path):
return None
config.read(config_path)
return config
[docs]def get_stepmonitor_url() -> Optional[str]:
config = _get_stepmonitor_config()
if config is None:
return None
return config['server_settings']['url']
def _get_update_rate() -> Optional[float]:
config = _get_stepmonitor_config()
if config is None:
return None
return float(config['stepper_settings']['update_rate_in_secs'])
[docs]def POST_update_job(job_id: str, run_info: dict, parent_id: str):
"""
Make a POST request to the stepmonitor server with the current state of the
workflow's run_info.
:param job_id: The job id for this workflow.
:param run_info: The run info describing the status of the workflow.
:param parent_id: The job id of the parent of this workflow.
"""
base_url = get_stepmonitor_url()
if base_url is None:
err_msg = "Can't make a request without a stepmonitor URL defined."
raise ValueError(err_msg)
endpoint = base_url + 'job/update/' + job_id
payload = {'run_info': run_info}
if parent_id is not None:
payload['parent_id'] = parent_id
resp = requests.post(endpoint, data=payload)
resp.raise_for_status()
#===============================================================================
# Batching
#===============================================================================
def _assert_step_hasnt_started(func):
"""
Decorator that prevents a step method from running if the output generator
has already been created.
"""
def wrapped_func(self, *args, **kwargs):
if self._outputs_gen is not None:
raise RuntimeError(
f'Cannot call {func.__name__} because this step has already '
'started (i.e. outputs() or getOutput() has already been called).'
)
return func(self, *args, **kwargs)
return wrapped_func
def _prettify_time(time_in_float):
utc_time = time.gmtime(time_in_float)
return time.strftime('%Y-%m-%d %H:%M:%S %Z', utc_time)
def _prettify_duration(time_in_sec):
def div_w_remainder(numer, denom):
return int(numer // denom), numer % denom
days, remaining_sec = div_w_remainder(time_in_sec, 24 * 60 * 60)
hours, remaining_sec = div_w_remainder(remaining_sec, 60 * 60)
minutes, remaining_sec = div_w_remainder(remaining_sec, 60)
seconds = int(remaining_sec)
pretty_string = f'{hours:02d}:{minutes:02d}:{seconds:02d}'
if days:
pretty_string = f'{days:02d}:{pretty_string}'
return pretty_string
[docs]class BQTable(parameters.CompoundParam):
table_id: str = None
starting_idx: int = None
batch_size: int = None
[docs] def delete(self):
_delete_table(self.table_id)
class _DehydratedStep(parameters.CompoundParam):
"""
See `_BaseStep._dehydrateStep` for documentation.
"""
step_module_path: str
step_class_name: str
step_id: str
step_config: dict
starting_step_id: str = None
input_file: str = None
input_bq_table: BQTable
output_bq_table: BQTable
[docs]class StepTaskOutput(parameters.CompoundParam):
output_file: str = None
run_info: dict
[docs]class StepTaskMixin(parameters.CompoundParamMixin):
"""
This class must be mixed in with a subclass of AbstractComboTask. The
resulting task class may be used to run any step as a task, provided the
input, output, and settings classes are all JSONable.
"""
_double_batch: bool = False
input: StepTaskInput
output: StepTaskOutput
DEFAULT_TASKDIR_SETTING = tasks.AUTO_TASKDIR
[docs] def __init__(self, *args, step=None, **kwargs):
super().__init__(*args, **kwargs)
self._step_class = None
self._step = None
if step is not None:
self.setStep(step)
[docs] def addLicenseReservation(self, license, num_tokens=1):
try:
super().addLicenseReservation(license, num_tokens)
except AttributeError:
pass
def _setUpInputFile(self, filepath):
"""
Given a filepath, do any necessary setup to register the file (e.g.
add it to a list of input files) and return the path that should be
used by the backend task. (e.g. the absolute path for a subprocess task
or a relative path for a job task)
"""
raise NotImplementedError
[docs] def setStep(self, step):
self._step = step
dehyd_step = step._dehydrateStep()
# Set up files for this step and all of its component steps
for step_id, step_settings in dehyd_step.step_config.items():
if step_id.startswith(step.getStepId()):
for setting, value in step_settings.items():
if isinstance(value, StepperFile):
step_settings[setting] = self._setUpInputFile(value)
elif isinstance(value, StepperFolder):
step_settings[setting] = self._setUpInputFolder(value)
self.input.dehydrated_step = dehyd_step
self._step_class = type(step)
self._setUpStepTask(dehyd_step)
[docs] def getStepClass(self):
return self._step_class
def _setUpStepTask(self, dehyd_step: _DehydratedStep):
self._preprocessModuleRoot(dehyd_step)
self._preprocessInputFiles(dehyd_step.input_file)
if dehyd_step.input_file is not None:
dehyd_step.input_file = self._setUpInputFile(dehyd_step.input_file)
def _preprocessModuleRoot(self, dehyd_step: _DehydratedStep):
"""
If the dehydrated step is defined in a package in a non-blacklisted
folder in the working directory, add the package as an input folder for
the task so it will be available for import in the backend task folder.
If the step is defined in the main script we will not be able to import
it in the backend, so a `ValueError` exception is raised.
:raise ValueError: if the module root is __main__.
"""
root = dehyd_step.step_module_path.split('.')[0]
if root == '__main__':
raise ValueError(
f'Step class {dehyd_step.step_class_name} should be defined'
f' outside of __main__.')
if os.path.isdir(root) and not root.lower() in MODULE_ROOT_BLACKLIST:
print(f'Using nonstandard package {root}')
self._setUpInputFolder(root)
def _preprocessInputFiles(self, input_file: str):
"""
Before the starting the task, convert any input StepperFiles to paths
relative to the backend machine.
"""
if self._step.Input is StepperFile:
# If the inputs for the steps are StepperFiles, then we need
# to read the input file and register the inputs and convert
# them to the right path for the backend.
serializer = self._step._getInputSerializer()
# Read the inputs and register them
inp_files = []
for inp in serializer.deserialize(input_file):
inp_files.append(self._setUpInputFile(inp))
# Write out the inputs again with the correct paths for
# the backend
serializer.serialize(inp_files, input_file)
def _makeBackendStep(self):
step = _rehydrate_step(self.input.dehydrated_step)
self._step_class = type(step)
if not self._double_batch:
step.setBatchSettings(None)
return step
[docs] def mainFunction(self):
self._step = self._makeBackendStep()
step = self._step
if step._output_table is None:
batch_outp_name = self.name + '.out'
step.writeOutputsToFile(batch_outp_name)
self.output.output_file = batch_outp_name
else:
step.writeOutputsToTable()
self.output.run_info = step._run_info
if self.input.debug_mode:
self._runDebug()
def _runDebug(self):
pass
def _postprocessOutputFiles(self):
"""
After the task returns, convert any output StepperFiles to paths
relative to the frontend machine.
"""
if self._step.Output is StepperFile:
output_file = self.getTaskFilename(self.output.output_file)
self._processOutputStepperFiles(output_file)
def _processOutputStepperFiles(self, output_file: str):
"""
Reads in the output file containing stepper files, processes them
(e.g. register them as output files, convert them to the correct paths),
and then writes them back out again.
:param output_file: File storing list of unprocessed output stepper
files
"""
processed_outputs = []
serializer = self._step.getOutputSerializer()
for outp in serializer.deserialize(output_file):
outp = self._setUpOutputFile(outp)
processed_outputs.append(outp)
serializer.serialize(processed_outputs, output_file)
def _setUpOutputFile(self, outp_file):
return StepperFile(
os.path.join(self.getTaskDir(), self.getTaskFilename(outp_file)))
[docs]class StepSubprocessTask(StepTaskMixin, tasks.ComboSubprocessTask):
def _setUpInputFile(self, filepath):
return StepperFile(os.path.abspath(filepath))
def _setUpInputFolder(self, folderpath):
return StepperFolder(os.path.abspath(folderpath))
@tasks.postprocessor
def _postprocessOutputFiles(self):
return super()._postprocessOutputFiles()
[docs]class StepJobTask(StepTaskMixin, jobtasks.ComboJobTask):
_use_async_jobhandler: bool = True
input: StepTaskInput
output: StepTaskOutput
def _makeBackendStep(self, *args, **kwargs):
step = super()._makeBackendStep(*args, **kwargs)
step.progressUpdated.connect(self._progressUpdated)
return step
def _progressUpdated(self, run_info_str):
if get_stepmonitor_url() is not None:
self._run_info_str = run_info_str
if not self._run_info_update_timer.isActive():
self._run_info_update_timer.start()
def _postRunInfoUpdate(self):
job = self._job
POST_update_job(job.JobId, self._run_info_str, job.ParentJobId)
def _setUpInputFile(self, filepath):
self.addInputFile(filepath)
if os.path.isabs(filepath) or filepath.startswith('..'):
return os.path.basename(filepath)
else: # filepath is relative path in launch directory
return filepath
def _setUpInputFolder(self, folderpath):
self.addInputDirectory(folderpath)
if os.path.isabs(folderpath) or folderpath.startswith('..'):
return os.path.basename(folderpath)
else: # folderpath is relative path in launch directory
return folderpath
def _runDebug(self):
# Register all input and output files so they're brought back to the
# launch machine.
for file_ in _get_stepper_debug_files():
self.addOutputFile(file_)
[docs] def mainFunction(self):
self._job = jobcontrol.get_backend().getJob()
if get_stepmonitor_url() is not None:
self._run_info_update_timer = QtCore.QTimer(self)
self._run_info_update_timer.setSingleShot(True)
self._run_info_update_timer.setInterval(_get_update_rate() * 1000)
self._run_info_update_timer.timeout.connect(self._postRunInfoUpdate)
super().mainFunction()
self.addOutputFile(self.output.output_file)
if self._step.Output is StepperFile:
self._processOutputStepperFiles(self.output.output_file)
if get_stepmonitor_url() is not None:
self._run_info_update_timer.stop()
try:
self._postRunInfoUpdate()
except Exception:
traceback.print_exc()
@tasks.postprocessor
def _postprocessOutputFiles(self):
return super()._postprocessOutputFiles()
def _setUpOutputFile(self, outp_file):
if self.isBackendMode():
self.addOutputFile(outp_file)
return StepperFile(
os.path.join(self.getTaskDir(), self.getTaskFilename(outp_file)))
#===============================================================================
# Running steps in batches
#===============================================================================
[docs]class BatchSettings(parameters.CompoundParam):
size: int = 10
task_class: type = StepJobTask
hostname: str = 'localhost'
batch_by_bigquery: bool = False
_bq_dataset: str = None
[docs]class Serializer:
""" <_serialization_>
A class for defining special serialization for some datatype. Serialization
by default uses the `json` protocol, but if a specialized protocol is wanted
instead, users can subclass this class to do so.
Subclasses should:
- Define `DataType`. This is the class that this serializer can
encode/decode.
- Define `toString(self, output)`, which defines how to serialize
an output.
- Define `fromString(self, input_str)`, which defines how to
deserialize an input.
This can then be used as the `InputSerializer` or `OutputSerializer` for
any step.
Here's an example for defining an int that's serialized in base-two
as opposed to base-ten::
class IntBaseTwoSerializer(Serializer):
DataType = int
def toString(self, output):
return bin(output) # 7 -> '0b111'
def fromString(self, input_str):
return int(input_str[2:], 2) # '0b111' -> 7
This can then be used anywhere you'd use an int as the output or input in a
step. For example::
class SquaringStep(MapStep):
Input = int
InputSerializer = IntBaseTwoSerializer
Output = int
OutputSerializer = IntBaseTwoSerializer
def mapFunction(self, inp):
yield inp**2
Now, any time that a `SquaringStep` would read its inputs from a file
or write its outputs to a file, it'll do so using using a base-two
representation.
"""
DataType = NotImplemented
[docs] def serialize(self, items, fname):
"""
Write `items` to a file named `fname`.
:type items: iterable[self.DataType]
:type fname: str
"""
with open(fname, 'w') as outfile:
for outp in items:
outfile.write(self.toString(outp) + '\n')
[docs] def deserialize(self, fname):
"""
Read in items from `fname`.
:type fname: str
:rtype: iterable[self.DataType]
"""
with open(fname, 'r') as infile:
for line in infile:
inp = self.fromString(line.strip('\n'))
yield inp
[docs] def fromString(self, input_str):
raise NotImplementedError
[docs] def toString(self, output):
raise NotImplementedError
@classmethod
def __init_subclass__(cls):
if cls.DataType is NotImplemented:
raise NotImplementedError(
"DataType must be specified for Serializers")
super().__init_subclass__()
class _DynamicSerializer(Serializer):
"""
The default serializer that simply uses `json.loads` and `json.dumps`
"""
DataType = object
def __init__(self, dataclass):
self._dataclass = dataclass
def fromString(self, inp_str):
return json.loads(inp_str, DataClass=self._dataclass)
def toString(self, outp):
return json.dumps(outp)
[docs]class StepperFolder(json.JsonableClassMixin, str):
"""
See `_BaseStep` for documentation.
"""
[docs] @classmethod
def fromJsonImplementation(cls, json_str):
return cls(json_str)
[docs] def toJsonImplementation(self):
return str(self)
[docs]class StepperFile(json.JsonableClassMixin, str):
"""
See `_BaseStep` for documentation.
"""
[docs] @classmethod
def fromJsonImplementation(cls, json_str):
return cls(json_str)
[docs] def toJsonImplementation(self):
return str(self)
[docs]class ValidationIssue(RuntimeError):
[docs] def __init__(self, source_step, msg):
self.source_step = source_step
self.msg = msg
super().__init__(msg)
def __repr__(self):
return f'{type(self).__name__}("{self.source_step.getStepId()}", "{self.msg}")'
def __str__(self):
return f'{type(self).__name__}("{self.source_step.getStepId()}", "{self.msg}")'
[docs]class SettingsError(ValidationIssue):
"""
Used in conjunction with `_BaseStep.validateSettings` to report an error
with settings. Constructed with the step with the invalid settings and an
error message, e.g.
`SettingsError(bad_step, "Step does not have required settings."`)
"""
[docs]class SettingsWarning(ValidationIssue):
"""
Used in conjunction with `_BaseStep.validateSettings` to report a warning
with settings. Constructed with the step with the invalid settings and an
error message, e.g.
`SettingsError(bad_step, "Step setting FOO should ideally by non-negative"`)
"""
class _BaseStep(QtCore.QObject):
"""
The features and behavior described in this docstring apply to all steps
and chains.
To use a step, instantiate it, set the inputs, and request outputs.
Accessing outputs causes the step to get input from the input source and
run the step operation. There is no concept of "running" or "starting" the
step.
class SquareStep(MapStep):
def mapFunction(self, inp):
yield inp * inp
step = SquareStep()
step.setInputs([1, 2, 3])
print(step.getOutputs()) # [1, 4, 9]
The outputs are produced with a generator. Thus, calling
`step.getOutputs()` twice will always result in an empty list for the
second call.
Settings
======== <_settings_>
Every step can parameterize how it operates using a set of settings. The
settings of a step are defined as a subclass of `CompoundParam` at the
class level, and can be set per-instance using keyword arguments at
instantiation time. Example::
class MultiplyByStep(MapStep):
class Settings(parameters.CompoundParam):
multiplier: int = 1
by_4_step = MultiplyByStep(multiplier=4)
by_4_step.setInputs([1, 2, 3])
by_4_step.getOutputs() == [4, 8, 12]
=============
Configuration
=============
A configuration is a dictionary that specifies settings values for steps
within a chain.
A step can take a configuration dictionary that maps step
selectors to default setting values. For example::
Chain(config={'A':{'max_rounds':10}})
This configuration will go through `Chain` and set all settings of A step's
to have `max_rounds` value of 10.
There are three currently supported selectors:
General selectors e.g. "A":
This will select all steps of type "A" (Note that this does not
select subclasses of "A")
Child selectors e.g. "A>B"
This will select all steps of type "B" that
are in chains of type "A". Multiple ">" operators can be linked
together. For example, "A>B>C" will select all "C" steps in "B"
chains which are in the "A" chain.
ID selector e.g. "A.B_0"
This will select the first "B" step in chain "A". The top level
chain never has an index. Steps in a chain are indexed relative to
other steps of the same type in that chain. For example,
if chain "A" is composed of steps BCBCC, then the ids would be
"A.B_0", "A.C_0", "A.B_1", "A.C_1", "A.C_2"
=============
File Handling
=============
To specify a file, use the `StepperFile` class as the input type, output
type, or as a subparam on the `Settings` class. Files specified in these
locations will automatically be copied to and from compute machines.
You can similarly specify `StepperFolder` to have folders copied over
to compute machines. Currently, `StepperFolder` can only be used with
step settings, not as step inputs or outputs.
Strings specified in `config` for `StepperFile` and `StepperFolder` will
be automatically cast.
========
Licenses
========
Some steps may require a license for each node that it's run on. All
batchable steps support this feature.
To specify the number of license reservations a step requires, override
`getLicenseRequirements` and return a dictionary mapping licenses
to the number of tokens required for that license. For example::
from schrodinger.utils import license
class LicenseRequiringStep(MapStep):
Input = str
Output = str
def getLicenseRequirements(self):
return {license.GLIDE_MAIN: 2}
Once you've specified what licenses are required, any batched steps will
automatically have the right number of licenses reserved.
.. NOTE:: Batched `Chain` by default account for any reservations that
might be necessary to run any component steps.
"""
progressUpdated = QtCore.pyqtSignal(str) # serialized run_info dict
Input = None
InputSerializer = _DynamicSerializer
Output = None
OutputSerializer = _DynamicSerializer
Settings = parameters.CompoundParam
def __init__(self,
settings=None,
config=None,
step_id=None,
_run_info=None,
**kwargs):
super().__init__()
if not step_id:
self._step_id = type(self).__name__
else:
self._step_id = step_id
if _run_info is None:
_run_info = collections.defaultdict(dict)
self._setRunInfo(_run_info)
self._outputs_gen = None
self.setSettings(settings, **kwargs)
self._setCompositionPath(type(self).__name__)
self._setConfig(config)
self._input_file = None
self._inputs = None
def isBigQueryBatched(self):
return False
@classmethod
def __init_subclass__(cls):
"""
Validate the validity of the class.
"""
if cls.InputSerializer is not _DynamicSerializer:
if cls.Input is None or not issubclass(
cls.Input, cls.InputSerializer.DataType):
msg = (
'Incompatible InputSerializer specified. \n'
f'Step "{cls.__name__}" has Input "{cls.Input}" '
f'but InputSerializer has DataType "{cls.InputSerializer.DataType}"'
)
raise TypeError(msg)
if cls.OutputSerializer is not _DynamicSerializer:
if cls.Output is None or (
cls.Output != cls.OutputSerializer.DataType and
not issubclass(cls.Output, cls.OutputSerializer.DataType)):
msg = (
'Incompatible OutputSerializer specified. \n'
f'Step "{cls.__name__}" has Output "{cls.Output}" '
f'but OutputSerializer has DataType "{cls.OutputSerializer.DataType}"'
)
raise TypeError(msg)
if (not isinstance(cls.Settings, type) or
not issubclass(cls.Settings, parameters.CompoundParam)):
raise TypeError("Custom settings must subclass CompoundParam")
super().__init_subclass__()
def _getCanonicalizedConfig(self):
return {self.getStepId(): self.settings.toDict()}
def report(self, prefix=''):
"""
Report the settings and batch settings for this step.
"""
logger.info(f'{prefix} - {self.getStepId()}')
for opts in (self.settings, self._batch_settings):
if opts and opts.toDict():
logger.info(
f'{prefix} {opts.__class__.__name__}: {opts.toDict()}')
def prettyPrintRunInfo(self):
"""
Format and print info about the step's run.
"""
run_info = copy.deepcopy(self.getRunInfo())
self._prettifyRunInfo(run_info)
# Listify the dict into tuples since prettyprint doesnt respect
# dictionary order
run_info = list(run_info.items())
pprint.pprint(run_info)
def _prettifyRunInfo(self, run_info_dict):
"""
Recurse through `run_info_dict` and listify dicts into item tuples.
This improves the readability of pretty-print and preserves the
dictionary insertion order.
"""
for k, v in run_info_dict.items():
if isinstance(v, dict):
self._prettifyRunInfo(v)
def __copy__(self):
copied_step = type(self)(
settings=copy.copy(self.settings),
config=self._getCanonicalizedConfig(),
step_id=self.getStepId())
return copied_step
def _getInputSerializer(self):
if issubclass(self.InputSerializer, _DynamicSerializer):
return _DynamicSerializer(dataclass=self.Input)
else:
return self.InputSerializer()
def getOutputSerializer(self):
if issubclass(self.OutputSerializer, _DynamicSerializer):
return _DynamicSerializer(dataclass=self.Output)
else:
return self.OutputSerializer()
def _validateStepperFileSettings(self):
"""
Look through settings for StepperFiles and StepperFolders and
confirms that that they're set to valid files and folder paths
:return: A list of `SettingsError`, one for each invalid stepper file
:rtype: list[SettingsError]
"""
results = []
if self.settings is None:
return results
settings = self.settings
for subparam_name, abstract_subparam in self.Settings.getSubParams(
).items():
if abstract_subparam.DataClass is StepperFile:
stepperfile = abstract_subparam.getParamValue(settings)
if stepperfile is None:
results.append(
SettingsError(
self,
f"<{self._step_id}> setting '{subparam_name}' has "
"not been set."))
elif not os.path.isfile(stepperfile):
results.append(
SettingsError(
self,
f"<{self._step_id}> setting '{subparam_name}' "
f"set to invalid file path: '{str(stepperfile)}'"))
if abstract_subparam.DataClass is StepperFolder:
stepperfolder = abstract_subparam.getParamValue(settings)
if stepperfolder is None:
results.append(
SettingsError(
self,
f"<{self._step_id}> setting '{subparam_name}' has "
"not been set."))
elif not os.path.isdir(stepperfolder):
results.append(
SettingsError(
self,
f"<{self._step_id}> setting '{subparam_name}' "
"set to invalid dir path: '{str(stepperfolder)}'"))
return results
def validateSettings(self):
"""
Check whether the step settings are valid and return a list of
`SettingsError` and `SettingsWarning` to report any invalid settings.
Default implementation checks that all stepper files are set to valid
file paths.
:rtype: list[TaskError or TaskWarning]
"""
return self._validateStepperFileSettings()
def _setCompositionPath(self, path):
"""
Update the composition path. The composition path is the string
that defines a steps ancestry. For example, a composition path "A>B>C"
means that this step, C, is in a chain B, which is itself in a chain
A.
"""
self._comp_path = path
def _setStepId(self, new_id):
self._step_id = new_id
def getStepId(self):
return self._step_id
def _setRunInfo(self, run_info):
self._run_info = run_info
def getRunInfo(self):
return self._run_info
def _setConfig(self, config):
if config:
# Sort by length of the selectors so that we apply child selectors
# by order of selectivity. (Based on assumption that longer
# selectors have longer keys)
if '__sorted' not in config:
config = dict(
sorted(config.items(), key=lambda item: len(item[0])))
config['__sorted'] = True
for k in config:
if self._comp_path.endswith(k):
self._applyConfigSettings(config[k])
# Apply ID selector settings last so they take final priority
if self._step_id in config:
self._applyConfigSettings(config[self._step_id])
self._config = config
def _applyConfigSettings(self, new_settings):
if new_settings:
for k, v in new_settings.items():
if v is None:
continue
if not hasattr(self.Settings, k):
raise SettingsError(self, f"Step \"{type(self).__name__}\""
f" has no setting \"{k}\"")
if getattr(self.Settings, k).DataClass is StepperFile:
new_settings[k] = StepperFile(v)
elif getattr(self.Settings, k).DataClass is StepperFolder:
new_settings[k] = StepperFolder(v)
self.settings.setValue(**new_settings)
def setInputFile(self, fname):
self._input_file = fname
self.setInputs(self._inputsFromFile(fname))
def setInputBQTable(self, bq_table, bq_dataset=None):
self._setInputBQTable(bq_table)
serializer = self._getInputSerializer()
def lazy_gen():
for inp in _get_table_data(bq_table.table_id, bq_table.starting_idx,
bq_table.batch_size):
yield serializer.fromString(inp)
self.setInputs(lazy_gen())
def setOutputBQTable(self, bq_table):
self._output_table = BQTable(bq_table)
def _inputsFromFile(self, fname):
serializer = self._getInputSerializer()
yield from serializer.deserialize(fname)
def writeOutputsToFile(self, fname):
"""
Write outputs to `fname`. By default, the output file will consist of
one line for each output with whatever is produced when passing the out-
put to `str`. Override this method if more complex behavior is needed.
"""
serializer = self.getOutputSerializer()
serializer.serialize(self.outputs(), fname)
def writeOutputsToTable(self):
serializer = self.getOutputSerializer()
_load_in_batches(self.outputs(), serializer, self._getOutputBQTableId())
def setUp(self):
"""
Hook for adding any type of work that needs to happen before any
outputs are created.
"""
pass
def cleanUp(self):
"""
Hook for adding any type of work that needs to happen after all
outputs are exhausted or if some outputs are created and the step
is destroyed.
"""
pass
@_assert_step_hasnt_started
def setSettings(self, settings=None, **kwargs):
"""
Supply the settings for this step to use when running. The supplied
settings must match the Settings class or, if None is passed in, a
default settings object will be used.
"""
if settings is not None and kwargs:
raise ValueError('Cannot specify both settings and kwargs')
elif self.Settings is None:
if settings is not None or kwargs:
raise ValueError("Specified settings for a step that doesn't "
"expect settings")
elif settings is None:
settings = self.Settings(**kwargs)
elif not isinstance(settings, self.Settings):
raise ValueError(f"settings should be of type {self.Settings}, not "
f"{type(settings)}.")
self.settings = settings
@_assert_step_hasnt_started
def setInputs(self, inputs):
"""
Set the input source for this step. This should be an iterable. Items
from the input source won't actually be accessed until the outputs for
this step are accessed.
"""
if inputs is None:
inputs = []
self._inputs = inputs
def inputs(self):
yield from self._inputs
@_assert_step_hasnt_started
def outputs(self):
"""
Creates the output generator for this step and returns it.
"""
self.setUp()
self._run_info[self.getStepId()] = {}
outputs_gen = self._makeOutputGenerator()
outputs_gen = self._outputsWithCounting(outputs_gen)
self._outputs_gen = self._cleanUp_after_generator(outputs_gen)
return self._outputs_gen
def _outputsWithCounting(self, output_gen):
self._output_count = 0
self._end_time = None
def wrapped_output_gen():
for output in output_gen:
self._output_count += 1
yield output
self._end_time = time.time()
self._updateRunInfo()
return wrapped_output_gen()
def _cleanUp_after_generator(self, gen):
"""
Call the step's cleanUp method once the generator has been
exhausted.
"""
try:
for output in gen:
yield output
finally:
self.cleanUp()
def _updateRunInfo(self):
step_run_info = self._run_info[self.getStepId()]
step_run_info['num_inputs'] = self._input_count
step_run_info['num_outputs'] = getattr(self, '_output_count', 0)
if getattr(self, '_end_time', None) is not None:
duration = self._end_time - self._start_time
else:
duration = time.time() - self._start_time
# Turn end_time, start_time, and duration into human readable strings
step_run_info['start_time'] = _prettify_time(self._start_time)
end_time = getattr(self, '_end_time', None)
if end_time:
step_run_info['end_time'] = _prettify_time(self._end_time)
step_run_info['duration'] = _prettify_duration(duration)
self._emitProgressUpdated()
application.process_events()
def _getElapsedTime(self):
if self._start_time is None:
raise RuntimeError("Can't get elapsed time when step hasn't been "
"started.")
return _prettify_duration(time.time() - self._start_time)
def _emitProgressUpdated(self):
"""
Emit a progress updated signal with serialized dump of this step's run
info. Note that the run info will only contain information about this
step and not about any batches. We emit only the stripped down info
since it's expected that any listeners will also be listening to
progress changes in the batched jobs.
"""
run_info = copy.deepcopy(self.getRunInfo())
run_info.pop('batches', None)
for v in run_info.values():
v.pop('batches', None)
self.progressUpdated.emit(json.dumps(run_info))
def _makeOutputGenerator(self):
raise NotImplementedError()
def getOutputs(self):
"""
Gets all the outputs in a list by fully iterating the output generator.
"""
return list(self.outputs())
def getLicenseRequirements(self):
return {}
def _rehydrate_step(dehydrated_step: _DehydratedStep):
"""
Recreate the step that `dehydrated_step` was created from.
"""
with env.prepend_sys_path(os.getcwd()):
step_module = imputils.get_module_from_path(
dehydrated_step.step_module_path)
step_class = getattr(step_module, dehydrated_step.step_class_name)
return step_class._rehydrateStep(dehydrated_step)
class _BigQueryBatchableStepMixin:
def _getBQDataset(self):
if self._batch_settings is None:
return None
else:
return self._batch_settings._bq_dataset
def _setBQDataset(self, dataset: str):
if self._batch_settings is not None:
self._batch_settings._bq_dataset = dataset
@staticmethod
def is_bq_step(step):
return isinstance(step,
_BatchableStepMixin) and step.isBigQueryBatched()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._input_table = None
self._output_table = None
def cleanUpTables(self):
if self._input_table is not None:
self._input_table.delete()
if self._output_table is not None:
self._output_table.delete()
def _dehydrateStep(self):
dehyd_step = super()._dehydrateStep()
if self._input_table is not None:
dehyd_step.input_bq_table = self._input_table
if self._output_table is not None:
dehyd_step.output_bq_table = self._output_table
return dehyd_step
@classmethod
def _rehydrateStep(cls, dehydrated_step):
rehyd_step = super()._rehydrateStep(dehydrated_step)
if dehydrated_step.input_bq_table.table_id is not None:
rehyd_step.setInputBQTable(dehydrated_step.input_bq_table)
if dehydrated_step.output_bq_table.table_id is not None:
rehyd_step.setOutputBQTable(dehydrated_step.output_bq_table)
return rehyd_step
def _initializeInputBQTable(self):
#TODO: Uniquify this, probably using a dataset created by the chain
input_table_id = self.getStepId().replace('.', '-') + '_inputs'
if self._batch_settings._bq_dataset is None:
bq_dataset_name = _initialize_stepper_bq_dataset(self)
self._batch_settings._bq_dataset = bq_dataset_name
input_table_id = self._batch_settings._bq_dataset + '.' + input_table_id
input_table = BQTable(table_id=input_table_id)
self._setInputBQTable(input_table)
def _setInputBQTable(self, table):
self._input_table = table
def _getInputBQTableId(self):
if self._input_table is None:
return None
return self._input_table.table_id
def _initializeOutputBQTable(self):
#TODO: Uniquify this, probably using a dataset created by the chain
output_table_id = self.getStepId().replace('.', '-') + '_outputs'
output_table_id = self._batch_settings._bq_dataset + '.' + output_table_id
output_table = BQTable(table_id=output_table_id)
self._output_table = output_table
try: # FIXME: Remove
_delete_table(self._getOutputBQTableId())
except:
pass
_create_table(self._getOutputBQTableId())
def _getOutputBQTableId(self):
if self._output_table is None:
return None
return self._output_table.table_id
def _createInputBQTable(self):
self._initializeInputBQTable()
try: # FIXME: Remove
_delete_table(self._getInputBQTableId())
except:
pass
_create_table(self._getInputBQTableId())
def _createOutputBQTable(self):
self._initializeOutputBQTable()
def _uploadToBQTable(self, inps):
serializer = self._getInputSerializer()
_load_in_batches(inps, serializer, self._getInputBQTableId())
def _splitInpTableToBatchTables(self):
bsize = self._batch_settings.size
rowcount = _table_row_count(self._input_table.table_id)
if rowcount >= DOUBLE_BATCH_THRESHOLD * bsize:
bsize = DOUBLE_BATCH_THRESHOLD * bsize
for start_idx in range(0, rowcount, bsize):
tb = BQTable(
table_id=self._input_table.table_id,
starting_idx=start_idx,
batch_size=bsize)
yield tb
def _makeBigQueryStep(self, batch_table):
step = copy.copy(self)
step.setInputBQTable(batch_table)
step.setOutputBQTable(self._output_table)
return step
def _makeBigQueryBatchTask(self, table):
step = self._makeBigQueryStep(table)
task = self._batch_settings.task_class(step=step)
if table.batch_size > self._batch_settings.size:
task._double_batch = True
if issubclass(self._batch_settings.task_class, StepJobTask):
task.job_config.host_settings.host = hosts.Host(
self._batch_settings.hostname)
for req_license, num_tokens in self.getLicenseRequirements().items():
task.addLicenseReservation(req_license, num_tokens)
return task
def _queryOutputBQTable(self):
serializer = self.getOutputSerializer()
for outp in _get_table_data(self._output_table.table_id):
yield serializer.fromString(outp)
@_assert_step_hasnt_started
def outputs(self):
if not (self._batch_settings and
self._batch_settings.batch_by_bigquery):
return super().outputs()
if not self._getInputBQTableId():
self._createInputBQTable()
self._uploadToBQTable(self._inputsWithCounting())
if not self._output_table:
self._createOutputBQTable()
self._run_info[self.getStepId()] = {
'batches': collections.defaultdict(dict)
}
tasks = []
for idx, table in enumerate(self._splitInpTableToBatchTables()):
task = self._makeBigQueryBatchTask(table)
task.name = self.getStepId() + '_' + str(idx)
tasks.append(task)
task.taskDone.connect(self._onTaskDone)
task.taskStarted.connect(self._onTaskStarted)
task.taskFailed.connect(self._onTaskFailed)
self._start_time = time.time()
self._logDebugMsg(f'Running {len(tasks)} batched tasks')
queue.run_tasks_on_dj(tasks)
self._logDebugMsg(f'Batched step {self.getStepId()} complete')
def get_table_results():
yield from self._queryOutputBQTable()
self._outputs_gen = get_table_results()
return self._outputs_gen
def _logDebugMsg(self, msg):
logger.debug(f' +{self._getElapsedTime()}: {msg}')
def _onTaskDone(self):
task = self.sender()
self._logDebugMsg(f'Batch {task.name} completed successfully')
def _onTaskStarted(self):
task = self.sender()
self._logDebugMsg(f'Batch {task.name} started')
def _onTaskFailed(self):
task = self.sender()
self._logDebugMsg(f'Batch {task.name} failed!')
class __BatchableStepMixin:
"""
A step that can distribute its input into multiple batches and processes
them in parallel as tasks. Example::
# Running a batcher as a single step
b = ProcessSmilesChain(batch_size=10)
b.setInputFile(smiles_filename)
for output in b.outputs():
print(output)
"""
def __init__(self, *args, batch_size=None, batch_settings=None, **kwargs):
if batch_size and batch_settings:
raise TypeError("Can't pass both batch_size and batch_settings")
elif batch_size is not None:
batch_settings = BatchSettings(size=batch_size)
self._batch_settings = batch_settings
super().__init__(*args, **kwargs)
@_assert_step_hasnt_started
def setBatchSettings(self, batch_settings):
"""
Set the batch settings for this step. Will raise an exception if this
is done after the step has already started processing inputs.
:type batch_settings: BatchSettings
"""
self._batch_settings = batch_settings
def isBigQueryBatched(self):
return self._batch_settings and self._batch_settings.batch_by_bigquery
def _prettifyRunInfo(self, run_info_dict):
super()._prettifyRunInfo(run_info_dict)
if 'batches' in run_info_dict:
batch_infos = []
if not isinstance(run_info_dict['batches'], dict):
return
for batch_job_id, batch_info in run_info_dict['batches'].items():
self._prettifyRunInfo(batch_info)
batch_infos.append((batch_job_id, list(batch_info.items())))
run_info_dict['batches'] = batch_infos
def _applyConfigSettings(self, new_settings):
new_settings = copy.deepcopy(new_settings)
if 'batch_settings' in new_settings:
for k in new_settings['batch_settings']:
if not hasattr(BatchSettings, k):
raise SettingsError(
self,
f"Specified batch setting does not exist: \"{k}\"")
self.setBatchSettings(
BatchSettings(**new_settings.pop('batch_settings')))
super()._applyConfigSettings(new_settings)
def _getCanonicalizedConfig(self):
"""
Return a config that can be used to set the settings for a different
instance of this step to the same settings as this step.
"""
if isinstance(self.settings, parameters.CompoundParam):
canon_config = super()._getCanonicalizedConfig()
if self._batch_settings:
batch_settings_dict = self._batch_settings.toDict()
# Setting task class through config is currently unsupported
batch_settings_dict.pop('task_class')
canon_config[self.getStepId()][
'batch_settings'] = batch_settings_dict
return canon_config
return {}
def _dehydrateStep(self):
"""
Create a `_DehydratedStep` from this instance of a step. A dehydrated
step has all the information necessary to recreate a step sans inputs
and can be serialized in a json file.
"""
dehyd = _DehydratedStep()
step_module = inspect.getmodule(self)
dehyd.step_module_path = imputils.get_path_from_module(step_module)
dehyd.step_class_name = type(self).__name__
dehyd.step_id = self._step_id
dehyd.step_config = self._getCanonicalizedConfig()
dehyd.input_file = self._input_file
return dehyd
@classmethod
def _rehydrateStep(cls, dehydrated_step):
"""
Recreate the step that `dehydrated_step` was created from.
"""
step = cls(
step_id=dehydrated_step.step_id, config=dehydrated_step.step_config)
if dehydrated_step.input_file:
step.setInputFile(dehydrated_step.input_file)
return step
def _makeStep(self, input_file):
step = copy.copy(self)
step.setInputFile(input_file)
return step
def getLicenseRequirements(self):
return {}
def _makeBatchTask(self, batch_file):
step = self._makeStep(batch_file)
task = self._batch_settings.task_class(step=step)
if issubclass(self._batch_settings.task_class, StepJobTask):
task.job_config.host_settings.host = hosts.Host(
self._batch_settings.hostname)
for req_license, num_tokens in self.getLicenseRequirements().items():
task.addLicenseReservation(req_license, num_tokens)
return task
def _queueBatchSteps(self, task_queue):
for batch_num, batch_file, double_batch in self._splitInputsIntoBatchFiles(
):
application.process_events()
task = self._makeBatchTask(batch_file)
task._double_batch = double_batch
task.name, _ = os.path.splitext(os.path.basename(batch_file))
task_queue.addTask(task)
def _splitInputsIntoBatchFiles(self):
serializer = self._getInputSerializer()
inps = self._inputsWithCounting()
continue_with_double_batching = False
for batch_num, batch_of_lines in enumerate(
more_itertools.ichunked(inps, self._batch_settings.size)):
batch_fname = self.getStepId() + '_batch_' + str(batch_num) + '.in'
serializer.serialize(batch_of_lines, batch_fname)
yield batch_num, batch_fname, False
if batch_num + 1 >= DOUBLE_BATCH_THRESHOLD:
continue_with_double_batching = True
break
if continue_with_double_batching:
double_batch_size = self._batch_settings.size * DOUBLE_BATCH_THRESHOLD
double_batches = more_itertools.ichunked(inps, double_batch_size)
for batch_num, batch_of_lines in enumerate(
double_batches, start=batch_num + 1):
batch_fname = self.getStepId() + '_batch_' + str(
batch_num) + '.in'
serializer.serialize(batch_of_lines, batch_fname)
yield batch_num, batch_fname, True
@_assert_step_hasnt_started
def outputs(self):
"""
Like the super class method, returns a generator for the outputs.
Calling the generator begins the batching process by requesting outputs
from the input source (previous step), accumulating them into batches
of the specified size, and queuing them all up.
"""
if self._batch_settings is None:
return super().outputs()
else:
self._run_info[self.getStepId()] = {
'batches': collections.defaultdict(dict)
}
task_dj = queue.TaskDJ(max_failures=queue.NOLIMIT)
self._queueBatchSteps(task_dj)
if not task_dj.waiting_jobs:
# We didn't have any batches to process, so just return early
return []
outputs_gen = self._makeBatchedOutputsGenerator(task_dj)
outputs_gen = self._outputsWithCounting(outputs_gen)
self._outputs_gen = outputs_gen
return outputs_gen
def _updateBatchRunInfo(self, batch_name, new_batch_info):
stepid = self.getStepId()
batch_info = self._run_info[stepid]['batches'][batch_name]
batch_info.update(new_batch_info)
batch_info.update(batch_info.pop(stepid))
def _makeBatchedOutputsGenerator(self, task_dj):
for task in task_dj.updatedTasks():
if task.status is task.DONE:
self._updateBatchRunInfo(task.name, task.output.run_info)
branch_count = task.name.count('.')
logger.info(f'{">"*branch_count}START {task.name} log')
logger.info(task.getLogAsString().strip())
logger.info(f'{">"*branch_count}END {task.name} log')
task.wait()
outp_file = task.getTaskFilename(task.output.output_file)
assert os.path.isfile(outp_file), outp_file
serializer = self.getOutputSerializer()
for outp in serializer.deserialize(outp_file):
yield outp
elif task.status is task.FAILED:
logger.error("task failed")
branch_count = task.name.count('.')
logger.error(f"FAILURE WHEN RUNNING {task.name}")
try:
_write_repro_file(task)
except Exception:
logger.error(
"Error when writing the reproduction zip. Try "
f"reproducing manually with {task.name}'s inputs.")
else:
logger.error(
f"Files for reproducing step saved to: {task.name}_repro.rzip"
)
logger.error(f'{">"*branch_count}START {task.name} log')
try:
logger.error(task.getLogAsString())
except FileNotFoundError:
logger.error(f"** LOG FILE NOT SAVED FOR {task.name} **")
logger.error(f'{">"*branch_count}END {task.name} log')
class _BatchableStepMixin(_BigQueryBatchableStepMixin, __BatchableStepMixin):
pass
[docs]class UnbatchedReduceStep(_BaseStep):
""""
An unbatchable ReduceStep. See ReduceStep for more information.
"""
def _makeOutputGenerator(self):
return self.reduceFunction(self._inputsWithCounting())
def _inputsWithCounting(self):
self._start_time = time.time()
self._input_count = 0
self._updateRunInfo()
for input in self._inputs:
self._input_count += 1
yield input
[docs] def reduceFunction(self, inputs):
raise NotImplementedError
[docs]class ReduceStep(_BatchableStepMixin, UnbatchedReduceStep):
""" <_reduce_step_>
A computational step that performs a function on a collection of inputs
to produce output items.
To construct a ReduceStep:
* Implement reduceFunction
* Define Input (the type expected by the mapFunction)
* Define Output (the type of item produced by the mapFunction)
* Define Settings (data class for any settings needed by the
mapFunction)
"""
[docs] def reduceFunction(self, inputs):
"""
The main computation for this step. This function should take in a
iterable of inputs and return an iterable of outputs.
Example::
def reduceFunction(self, words):
# Find all unique words
seen_words = set()
for word in words:
if word not in seen_words:
seen_words.add(word)
yield word
"""
return super().reduceFunction(inputs)
[docs]class UnbatchedMapStep(UnbatchedReduceStep):
""" <_unbatchability_>
An unbatchable MapStep. See MapStep for more information.
"""
[docs] def reduceFunction(self, inputs):
for input in inputs:
for output in self.mapFunction(input):
yield output
[docs] def mapFunction(self, input):
raise NotImplementedError()
[docs]class MapStep(_BatchableStepMixin, UnbatchedMapStep):
""" <_map_step_>
A computational step that performs a function on input items from an input
source to produce output items.
To construct a MapStep:
* Implement mapFunction
* Define Input (the type expected by the mapFunction)
* Optionally define a InputSerializer (see `Serializer` for more info.)
* Define Output (the type of item produced by the mapFunction)
* Optionally define a OutputSerializer (see `Serializer` for more info.)
* Define Settings (data class for any settings needed by the mapFunction)
"""
[docs] def mapFunction(self, input):
"""
The main computation for this step. This function should take in a
single input item and return an iterable of outputs. This allows a
single output to produce multiple ouputs (e.g. enumeration).
The output may be yielded as a generator, in order to reduce memory
usage.
If only a single output is produced for each input, return it as a
single-element list.
:param input: this will be a single input item from the input source.
Implementer is encouraged to use a more descriptive, context-
specific variable name. Example:
def mapFunction(self, starting_smiles):
...
"""
return super().mapFunction(input)
[docs]class UnbatchedChain(UnbatchedReduceStep):
def __copy__(self):
copied_step = super().__copy__()
copied_step._setStartingStep(self._starting_step_id)
return copied_step
def _setStartingStep(self, starting_step: str):
if starting_step is not None:
self._validateStartingStepId(starting_step)
self._starting_step_id = starting_step
[docs] def validateSettings(self):
"""
Check whether the chain settings are valid and return a list of
`SettingsError` and `SettingsWarning` to report any invalid settings.
Default implementation simply returns problems from all child steps.
:rtype: list[TaskError or TaskWarning]
"""
problems = []
for step in self:
problems += step.validateSettings()
return problems
@property
def Input(self):
if not self._steps:
return None
return self[0].Input
@property
def Output(self):
if not self._steps:
return None
return self[-1].Output
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._setStartingStep(None)
self._updateChain()
def __getitem__(self, idx):
return self._steps[idx]
def _setStepId(self, new_id):
super()._setStepId(new_id)
self._updateChain()
[docs] def __len__(self):
return len(self._steps)
def _setConfig(self, config):
super()._setConfig(config)
self._updateChain()
def _getCanonicalizedConfig(self):
"""
Return a config that can be used to set the settings for a different
instance of this chain and its substeps to the same settings as this
chain and its substeps.
"""
config = super()._getCanonicalizedConfig()
for child_step in self:
config.update(child_step._getCanonicalizedConfig())
return config
def _updateChain(self):
self._steps = []
self.buildChain()
self._updateComponentStepIDs()
self._updateComponentStepConfigs()
self.validateChain()
def _updateComponentStepIDs(self):
step_type_counter = collections.Counter()
for step in self:
step_count = step_type_counter[type(step)]
step._setStepId(
f'{self._step_id}.{type(step).__name__}_{step_count}')
step_type_counter[type(step)] += 1
[docs] def addStep(self, step):
self._steps.append(step)
step._setCompositionPath(self._comp_path + '>' + step._comp_path)
step._setRunInfo(self._run_info)
step.progressUpdated.connect(self._emitProgressUpdated)
def _updateComponentStepConfigs(self):
for step in self:
step._setConfig(self._config)
[docs] def report(self, prefix=''):
"""
Report the workflow steps and their settings (recursively).
:param prefix: the text to start each line with
:type prefix: str
"""
super().report(prefix)
for step in self:
step.report(prefix + ' ')
[docs] def validateChain(self):
"""
Checks that the declaration of the chain is internally consistent - i.e.
that each step is valid and each step's Input class matches the
preceding step's Output class.
"""
if len(self) == 0:
return
for prev_step, next_step in more_itertools.pairwise(self):
err_msg = (f"Mismatched Input and Output.\n"
f"Previous step: {prev_step}\n"
f"Output: {prev_step.Output}\n"
f"Next step: {next_step}\n"
f"Input: {next_step.Input}\n")
if None in (next_step.Input, prev_step.Output):
assert prev_step.Output is next_step.Input, err_msg
else:
assert prev_step.Output == next_step.Input or issubclass(
prev_step.Output, next_step.Input), err_msg
first_step = self[0]
msg = (f'Mismatched input of first step. The Input for the chain'
f'("{type(self).__name__}") is specified as {self.Input}'
' but the Input for the first step '
f'("{type(first_step).__name__}") is {first_step.Input}')
assert first_step.Input is self.Input, msg
last_step = self[-1]
msg = (f'Mismatched output of last step. The Output for the chain'
f'("{type(self).__name__}") is specified as {self.Output}'
' but the Output for the last step '
f'("{type(last_step).__name__}") is {last_step.Output}')
assert last_step.Output is self.Output, msg
def _validateStartingStepId(self, step_id: str):
"""
Checks to see if the `step_id` actually matches a step in this chain.
If not, raise a ValueError.
"""
if step_id == self.getStepId():
return
for idx, step in enumerate(self):
if step_id.startswith(step.getStepId()):
if isinstance(step, Chain):
step._validateStartingStepId(step_id)
break
else:
if step.getStepId() == step_id:
break
else:
raise ValueError("Invalid starting step ID: " + step_id)
[docs] def reduceFunction(self, inputs):
bq_dataset = self._getBQDataset()
self._updateChain()
if bq_dataset is not None:
self._setBQDataset(bq_dataset)
if len(self) == 0:
return inputs
if self.hasBQBatching() and self._getBQDataset() is None:
name = _initialize_stepper_bq_dataset(self)
self._setBQDataset(name)
starting_step_idx = 0
if self._starting_step_id is not None:
for idx, step in enumerate(self):
if self._starting_step_id.startswith(step.getStepId()):
starting_step_idx = idx
break
starting_step = self[starting_step_idx]
if isinstance(starting_step, Chain):
starting_step.setInputs(
inputs, starting_step_id=self._starting_step_id)
else:
starting_step.setInputs(inputs)
is_bq_step = _BigQueryBatchableStepMixin.is_bq_step
for step in self:
if is_bq_step(step):
step._initializeOutputBQTable()
for prev_step, next_step in more_itertools.pairwise(
self[starting_step_idx:]):
if is_bq_step(prev_step) and is_bq_step(next_step):
next_step.setInputBQTable(prev_step._output_table)
# Call outputs in order to execute the step and populate
# the output BQ table. If we don't do this, we won't be able
# to use the output table in downstream steps.
prev_step.outputs()
else:
next_step.setInputs(prev_step.outputs())
last_step = self[-1]
return last_step.outputs()
[docs] def buildChain(self):
"""
This method must be implemented by subclasses to build the chain. The
chain is built by modifying self.steps. The chain's composition may be
dependent on self.settings.
"""
raise NotImplementedError()
def _initialize_stepper_bq_dataset(step):
suffix = str(uuid.uuid4())[:8]
dataset_name = step.getStepId().replace('.', '-') + '_' + suffix
print("Using Bigquery Dataset: ", dataset_name)
_create_dataset(dataset_name)
return dataset_name
[docs]class Chain(_BatchableStepMixin, UnbatchedChain):
""" <_chain_>
Run a series of steps. The steps must be created by overriding buildChain.
"""
def _getBQDataset(self):
if self._batch_settings is None:
for step in self:
if (isinstance(step, _BigQueryBatchableStepMixin) and
step._getBQDataset() is not None):
return step._getBQDataset()
else:
return None
else:
return self._batch_settings._bq_dataset
def _setBQDataset(self, dataset: str):
for step in self:
if isinstance(step, _BigQueryBatchableStepMixin):
step._setBQDataset(dataset)
super()._setBQDataset(dataset)
[docs] def hasBQBatching(self):
for step in self:
if isinstance(step, UnbatchedChain):
if step.hasBQBatching() or step.isBigQueryBatched():
return True
elif step.isBigQueryBatched():
return True
else:
return False
[docs] def getLicenseRequirements(self):
req_licenses = collections.Counter()
for step in self:
if not (isinstance(step, _BatchableStepMixin) and
step._batch_settings is not None):
req_licenses = req_licenses | collections.Counter(
step.getLicenseRequirements())
return dict(req_licenses)
def _dehydrateStep(self):
dehyd = super()._dehydrateStep()
dehyd.starting_step_id = self._starting_step_id
return dehyd
@classmethod
def _rehydrateStep(cls, dehydrated_step: _DehydratedStep) -> 'Chain':
"""
Recreate the step that `dehydrated_step` was created from.
"""
step = super()._rehydrateStep(dehydrated_step)
step._setStartingStep(dehydrated_step.starting_step_id)
return step
def _line_count(filename):
count = 0
with open(filename, 'r') as file:
for line in file:
count += 1
return count
### Debugging helper methods, not for use in production.
def _get_all_stepper_input_files():
input_file_pattern = os.path.join('**', '*.in')
return glob.glob(input_file_pattern, recursive=True)
def _get_all_stepper_output_files():
output_file_pattern = os.path.join('**', '*.out')
return glob.glob(output_file_pattern, recursive=True)
def _get_all_stepper_zip_files():
output_file_pattern = os.path.join('**', '*.rzip')
return glob.glob(output_file_pattern, recursive=True)
def _write_repro_file(steptask):
"""
Write a rzip with...
- the input file for the step
- the yaml config file for the step
- a command for rerunning the step with the above input files
- any necessary settings files/folders
"""
repro_fname = f'{steptask.name}_repro.rzip'
with zipfile.ZipFile(repro_fname, 'w') as repro_zipfile:
dehyd_step = steptask._step._dehydrateStep()
for step_id, step_settings in dehyd_step.step_config.items():
if step_id.startswith(steptask._step.getStepId()):
for name, value in step_settings.items():
if isinstance(value, StepperFile):
repro_zipfile.write(value, value)
elif isinstance(value, StepperFolder):
for root, _, files in os.walk(value):
for filename in files:
src_path = os.path.join(root, filename)
repro_zipfile.write(src_path)
yaml_fname = f'{steptask.name}.yaml'
with open(yaml_fname, 'w') as yaml_file:
yaml.dump(dict(dehyd_step.step_config), yaml_file)
cmd_fname = f'{steptask.name}.sh'
with open(cmd_fname, 'w') as cmd_file:
cmd_file.write(
f'$SCHRODINGER/run stepper.py '
f'{dehyd_step.step_module_path}.{dehyd_step.step_class_name} '
f'{dehyd_step.input_file} bad_step.out -config {yaml_fname} '
f'-workflow-id {dehyd_step.step_id}')
repro_zipfile.write(dehyd_step.input_file,
os.path.basename(dehyd_step.input_file))
repro_zipfile.write(yaml_fname)
repro_zipfile.write(cmd_fname)
def _get_stepper_debug_files():
# Return all stepper repro zip files
return _get_all_stepper_zip_files()