Source code for schrodinger.application.livedesign.login

"""
Login page of the Maestro LiveDesign Export GUI.

Copyright Schrodinger, LLC. All rights reserved.
"""

import glob
import os
import re
import sys
import tarfile
from collections import namedtuple
from enum import Enum

import requests

from schrodinger.utils import fileutils

Version = namedtuple('Version', ['major', 'minor'])

HOST = "host"
USERNAME = "username"
CLIENT = "client"
MODELS = 'models'
API_PATH = '/livedesign/api'
LDCLIENT_PATH = '/livedesign/ldclient.tar.gz'
VERSION_NUMBER = 'seurat_version_number'
LD_MODE = 'LIVEDESIGN_MODE'
VERSION_RE = re.compile(r'(?P<major>\d+)\.(?P<minor>\d+)')

GLOBAL_PROJECT_ID = '0'
GLOBAL_PROJECT_NAME = 'Global'

# The version that introduces corporate ID compound matching
LD_VERSION_CORP_ID_MATCHING = Version(8, 1)

# The version that introduces multiple LD IDs per structure
LD_VERSION_MULTIPLE_IDS = Version(8, 1)

# The version that introduces a separation between real/virtual compounds
LD_VERSION_REAL_VIRTUAL = Version(8, 2)

# The version that introduces the new LD Export API
LD_VERSION_NEW_EXPORT = Version(8, 7)

# The version that introduces the new LD Import API
LD_VERSION_NEW_IMPORT = Version(8, 9)

# The version that introduces custom pose names
LD_VERSION_POSE_NAMES = Version(8, 9)

CONTACT_SUPPORT_MSG = '\nPlease contact Schrodinger Technical Support ' \
                      'with the hostname used to login.'
IMPORT_ERROR_MSG = 'Could not successfully import the necessary files fetched' \
                   ' from the server.'
NO_LDCLIENT_MSG = "Cannot retrieve LiveDesign client from the server"

ENTER_HOST_MSG = "Please enter a valid LiveDesign server name."
ENTER_CREDENTIALS_MSG = "Please enter valid username and password to connect."
INVALID_CREDENTIALS_MSG = "Invalid username or password."
VERSION_MISMATCH_MSG = ('There is a mismatch between the client version and the'
                        ' server version. Please try restarting Maestro.')
TIMEOUT_MSG = ("Cannot connect to the LiveDesign server "
               "(Operation timed out)")

MULTIPLE_FILES_FOUND_MSG = 'More than one {0} was found in the schrodinger ' \
                           'tmp dir at: {1}\nPlease try again after removing ' \
                           'any ld_client files or folders from the tmp ' \
                           'directory.'
NO_FILES_FOUND_MSG = 'The required files were not found in the tmp directory.'
TAR_ERROR_MSG = 'Unable to extract the necessary files from the fetched tar ' \
                'file: {0}'

global _SESSION_USERNAME, _SESSION_PWD, _SESSION_HOST, \
    _SESSION_IMPORT_PATHS
_SESSION_PWD = None
_SESSION_HOST = None
_SESSION_USERNAME = None
_SESSION_IMPORT_PATHS = []


class LDMode(str, Enum):
    """
    Enumerate the different LiveDesign modes.
    """
    DRUG_DISCOVERY = 'DRUG_DISCOVERY'
    MATERIALS_SCIENCE = 'MATERIALS_SCIENCE'

    def __str__(self):
        return self.value


LOGIN_ERR_MSG = 'Please use the Maestro LiveDesign Login Panel to first login' \
                ' into a LiveDesign host.'


def download_ld_client(url, tmp_dir, tar_filename, glob_path, timeout=None):
    '''
    Download the ld client under tmp_dir.
    :param url: url of the ld client
    :type url: str

    :param tmp_dir: Directory under which ld client will be downloaded
    :type tmp_dir: str

    :param tar_filename: tar filename of the client (without .tar.gz ext)
    :type tar_filename: str

    :param glob_path: glob path with wildcards (ex: ldclient-*)
    :type glob_path: str

    :param timeout: Timeout for the download request (in secs)
    :type timeout: int or NoneType

    :return: Returns the path to the client
    :rtype: str

    :raises Exception: Raises RuntimeError in case of any error
    '''

    # Remove any previous existing ldclient related files / directories
    remove_previous_existing_files(tmp_dir, tar_filename)

    # Fetch the tar file using requests
    filename = os.path.join(tmp_dir, tar_filename + '.tar.gz')
    with open(filename, "wb") as f:
        r = requests.get(url, verify=False, timeout=timeout)
        f.write(r.content)

    # Un-tar the tar file into the cwd
    try:
        tar = tarfile.open(filename)
        tar.extractall(path=tmp_dir)
        tar.close()
    except tarfile.TarError as e:
        err_msg = TAR_ERROR_MSG.format(filename)
        raise RuntimeError(err_msg)

    # Construct the path to un-compressed file
    path = os.path.join(tmp_dir, glob_path)
    possible_paths = glob.glob(path)
    file_count = len(possible_paths)

    if file_count > 1:
        raise RuntimeError(
            MULTIPLE_FILES_FOUND_MSG.format(tar_filename, tmp_dir))
    elif file_count == 0:
        raise RuntimeError(NO_FILES_FOUND_MSG)

    return possible_paths[0]


