You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

61 lines
2.0 KiB

  1. import sys
  2. sys.path.append("../..")
  3. import warnings
  4. warnings.filterwarnings("ignore")
  5. from chemocommons import *
  6. import pandas as pd
  7. import numpy as np
  8. from skmultilearn.cluster import NetworkXLabelGraphClusterer # clusterer
  9. from skmultilearn.cluster import LabelCooccurrenceGraphBuilder # as it writes
  10. from skmultilearn.ensemble import LabelSpacePartitioningClassifier # so?
  11. from skmultilearn.adapt import MLkNN, MLTSVM
  12. from skmultilearn.problem_transform import LabelPowerset # sorry, we only used LP
  13. from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier # Okay?
  14. from sklearn.model_selection import LeaveOneOut, RepeatedKFold #, KFold # jackknife, "socalled"
  15. from sklearn.metrics import jaccard_similarity_score, f1_score # for some calculation
  16. from sklearn.utils.multiclass import unique_labels
  17. from lightgbm import LGBMClassifier
  18. loocv = LeaveOneOut() # jackknife
  19. label_names = ["ABCG2", "MDR1", "MRP1", "MRP2", "MRP3", "MRP4", "NTCP2", "S15A1",
  20. "S22A1", "SO1A2", "SO1B1", "SO1B3", "SO2B1"]
  21. Y = pd.read_csv("label_matrix.txt", sep="\t", names=label_names)
  22. Y[Y==-1]=0
  23. ft_FP = pd.read_csv("query_smiles_feature_similarity_four_average.csv", names=label_names)
  24. ft_FP.rename(mapper= lambda x: x + "_FP", axis=1, inplace=True)
  25. ft_OT = pd.read_csv("feature_similarity_chebi_ontology_DiShIn_2.csv", names=label_names)
  26. ft_OT.rename(mapper= lambda x: x + "_OT", axis=1, inplace=True)
  27. X = np.concatenate((ft_FP, ft_OT), axis=1)
  28. scoring_funcs = {"hamming loss": hamming_func,
  29. "aiming": aiming_func,
  30. "coverage": coverage_func,
  31. "accuracy": accuracy_func,
  32. "absolute true": absolute_true_func,
  33. } # Keep recorded
  34. parameters = {'c_k': [2**i for i in range(-5, 5)]}
  35. mtsvm = GridSearchCV(MLTSVM(), param_grid=parameters, n_jobs=-1, cv=loocv,
  36. scoring=scoring_funcs, verbose=3, refit="absolute true")
  37. mtsvm.fit(X, Y.values)
  38. print(mtsvm.best_score_)
  39. mytuple = (
  40. mtsvm,
  41. )
  42. to_save = dump(mytuple, filename="mtsvm.joblib")