Source code for schrodinger.protein.tasks.blast

import os
from collections import defaultdict
from typing import List

from schrodinger.application.msv import seqio
from schrodinger.models import jsonable
from schrodinger.models import parameters
from schrodinger.protein import alignment
from schrodinger.protein import sequence
from schrodinger.tasks import jobtasks
from schrodinger.tasks import tasks
from schrodinger.job.util import hunt

BlastAlgorithm = jsonable.JsonableEnum('BlastAlgorithm', 'BLAST PSIBLAST')
BlastDatabase = jsonable.JsonableEnum('BlastDatabase', 'PDB NR')
SimilarityMatrix = jsonable.JsonableEnum(
    'SimilarityMatrix', 'BLOSUM45 BLOSUM62 BLOSUM80 PAM30 PAM70')
LOCAL = 'local'
REMOTE = 'remote'


class NoBlastHitsError(RuntimeError):
    """
    Custom exception if no blast hits are returned
    """


def _create_inverse_dict(k_v_dict):
    """
    Given a dict with a set of values for each key, create an inverse dict
    where the values point to a set of keys
    """
    v_k_dict = defaultdict(set)
    for k, v_set in k_v_dict.items():
        for v in v_set:
            v_k_dict[v].add(k)
    return dict(v_k_dict)


def _find_closest(x, values):
    # This is O(N) but N is small
    return min(values, key=lambda v: abs(v - x))


class BlastSettings(parameters.CompoundParam):
    progname: BlastAlgorithm
    location: str = LOCAL
    database_name: BlastDatabase
    word_size: int = 3
    filter_query: bool = False
    gap_open_cost: int = 11
    gap_extend_cost: int = 1
    possible_gap_open: set
    allowed_gap_extend: set
    similarity_matrix: SimilarityMatrix = SimilarityMatrix.BLOSUM80
    num_iterations: int = 3
    evalue_threshold: int = 1
    inclusion_threshold: float = 0.005
    allow_multiple_chains: bool = False
    download_structures: bool = False
    align_after_download: bool = False

    DEFAULTS = {
        SimilarityMatrix.BLOSUM45: (15, 2),
        SimilarityMatrix.BLOSUM62: (9, 2),
        SimilarityMatrix.BLOSUM80: (10, 1),
        SimilarityMatrix.PAM30: (9, 1),
        SimilarityMatrix.PAM70: (10, 1),
    }

    EXT_OPEN_VALUES = {
        SimilarityMatrix.BLOSUM45: {
            3: {13, 12, 11, 10},
            2: {15, 14, 13, 12},
            1: {19, 18, 17, 16}
        },
        SimilarityMatrix.BLOSUM62: {
            2: {11, 10, 9, 8, 7, 6},
            1: {13, 12, 11, 10, 9}
        },
        SimilarityMatrix.BLOSUM80: {
            2: {8, 7, 6},
            1: {11, 10, 9}
        },
        SimilarityMatrix.PAM30: {
            2: {7, 6, 5},
            1: {10, 9, 8}
        },
        SimilarityMatrix.PAM70: {
            2: {8, 7, 6},
            1: {11, 10, 9}
        },
        # Remote blast allows more combinations than local blast but for
        # simplicity we only allow local-compatible settings
        #SimilarityMatrix.PAM30: {3: {15, 13}, 2: {14, 7, 6, 5}, 1: {14, 10, 9, 8}},
        #SimilarityMatrix.PAM70: {3: {12}, 2: {11, 8, 7, 6}, 1: {11, 10, 9}},
    }

    OPEN_EXT_VALUES = {
        matrix: _create_inverse_dict(ext_open)
        for matrix, ext_open in EXT_OPEN_VALUES.items()
    }

    def initConcrete(self):
        super().initConcrete()
        self.similarity_matrixChanged.connect(self._setMatrixDefaults)
        self.gap_open_costChanged.connect(self._updateAllowedExtendValues)

    def initializeValue(self):
        super().initializeValue()
        self._setMatrixDefaults()

    @classmethod
    def fromJsonImplementation(cls, json_obj):
        new_input = super().fromJsonImplementation(json_obj)
        # These params can affect each other so they need to be set again in
        # the correct order to reconstitute correctly
        order_dependent_params = ('gap_open_cost', 'gap_extend_cost')
        for param_name in order_dependent_params:
            setattr(new_input, param_name, json_obj[param_name])
        return new_input

    def getBlastSettings(self):
        """
        Returns BLAST settings as a dictionary.

        :return: Dictionary of BLAST settings.
        :rtype: dict
        """
        # these settings will be passed as arguments to BlastPlus with only a
        # leading hyphen added
        settings = {
            'progname': 'blastp'
            if self.progname is BlastAlgorithm.BLAST else 'psiblast',
            'location': self.location,
            'database': self.database_name.name.lower(),
            'word_size': self.word_size,
            'filter': 'yes' if self.filter_query else 'no',
            'gap_open_cost': self.gap_open_cost,
            'gap_extend_cost': self.gap_extend_cost,
            'matrix': self.similarity_matrix.name,
            'num_iterations': self.num_iterations,
            'evalue': 10**self.evalue_threshold,
            'expand_hits': self.allow_multiple_chains,
        }
        if self.location == REMOTE or self.progname is BlastAlgorithm.PSIBLAST:
            settings['e_value_threshold'] = self.inclusion_threshold
        settings = {k: str(v) for k, v in settings.items()}
        return settings

    def _setMatrixDefaults(self):
        self._updateOpenValues()
        gap_open, gap_extend = self.DEFAULTS[self.similarity_matrix]
        self._updateAllowedExtendValues(gap_open)
        self.gap_open_cost = gap_open
        self.gap_extend_cost = gap_extend

    def _updateOpenValues(self):
        values = set(self.OPEN_EXT_VALUES[self.similarity_matrix].keys())
        self.possible_gap_open = values
        if self.gap_open_cost not in values:
            self.gap_open_cost = _find_closest(self.gap_open_cost, values)

    def _updateAllowedExtendValues(self, gap_open):
        values = self.OPEN_EXT_VALUES[self.similarity_matrix].get(gap_open)
        if values is None:
            values = set(self.EXT_OPEN_VALUES[self.similarity_matrix].keys())
        if self.gap_extend_cost not in values:
            self.gap_extend_cost = _find_closest(self.gap_extend_cost, values)
        self.allowed_gap_extend = set(values)


