Source code for schrodinger.models.jsonable

"""
A module for defining jsonable versions of classes (typically classes defined
in third-party modules).

You can also find the registry of classes that are supported by the load(s) and
dump(s) functions in `schrodinger.model.json`. Any object that is an instance
of one of the registered classes will be automatically jsonable using `dump`
and `dumps`. To deserialize, you must specify the registered class to
`load` or `loads`. Example::

    from schrodinger.models import json
    my_set = set(range(1,2,3))
    my_set_jsonstr = json.dumps(my_set)
    new_set = json.loads(my_set_jsonstr, DataClass=set)
    assert new_set == my_set
    assert isinstance(new_set, set)


Currently registered DataClasses:
    - structure.Structure
    - set
    - tuple
    - rdkit.Chem.rdchem.Mol
"""
import collections
import enum

from schrodinger import structure
from schrodinger.models.json import JsonableClassMixin
from schrodinger.models import json

from rdkit import Chem


class JsonableSet(JsonableClassMixin, set):
    # String used to signal that the list actually represents a set.
    ENCODING_KEY = '_python_set_'

    def toJsonImplementation(self):
        return list(self) + [self.ENCODING_KEY]

    @classmethod
    def fromJsonImplementation(cls, json_list):
        if not json_list or json_list.pop() != cls.ENCODING_KEY:
            err_msg = f'Given list was not originally encoded using {cls.__name__}'
            raise ValueError(err_msg)
        return cls(json_list)

    def copy(self):
        return JsonableSet(self)


class JsonableStructure(JsonableClassMixin, structure.Structure):

    def toJsonImplementation(self):
        return self.writeToString(structure.MAESTRO)

    @classmethod
    def fromJsonImplementation(cls, json_str):
        assert json_str is not None
        with structure.StructureReader.fromString(json_str) as reader:
            try:
                return next(reader)
            except StopIteration:
                raise json.JSONDecodeError('No structure found', json_str, 0)


class JsonableTuple(JsonableClassMixin, tuple):
    # String used to signal that the list actually represents a set.
    ENCODING_KEY = '_python_tuple_'

    def toJsonImplementation(self):
        return list(self) + [self.ENCODING_KEY]

    @classmethod
    def fromJsonImplementation(cls, json_list):
        if not json_list or json_list.pop() != cls.ENCODING_KEY:
            err_msg = f'Given list was not originally encoded using {cls.__name__}'
            raise ValueError(err_msg)
        return cls(json_list)


class _JsonableNamedTupleMeta(type):
    """
    Create a jsonable named tuple class.
    """

    def __new__(cls, cls_name, bases, cls_dict):
        if cls_name == 'JsonableNamedTuple':
            return super().__new__(cls, cls_name, bases, cls_dict)
        fields = cls_dict.get('__annotations__', {})
        namedtuple_cls = collections.namedtuple('_JsonableNamedTuple',
                                                list(fields.keys()))

        def toJsonImplementation(self):
            return JsonableTuple(self)

        @classmethod
        def fromJsonImplementation(cls, json_list):
            values = list(JsonableTuple.fromJson(json_list))
            for value_idx, value_type in enumerate(
                    cls.__annotations__.values()):
                values[value_idx] = json.decode(
                    values[value_idx], DataClass=value_type)
            return cls(*values)

        bases = tuple([JsonableClassMixin] + namedtuple_cls.mro())
        cls_dict = {
            'toJsonImplementation': toJsonImplementation,
            'fromJsonImplementation': fromJsonImplementation,
            '__annotations__': fields
        }
        jsonable_namedtuple_cls = type(cls_name, bases, cls_dict)
        # This is functionally equivalent to:
        # class cls_name(JsonableClassMixin, namedtuple_cls):
        #   def toJsonImplementation(self):
        #      return JsonableTuple(self)
        #   @classmethod
        #   def fromJsonImplementation(cls, json_list):
        #      # ...
        #   __annotations__ = fields
        return jsonable_namedtuple_cls


class JsonableNamedTuple(JsonableClassMixin, metaclass=_JsonableNamedTupleMeta):
    """
    A jsonabled NamedTuple that behaves like a normal named tuple but is
    jsonable if its fields are jsonable. Example::

        class Coordinate(JsonableNamedTuple):
            x: float
            y: float
            description: str

        coord = Coordinate(x=1, y=2, description="molecule coord")
        assert coord == (1, 2, "molecule coord")
        serialized_coord = json.dumps(c)
        deserialized_coord = json.loads(serialized_coord, DataClass=Coordinate)
        assert deserialized_coord == (1, 2, "molecule coord")

    WARNING:: Instances of subclasses of this class will not evaluate as
        instances of `JsonableNamedTuple`. This replicates the behavior
        of `typing.NamedTuple`.
    """
    pass


class _JsonableEnumBase(JsonableClassMixin):
    """
    The Enum class checks mixins to see if __reduce_ex__ is defined. If it
    isn't, it makes the class unpicklable and as a consequence undeepcopyable.
    We just use the regular Enum's picklable protocol since JsonableClassMixin
    doesn't need anything extra.
    """

    def __reduce_ex__(self, proto):
        return enum.Enum.__reduce_ex__(self, proto)