def remove_previous_existing_files(dirpath, name):
    """
    This helper method removed all files and directories matching the given
    name in the directory specified by dirpath.

    :param dirpath: Path to the directory under which the files needs to be
                    removed
    :type name: str

    :param name: to match any files or directories in the form of 'name*'
    :type name: str
    """
    remove_path = os.path.join(dirpath, name + '*')
    remove_paths = glob.glob(remove_path)

    for path in remove_paths:
        if os.path.isdir(path):
            fileutils.force_rmtree(path)
        else:
            fileutils.force_remove(path)


def required_login_info_set():
    """
    Checks and returns whether the required login information is set in order to
    setup LDClient
    """
    return (_SESSION_HOST and _SESSION_USERNAME and _SESSION_PWD and
            _SESSION_IMPORT_PATHS)


def get_ld_client_and_models():
    """
    Returns a new instance of ldclient.client.LDClient() and ld_client.models
    using the login credentials and the host server set in the Live Design
    login panel. If the instantiation was unsuccessful, the first object in the
    tuple will hold the error msg - otherwise, it will hold an empty str.

    :return: tuple of error msg str, new instance of the LD Client, and the
        ldclient.models module
    :rtype: (str, ldclient.client.LDClient(), ldclient.models)
    """
    msg = ''
    ld_client = None
    ld_models = None

    if not required_login_info_set():
        msg = LOGIN_ERR_MSG
        return (msg, ld_client, ld_models)

    # Add the ld_client module path to the top of the search path
    for path in _SESSION_IMPORT_PATHS:
        sys.path.insert(0, path)
    try:
        import ldclient.client
        import ldclient.models

        ld_client_host = _SESSION_HOST + API_PATH
        ld_client = ldclient.client.LDClient(ld_client_host, _SESSION_USERNAME,
                                             _SESSION_PWD)
        ld_models = ldclient.models
    except ImportError as e:
        msg = IMPORT_ERROR_MSG + CONTACT_SUPPORT_MSG

    return (msg, ld_client, ld_models)


def format_host(host):
    """
    Format the given host. Adds 'https' protocol if none,
    and removes any trailing '/'s

    :param host: LiveDesign server host
    :type host: str

    :return: Formatted host
    :rtype: str
    """
    host = host.rstrip('/')
    if '://' not in host:
        host = 'https://{}'.format(host)
    return host


def get_host():
    """
    :return: the host for the current session
    :rtype: str
    """
    return _SESSION_HOST


def get_credentials():
    """
    :return: the username and password for the current session
    :rtype: tuple(str, str)
    """
    return _SESSION_USERNAME, _SESSION_PWD


def get_username():
    """
    :return: the username for the current session
    :rtype: str
    """
    return _SESSION_USERNAME


def get_LD_version(ld_client=None):
    """
    Given an LDClient instance, return the LD version number.

    :param ld_client: optionally, an instance of the LDClient
    :type ld_client: ld_client.LDClient or NoneType
    :return: the version of the LD server
    :rtype: Version
    """
    if ld_client is None:
        _, ld_client, _ = get_ld_client_and_models()
    version_str = ld_client.about()[VERSION_NUMBER]
    match = VERSION_RE.match(version_str)
    major = int(match.group('major'))
    minor = int(match.group('minor'))
    return Version(major, minor)


def get_LD_mode(ld_client):
    """
    Given an LDClient instance, return the LD instance mode. For example, if
    the instance is in DRUG_DISCOVERY or MATERIAL_SCIENCE mode, etc. Note, that
    for older LD versions (< 8.6), there was no concept of a LD mode, and thus
    we assume the default mode of DRUG_DISCOVERY.

    :param ld_client: instance of the LDClient
    :type ld_client: `ld_client.LDClient`

    :return: the server mode.
    :rtype: str
    """
    config_dict = {d['key']: d['value'] for d in ld_client.config()}
    return config_dict.get(LD_MODE, LDMode.DRUG_DISCOVERY)