Source code for schrodinger.application.matsci.mlearn.sklearn_json

"""
# Third-party code. No Schrodinger Copyright.
"""

import json

from sklearn import discriminant_analysis
from sklearn import dummy
from sklearn import svm
from sklearn.decomposition import PCA
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import _gb_losses
from sklearn.linear_model import Lasso
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import Perceptron
from sklearn.linear_model import Ridge
from sklearn.naive_bayes import BernoulliNB
from sklearn.naive_bayes import ComplementNB
from sklearn.naive_bayes import GaussianNB
from sklearn.naive_bayes import MultinomialNB
from sklearn.neural_network import MLPClassifier
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor

from . import classification as clf
from . import decomposition as dcp
from . import regression as reg

__version__ = '0.1.0'


[docs]def serialize_model(model): if isinstance(model, LogisticRegression): return clf.serialize_logistic_regression(model) elif isinstance(model, BernoulliNB): return clf.serialize_bernoulli_nb(model) elif isinstance(model, GaussianNB): return clf.serialize_gaussian_nb(model) elif isinstance(model, MultinomialNB): return clf.serialize_multinomial_nb(model) elif isinstance(model, ComplementNB): return clf.serialize_complement_nb(model) elif isinstance(model, discriminant_analysis.LinearDiscriminantAnalysis): return clf.serialize_lda(model) elif isinstance(model, discriminant_analysis.QuadraticDiscriminantAnalysis): return clf.serialize_qda(model) elif isinstance(model, svm.SVC): return clf.serialize_svm(model) elif isinstance(model, Perceptron): return clf.serialize_perceptron(model) elif isinstance(model, DecisionTreeClassifier): return clf.serialize_decision_tree(model) elif isinstance(model, GradientBoostingClassifier): return clf.serialize_gradient_boosting(model) elif isinstance(model, RandomForestClassifier): return clf.serialize_random_forest(model) elif isinstance(model, MLPClassifier): return clf.serialize_mlp(model) elif isinstance(model, LinearRegression): return reg.serialize_linear_regressor(model) elif isinstance(model, Lasso): return reg.serialize_lasso_regressor(model) elif isinstance(model, Ridge): return reg.serialize_ridge_regressor(model) elif isinstance(model, SVR): return reg.serialize_svr(model) elif isinstance(model, DecisionTreeRegressor): return reg.serialize_decision_tree_regressor(model) elif isinstance(model, GradientBoostingRegressor): return reg.serialize_gradient_boosting_regressor(model) elif isinstance(model, RandomForestRegressor): return reg.serialize_random_forest_regressor(model) elif isinstance(model, MLPRegressor): return reg.serialize_mlp_regressor(model) elif isinstance(model, PCA): return dcp.serialize_pca(model) else: raise ModellNotSupported( 'This model type is not currently supported. Email support@mlrequest.com to request a feature or report a bug.' )
[docs]def deserialize_model(model_dict): if model_dict['meta'] == 'lr': return clf.deserialize_logistic_regression(model_dict) elif model_dict['meta'] == 'bernoulli-nb': return clf.deserialize_bernoulli_nb(model_dict) elif model_dict['meta'] == 'gaussian-nb': return clf.deserialize_gaussian_nb(model_dict) elif model_dict['meta'] == 'multinomial-nb': return clf.deserialize_multinomial_nb(model_dict) elif model_dict['meta'] == 'complement-nb': return clf.deserialize_complement_nb(model_dict) elif model_dict['meta'] == 'lda': return clf.deserialize_lda(model_dict) elif model_dict['meta'] == 'qda': return clf.deserialize_qda(model_dict) elif model_dict['meta'] == 'svm': return clf.deserialize_svm(model_dict) elif model_dict['meta'] == 'perceptron': return clf.deserialize_perceptron(model_dict) elif model_dict['meta'] == 'decision-tree': return clf.deserialize_decision_tree(model_dict) elif model_dict['meta'] == 'gb': return clf.deserialize_gradient_boosting(model_dict) elif model_dict['meta'] == 'rf': return clf.deserialize_random_forest(model_dict) elif model_dict['meta'] == 'mlp': return clf.deserialize_mlp(model_dict) elif model_dict['meta'] == 'linear-regression': return reg.deserialize_linear_regressor(model_dict) elif model_dict['meta'] == 'lasso-regression': return reg.deserialize_lasso_regressor(model_dict) elif model_dict['meta'] == 'ridge-regression': return reg.deserialize_ridge_regressor(model_dict) elif model_dict['meta'] == 'svr': return reg.deserialize_svr(model_dict) elif model_dict['meta'] == 'decision-tree-regression': return reg.deserialize_decision_tree_regressor(model_dict) elif model_dict['meta'] == 'gb-regression': return reg.deserialize_gradient_boosting_regressor(model_dict) elif model_dict['meta'] == 'rf-regression': return reg.deserialize_random_forest_regressor(model_dict) elif model_dict['meta'] == 'mlp-regression': return reg.deserialize_mlp_regressor(model_dict) elif model_dict['meta'] == 'pca': return dcp.deserialize_pca(model_dict) else: raise ModellNotSupported( 'Model type not supported or corrupt JSON file. Email support@mlrequest.com to request a feature or report a bug.' )
[docs]def to_dict(model): return serialize_model(model)
[docs]def from_dict(model_dict): return deserialize_model(model_dict)
[docs]def to_json(model, model_name): with open(model_name, 'w') as model_json: json.dump(serialize_model(model), model_json)
[docs]def from_json(model_name): with open(model_name, 'r') as model_json: model_dict = json.load(model_json) return deserialize_model(model_dict)
[docs]class ModellNotSupported(Exception): pass