Source code for schrodinger.application.steps.bigquery_deduplication
try:
from google.auth import compute_engine
from google.cloud import bigquery
from google.cloud import storage
from google.cloud import exceptions
except ImportError:
compute_engine = None
bigquery = None
storage = None
exceptions = None
from schrodinger.models import parameters
from schrodinger.tasks import stepper
from schrodinger.utils import fileutils
from .basesteps import MolMolMixin
# Scopes
BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery"
CLOUD_SCOPE = "https://www.googleapis.com/auth/cloud-platform"
# Error messages
NO_BUCKET = 'No bucket defined'
NO_PROJECT = 'No project defined'
NO_DATASET = 'No dataset defined'
[docs]def generate_clients(node_auth=True):
"""
Generate GCP and BigQuery clients. Defaults to authenticating via node.
Otherwise, requires a service account authentication via the file specified at
GOOGLE_APPLICATION_CREDENTIALS
"""
kwargs = {}
if node_auth:
kwargs['credentials'] = compute_engine.Credentials()
return bigquery.Client(**kwargs), storage.Client(**kwargs)
[docs]class BQDedupeSettings(parameters.CompoundParam):
bucket: str
project: str
dataset: str
table_name: str = 'smiles_table'
[docs]class RandomSampleBQDedupeSettings(BQDedupeSettings):
rows_to_sample: int = 5000
[docs]class AbstractBQDedupeStep(stepper.ReduceStep):
"""
Map Step that deduplicates a singular input file of compounds using BigQuery
Override getQuery() to use a specific query for the deduplication
"""
Settings = BQDedupeSettings
[docs] def validateSettings(self):
errors = []
if not self.settings.bucket:
errors.append(stepper.SettingsError(self, NO_BUCKET))
if not self.settings.project:
errors.append(stepper.SettingsError(self, NO_PROJECT))
if not self.settings.dataset:
errors.append(stepper.SettingsError(self, NO_DATASET))
return errors
[docs] def reduceFunction(self, inps):
self.bq_client, self.storage_client = generate_clients(False)
with fileutils.tempfilename() as f:
self._combineSMILES(inps, f)
table = self._uploadFileToTable(f)
yield from self._runDedupeQuery(table)
[docs] def getQuery(self, table_name):
"""
Returns formatted query for BQ load job
:param table_name: name of table to query
:type table_name: str
"""
raise NotImplementedError
def _combineSMILES(self, inps, filepath):
"""
Combines input into a newline separated file
"""
serializer = self._getInputSerializer()
with open(filepath, 'w') as f:
for inp in inps:
st = serializer.toString(inp)
f.write(f'{st}\n')
def _uploadFileToTable(self, filename):
"""
Uploads CSV file to BigQuery Table, via GCP.
:param filename: name of file to upload
:type filename: str
:return: Newly uploaded table
:rtype: bigquery.table.Table
"""
bucket = self.storage_client.get_bucket(self.settings.bucket)
bucket_file = bucket.blob(self.settings.table_name)
bucket_file.upload_from_filename(filename)
table = self._uploadFileToBQ(bucket_file)
return table
def _runDedupeQuery(self, table):
"""
Runs deduplication query and yields resultant rows
Deletes table.
:param table: Table to run query on
:type table: bigquery.table.Table
"""
table_name = f'{self.settings.project}.{self.settings.dataset}.{table.table_id}'
query_str = self.getQuery(table_name)
query_job = self.bq_client.query(query_str)
query_job.result()
# Yield rows
serializer = self._getInputSerializer()
for row in query_job:
row_data = row['SMILES']
yield serializer.fromString(row_data)
# Delete table, as it is no longer needed after querying
table = self.bq_client.get_table(table_name)
self.bq_client.delete_table(table)
def _uploadFileToBQ(self, bucket_file):
"""
Uploads file from GCP bucket to Bigquery Table
:param bucket_file: GCP file to load into BQ
:type bucket_file: storage.blob.Blob
"""
dataset_ref = self._getDataset()
table_ref = dataset_ref.table(self.settings.table_name)
table = self._createTableFromFile(bucket_file, table_ref)
return table
def _getDataset(self):
"""
Returns a dataset. Creates if it does not exist
"""
# Create dataset and table
try:
dataset_ref = self.bq_client.get_dataset(self.settings.dataset)
except exceptions.NotFound:
dataset = bigquery.Dataset(
f'{self.settings.project}.{self.settings.dataset}')
dataset_ref = self.bq_client.create_dataset(dataset)
return dataset_ref
def _createTableFromFile(self, bucket_file, table_ref):
"""
Creates a BigQuery Table based on a given file
:param bucket_file: GCP file to load into BQ
:type bucket_file: storage.blob.Blob
:param table_ref: Table reference to store data in
:type table_ref: bigquery.table.TableReference
:return: Created table
:rtype: bigquery.table.Table
"""
schema = [bigquery.SchemaField("SMILES", "STRING", mode="REQUIRED")]
job_config = bigquery.LoadJobConfig(schema=schema)
uri = f"gs://{bucket_file.bucket.name}/{bucket_file.name}"
load_job = self.bq_client.load_table_from_uri(
uri, table_ref, job_config=job_config)
load_job.result()
table = self.bq_client.get_table(table_ref)
return table
[docs]class MolBQDedupStep(MolMolMixin, AbstractBQDedupeStep):
pass
[docs]class BQDedupeStep(MolBQDedupStep):
"""
Standard deduplication step using BigQuery
"""
[docs] def getQuery(self, table_name):
return f"""
SELECT
SMILES
FROM
`{table_name}`
GROUP BY
SMILES"""
[docs]class RandomSampleBQDedupeStep(MolBQDedupStep):
"""
Deduplication step with random sampling enabled.
Sampling occurs after deduplication.
The config's `rows_to_sample` specifies the average number of rows to keep.
"""
Settings = RandomSampleBQDedupeSettings
[docs] def getQuery(self, table_name):
return f"""
WITH dedupe_table AS
(
SELECT
SMILES
FROM
`{table_name}`
GROUP BY
SMILES)
SELECT
SMILES
FROM
dedupe_table
WHERE
RAND() < {self.settings.rows_to_sample}/(SELECT COUNT(*) FROM dedupe_table);"""