class JsonableEnum(_JsonableEnumBase, enum.Enum):

    def __init__(self, *args, **kwargs):
        self._setJsonAdapters()

    @classmethod
    def fromJsonImplementation(cls, json_obj):
        return cls(json_obj)

    def toJsonImplementation(self):
        return self.value


class JsonableIntEnum(int, JsonableClassMixin, enum.Enum):

    def __init__(self, *args, **kwargs):
        self._setJsonAdapters()

    @classmethod
    def fromJsonImplementation(cls, json_obj):
        return cls(json_obj)

    def toJsonImplementation(self):
        return self.value


class _JsonableMolWrapper(JsonableClassMixin):

    def __init__(self, mol_block=None):
        if mol_block is None:
            self._mol = Chem.rdchem.Mol()
        else:
            self._mol = Chem.MolFromMolBlock(mol_block)

    def toJsonImplementation(self):
        return Chem.MolToMolBlock(self._mol)

    @classmethod
    def fromJsonImplementation(cls, json_str):
        return cls(mol_block=json_str)


"""---------------------- DataClass Registry ----------------------------------
To add a new class to the registry, subclass `AbstractJsonSerializer` and
implement the abstract variables and methods. See `AbstractJsonSerializer`
for more information.
"""


class AbstractJsonSerializer:
    """
    A class for defining how serialization should be done for a particular
    object. This should only be used if you're unable to use
    `json.JsonableClassMixin`. This can be used in conjunction with
    `json.load(s)` and `json.dump(s)`.

    Subclasses must define `ObjectClass` and `JsonableClass` and override
    `objectFromJsonable` and `jsonableFromObject`.

    Create a subclass here to add a new class to the global default
    serialization registry. (Consult with relevant parties before doing so...)

    :cvar ObjectClass: The non-jsonable third-party class (e.g. set, rdkit.Mol,
        etc.)
    :cvar JsonableClass: The class that subclasses `ObjectClass` and mixes in
        JsonableClassMixin.
    """
    ObjectClass = NotImplemented
    JsonableClass = NotImplemented

    def __init__(self):
        raise TypeError("Serializers should not be instantiated.")

    @classmethod
    def objectFromJsonable(cls, jsonable_obj):
        """
        Return an instance of `ObjectClass` from an instance of `JsonableClass`
        """
        raise NotImplementedError()

    @classmethod
    def jsonableFromObject(cls, obj):
        """
        Return an instance of `JsonableClass` from an instance of `ObjectClass`
        """
        raise NotImplementedError()

    @classmethod
    def objectFromJson(cls, json_obj):
        """
        DO NOT OVERRIDE.

        Return an instance of ObjectClass from a json object (i.e. an object
        made up of json native types).
        """
        jsonable_obj = json.decode(json_obj, DataClass=cls.JsonableClass)
        return cls.objectFromJsonable(jsonable_obj)


class StructureSerializer(AbstractJsonSerializer):
    ObjectClass = structure.Structure
    JsonableClass = JsonableStructure

    @classmethod
    def objectFromJsonable(cls, jsonable_structure):
        return structure.Structure(jsonable_structure.handle)

    @classmethod
    def jsonableFromObject(cls, structure_):
        return JsonableStructure(structure_.handle)


class TupleSerializer(AbstractJsonSerializer):
    ObjectClass = tuple
    JsonableClass = JsonableTuple

    @classmethod
    def objectFromJsonable(cls, jsonable_tuple):
        return tuple(jsonable_tuple)

    @classmethod
    def jsonableFromObject(cls, tuple_):
        return JsonableTuple(tuple_)


class SetSerializer(AbstractJsonSerializer):
    ObjectClass = set
    JsonableClass = JsonableSet

    @classmethod
    def objectFromJsonable(cls, jsonable_set):
        return set(jsonable_set)

    @classmethod
    def jsonableFromObject(cls, set_):
        return JsonableSet(set_)


class MolSerializer(AbstractJsonSerializer):
    ObjectClass = Chem.rdchem.Mol
    JsonableClass = _JsonableMolWrapper

    @classmethod
    def objectFromJsonable(cls, jsonable_mol):
        mol_block = Chem.MolToMolBlock(jsonable_mol._mol)
        return Chem.MolFromMolBlock(mol_block)

    @classmethod
    def jsonableFromObject(cls, mol):
        mol_block = Chem.MolToMolBlock(mol)
        return _JsonableMolWrapper.fromJson(mol_block)


DATACLASS_REGISTRY = {}
for name, attr in dict(globals()).items():
    try:
        is_serializer = issubclass(attr, AbstractJsonSerializer)
    except TypeError:
        continue
    if is_serializer:
        serializer = attr
        if serializer.ObjectClass is not NotImplemented:
            DATACLASS_REGISTRY[serializer.ObjectClass] = serializer


def get_default_serializer(DataClass):
    return DATACLASS_REGISTRY.get(DataClass)