You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

71 lines
2.4 KiB

  1. import warnings
  2. warnings.filterwarnings("ignore")
  3. from chemocommons import *
  4. import pandas as pd
  5. import numpy as np
  6. from skmultilearn.cluster import NetworkXLabelGraphClusterer # clusterer
  7. from skmultilearn.cluster import LabelCooccurrenceGraphBuilder # as it writes
  8. from skmultilearn.ensemble import LabelSpacePartitioningClassifier # so?
  9. from skmultilearn.adapt import MLkNN, MLTSVM
  10. from skmultilearn.problem_transform import LabelPowerset # sorry, we only used LP
  11. from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier # Okay?
  12. from sklearn.model_selection import LeaveOneOut, RepeatedKFold #, KFold # jackknife, "socalled"
  13. from sklearn.metrics import jaccard_similarity_score, f1_score # for some calculation
  14. from sklearn.utils.multiclass import unique_labels
  15. from lightgbm import LGBMClassifier
  16. loocv = LeaveOneOut() # jackknife
  17. label_names = ["ABCG2", "MDR1", "MRP1", "MRP2", "MRP3", "MRP4", "NTCP2", "S15A1",
  18. "S22A1", "SO1A2", "SO1B1", "SO1B3", "SO2B1"]
  19. Y = pd.read_csv("label_matrix.txt", sep="\t", names=label_names)
  20. Y[Y==-1]=0
  21. ft_FP = pd.read_csv("query_smiles_feature_similarity_four_average.csv", names=label_names)
  22. ft_FP.rename(mapper= lambda x: x + "_FP", axis=1, inplace=True)
  23. ft_OT = pd.read_csv("feature_similarity_chebi_ontology_DiShIn_2.csv", names=label_names)
  24. ft_OT.rename(mapper= lambda x: x + "_OT", axis=1, inplace=True)
  25. X = np.concatenate((ft_FP, ft_OT), axis=1)
  26. scoring_funcs = {"hamming loss": hamming_func,
  27. "aiming": aiming_func,
  28. "coverage": coverage_func,
  29. "accuracy": accuracy_func,
  30. "absolute true": absolute_true_func,
  31. } # Keep recorded
  32. parameters = {
  33. 'classifier': [LabelPowerset()],
  34. 'classifier__classifier': [ExtraTreesClassifier()],
  35. 'classifier__classifier__n_estimators': [50, 100, 500, 1000],
  36. 'clusterer' : [
  37. NetworkXLabelGraphClusterer(LabelCooccurrenceGraphBuilder(weighted=True, include_self_edges=False), 'louvain'),
  38. NetworkXLabelGraphClusterer(LabelCooccurrenceGraphBuilder(weighted=True, include_self_edges=False), 'lpa')
  39. ]
  40. }
  41. ext = GridSearchCV(LabelSpacePartitioningClassifier(), param_grid=parameters, n_jobs=-1, cv=loocv,
  42. scoring=scoring_funcs, verbose=3, refit="absolute true")
  43. ext.fit(X, Y.values)
  44. print(ext.best_score_)
  45. mytuple = (
  46. ext,
  47. )
  48. to_save = dump(mytuple, filename="ext.joblib")