class BlastTask(jobtasks.ComboJobTask):
    """
    This is a thin wrapper over BlastPlus object that implements job running
    and incorporation.

    To enable DEBUG_MODE, set DEBUG_MODE to True at the bottom of this class.
    In DEBUG_MODE, no blast call will actually be made and the first top
    10 hits of a BLAST search with 1cmy:a will be returned as the output.
    """
    DEFAULT_TASKDIR_SETTING = tasks.TEMP_TASKDIR
    PROGRAM_NAME = "BLAST"

    output: List[dict]  # List of blast hits

    class Input(parameters.CompoundParam):
        query_sequence: sequence.ProteinSequence = None
        _query_sequence_str: str
        _query_sequence_name: str
        settings: BlastSettings
        genomes: List[str]

        def getQueryName(self):
            query_seq = self.query_sequence
            if query_seq is None:
                return ''
            return query_seq.fullname.replace('_', ':')

        def initConcrete(self):
            super().initConcrete()
            self.query_sequenceChanged.connect(self._updateQuerySeqInfo)

        @classmethod
        def getJsonBlacklist(cls):
            """
            @overrides: parameters.CompoundParam

            Rather than encode the whole `ProteinSequence` object, we just
            save the necessary parts and rehydrate it in the backend.
            """
            return [cls.query_sequence]

        @classmethod
        def fromJsonImplementation(cls, json_obj):
            new_input = super().fromJsonImplementation(json_obj)
            if new_input._query_sequence_str or new_input._query_sequence_name:
                new_input.query_sequence = sequence.ProteinSequence(
                    new_input._query_sequence_str,
                    name=new_input._query_sequence_name)
            return new_input

        def _updateQuerySeqInfo(self, query_seq):
            if query_seq is not None:
                self._query_sequence_str = str(query_seq)
                self._query_sequence_name = query_seq.name

    def getExpectedRuntime(self):
        """
        Return the expected runtime of the task based on the current settings

        :return: Expected runtime in seconds
        :rtype: int
        """
        settings = self.input.settings
        if settings.getBlastSettings()['location'] == REMOTE:
            minutes = 30
        elif settings.database_name is BlastDatabase.NR:
            minutes = 3 * 60  # Local NR is expected to take ~2.5 hours
        else:
            minutes = 5
        return minutes * 60

    def getQueryName(self):
        return self.input.getQueryName()

    def checkLocalDatabase(self):
        """
        Return True if the local database exists and is correctly configured.
        Return False if the local database is missing or truncated.
        """
        blast_plus = self._initBlastPlus()
        try:
            has_local_db = blast_plus.checkLocalBlastInstallation()
        except RuntimeError:
            return False
        else:
            return has_local_db

    @tasks.preprocessor(order=tasks.BEFORE_TASKDIR)
    def _clearInputFiles(self):
        self.input_files.clear()

    def _initBlastPlus(self):
        """
        Create and return a BlastPlus object based on the current settings
        """
        options = []
        for key, value in self.input.settings.getBlastSettings().items():
            options.append('-' + key)
            options.append(value)

        from schrodinger.application.prime.packages import blast_plus
        parsed_options = blast_plus.blast_parser().parse_args(options)
        return blast_plus.BlastPlus(parsed_options)

    def backendMain(self):
        self._pdb_header_info = {}

        blast_plus = self._initBlastPlus()

        hits = blast_plus.runBlast(
            seqio.to_biopython(self.input.query_sequence))
        if not hits:
            raise NoBlastHitsError("No relevant BLAST results were found.")
        self.output = self._parseHits(hits)

    def getBlastAlignment(self):
        if self.status is self.FAILED:
            raise ValueError("Cannot get alignment. Blast task failed.")
        elif self.status is not self.DONE:
            raise ValueError("Can't get the blast alignment before the task "
                             "is run.")
        seqs = [self.input.query_sequence]
        seqs.extend(
            sequence.ProteinSequence(hit['sequence']) for hit in self.output)
        return alignment.ProteinAlignment(seqs)

    def _parseHits(self, hits):
        """
        :return: A list of parsed hits.
        :rtype: list of dict
        """

        if not self._pdb_header_info:
            self._initializePDBHeaderInfo()

        hit_list = []
        from schrodinger.application.prime.packages import blast_plus
        for hit in hits:
            name, description = self._getNameAndDescription(hit)
            title = hit.alignment.title
            info = blast_plus.get_info_from_title(title)
            database = blast_plus.get_database_from_title(title)
            for hsp in hit.alignment.hsps:
                codes = str(hit.seq_io.seq)
                percent_id = 100 * hsp.identities / hsp.align_length
                percent_pos = 100 * hsp.positives / hsp.align_length
                percent_gaps = 100 * hsp.gaps / hsp.align_length
                info = self._pdb_header_info.get(name, {})
                hit_dict = {
                    'name': name,
                    'info': info,
                    'database': database,
                    'sequence': codes,
                    'score': hsp.score,
                    'evalue': hsp.expect,
                    'percent_id': percent_id,
                    'percent_pos': percent_pos,
                    'percent_gaps': percent_gaps,
                    'pdb_title': description,
                    'pdb_compound': info.get('COMPND:', ''),
                    'pdb_source': info.get('SOURCE:', ''),
                    'pdb_expdata': info.get('EXPDTA:', ''),
                    'pdb_resolution': info.get('RESOLUTION:', ''),
                    'pdb_hetname': info.get('HETNAM:', ''),
                    'pdb_pfam': info.get('PFAM:', ''),
                }
                hit_list.append(hit_dict)

        return hit_list

    @staticmethod
    def _getNameAndDescription(hit):
        """
        Extract the hit's name and description.

        :param hit: Object representing a single BLAST hit
        :type hit: blast_plus.BlastHit
        """
        from schrodinger.application.prime.packages import blast_plus
        # When expand_hits is True, each hit has a reference to an identical
        # alignment object; only the biopython sequence description is updated
        # with the single sequence's name
        full_description = hit.seq_io.description
        name = blast_plus.get_name_from_title(full_description)
        # Get description of protein from seq description
        # hsp: high scoring pair
        first_hsp = full_description.split('>', 1)[0]
        description = first_hsp.split(' ', 1)[1]
        description = description.strip()
        return name, description

    @staticmethod
    def _getPDBHeaderFileName():
        """
        Find a path to PDB header info file.

        :return: Path to the PDB header info file, or empty string if not found.
        :rtype: str
        """
        psp_data_dir = hunt('psp', 'data')
        if psp_data_dir:
            header_file_name = os.path.join(psp_data_dir, "headerinfo.dat")

            if os.path.isfile(header_file_name):
                return header_file_name

        return ""

    def _initializePDBHeaderInfo(self):
        """
        Initialize PDB header info. This fills up the pdb_header_info
        dictionary.
        """
        self._pdb_header_info = {}
        pdb_header_file_name = self._getPDBHeaderFileName()
        if not pdb_header_file_name:
            return

        lines = []
        with open(pdb_header_file_name, "r") as header_file:
            lines = header_file.readlines()

        pdb_id = ""
        for line in lines:
            names = line.split(' ', 1)
            if len(names) > 1:
                key = names[0]
                text = names[1].rstrip()
                if key == "ID:":
                    pdb_id = text
                    self._pdb_header_info[pdb_id] = {}
                elif pdb_id:
                    self._pdb_header_info[pdb_id][key] = text