Source code for schrodinger.application.steps.bigquery_deduplication

try:
    from google.auth import compute_engine
    from google.cloud import bigquery
    from google.cloud import storage
    from google.cloud import exceptions
except ImportError:
    compute_engine = None
    bigquery = None
    storage = None
    exceptions = None

from schrodinger.models import parameters
from schrodinger.tasks import stepper
from schrodinger.utils import fileutils

from .basesteps import MolMolMixin

# Scopes
BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery"
CLOUD_SCOPE = "https://www.googleapis.com/auth/cloud-platform"

# Error messages
NO_BUCKET = 'No bucket defined'
NO_PROJECT = 'No project defined'
NO_DATASET = 'No dataset defined'


[docs]def generate_clients(node_auth=True): """ Generate GCP and BigQuery clients. Defaults to authenticating via node. Otherwise, requires a service account authentication via the file specified at GOOGLE_APPLICATION_CREDENTIALS """ kwargs = {} if node_auth: kwargs['credentials'] = compute_engine.Credentials() return bigquery.Client(**kwargs), storage.Client(**kwargs)
[docs]class BQDedupeSettings(parameters.CompoundParam): bucket: str project: str dataset: str table_name: str = 'smiles_table'
[docs]class RandomSampleBQDedupeSettings(BQDedupeSettings): rows_to_sample: int = 5000
[docs]class AbstractBQDedupeStep(stepper.ReduceStep): """ Map Step that deduplicates a singular input file of compounds using BigQuery Override getQuery() to use a specific query for the deduplication """ Settings = BQDedupeSettings
[docs] def validateSettings(self): errors = [] if not self.settings.bucket: errors.append(stepper.SettingsError(self, NO_BUCKET)) if not self.settings.project: errors.append(stepper.SettingsError(self, NO_PROJECT)) if not self.settings.dataset: errors.append(stepper.SettingsError(self, NO_DATASET)) return errors
[docs] def reduceFunction(self, inps): self.bq_client, self.storage_client = generate_clients(False) with fileutils.tempfilename() as f: self._combineSMILES(inps, f) table = self._uploadFileToTable(f) yield from self._runDedupeQuery(table)
[docs] def getQuery(self, table_name): """ Returns formatted query for BQ load job :param table_name: name of table to query :type table_name: str """ raise NotImplementedError
def _combineSMILES(self, inps, filepath): """ Combines input into a newline separated file """ serializer = self._getInputSerializer() with open(filepath, 'w') as f: for inp in inps: st = serializer.toString(inp) f.write(f'{st}\n') def _uploadFileToTable(self, filename): """ Uploads CSV file to BigQuery Table, via GCP. :param filename: name of file to upload :type filename: str :return: Newly uploaded table :rtype: bigquery.table.Table """ bucket = self.storage_client.get_bucket(self.settings.bucket) bucket_file = bucket.blob(self.settings.table_name) bucket_file.upload_from_filename(filename) table = self._uploadFileToBQ(bucket_file) return table def _runDedupeQuery(self, table): """ Runs deduplication query and yields resultant rows Deletes table. :param table: Table to run query on :type table: bigquery.table.Table """ table_name = f'{self.settings.project}.{self.settings.dataset}.{table.table_id}' query_str = self.getQuery(table_name) query_job = self.bq_client.query(query_str) query_job.result() # Yield rows serializer = self._getInputSerializer() for row in query_job: row_data = row['SMILES'] yield serializer.fromString(row_data) # Delete table, as it is no longer needed after querying table = self.bq_client.get_table(table_name) self.bq_client.delete_table(table) def _uploadFileToBQ(self, bucket_file): """ Uploads file from GCP bucket to Bigquery Table :param bucket_file: GCP file to load into BQ :type bucket_file: storage.blob.Blob """ dataset_ref = self._getDataset() table_ref = dataset_ref.table(self.settings.table_name) table = self._createTableFromFile(bucket_file, table_ref) return table def _getDataset(self): """ Returns a dataset. Creates if it does not exist """ # Create dataset and table try: dataset_ref = self.bq_client.get_dataset(self.settings.dataset) except exceptions.NotFound: dataset = bigquery.Dataset( f'{self.settings.project}.{self.settings.dataset}') dataset_ref = self.bq_client.create_dataset(dataset) return dataset_ref def _createTableFromFile(self, bucket_file, table_ref): """ Creates a BigQuery Table based on a given file :param bucket_file: GCP file to load into BQ :type bucket_file: storage.blob.Blob :param table_ref: Table reference to store data in :type table_ref: bigquery.table.TableReference :return: Created table :rtype: bigquery.table.Table """ schema = [bigquery.SchemaField("SMILES", "STRING", mode="REQUIRED")] job_config = bigquery.LoadJobConfig(schema=schema) uri = f"gs://{bucket_file.bucket.name}/{bucket_file.name}" load_job = self.bq_client.load_table_from_uri( uri, table_ref, job_config=job_config) load_job.result() table = self.bq_client.get_table(table_ref) return table
[docs]class MolBQDedupStep(MolMolMixin, AbstractBQDedupeStep): pass
[docs]class BQDedupeStep(MolBQDedupStep): """ Standard deduplication step using BigQuery """
[docs] def getQuery(self, table_name): return f""" SELECT SMILES FROM `{table_name}` GROUP BY SMILES"""
[docs]class RandomSampleBQDedupeStep(MolBQDedupStep): """ Deduplication step with random sampling enabled. Sampling occurs after deduplication. The config's `rows_to_sample` specifies the average number of rows to keep. """ Settings = RandomSampleBQDedupeSettings
[docs] def getQuery(self, table_name): return f""" WITH dedupe_table AS ( SELECT SMILES FROM `{table_name}` GROUP BY SMILES) SELECT SMILES FROM dedupe_table WHERE RAND() < {self.settings.rows_to_sample}/(SELECT COUNT(*) FROM dedupe_table);"""