You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

493 lines
14 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 2,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stderr",
  10. "output_type": "stream",
  11. "text": [
  12. "D:\\Anaconda3\\envs\\py36\\lib\\site-packages\\sklearn\\externals\\six.py:31: DeprecationWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).\n",
  13. " \"(https://pypi.org/project/six/).\", DeprecationWarning)\n"
  14. ]
  15. }
  16. ],
  17. "source": [
  18. "import sys\n",
  19. "sys.path.append(\"../..\")\n",
  20. "import warnings\n",
  21. "warnings.filterwarnings(\"ignore\")\n",
  22. "from chemocommons import *\n",
  23. "import pandas as pd\n",
  24. "import numpy as np\n",
  25. "from skmultilearn.cluster import NetworkXLabelGraphClusterer # clusterer\n",
  26. "from skmultilearn.cluster import LabelCooccurrenceGraphBuilder # as it writes\n",
  27. "from skmultilearn.ensemble import LabelSpacePartitioningClassifier # so?\n",
  28. "from skmultilearn.adapt import MLkNN, MLTSVM\n",
  29. "from skmultilearn.problem_transform import LabelPowerset # sorry, we only used LP\n",
  30. "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier # Okay?\n",
  31. "from sklearn.model_selection import LeaveOneOut, RepeatedKFold #, KFold # jackknife, \"socalled\"\n",
  32. "from sklearn.metrics import jaccard_similarity_score, f1_score # for some calculation\n",
  33. "from sklearn.utils.multiclass import unique_labels\n",
  34. "from lightgbm import LGBMClassifier\n",
  35. "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
  36. "from joblib import load"
  37. ]
  38. },
  39. {
  40. "cell_type": "code",
  41. "execution_count": 15,
  42. "metadata": {},
  43. "outputs": [],
  44. "source": [
  45. "#loocv = LeaveOneOut() # jackknife\n",
  46. "rmskf = MultilabelStratifiedKFold(n_splits=10, random_state=19941115)\n",
  47. "label_names = [\"ABCG2\", \"MDR1\", \"MRP1\", \"MRP2\", \"MRP3\", \"MRP4\", \"NTCP2\", \"S15A1\", \n",
  48. " \"S22A1\", \"SO1A2\", \"SO1B1\", \"SO1B3\", \"SO2B1\"]\n",
  49. "\n",
  50. "Y = pd.read_csv(\"label_matrix.txt\", sep=\"\\t\", names=label_names)\n",
  51. "Y[Y==-1]=0\n",
  52. "\n",
  53. "ft_FP = pd.read_csv(\"query_smiles_feature_similarity_four_average.csv\", names=label_names)\n",
  54. "ft_FP.rename(mapper= lambda x: x + \"_FP\", axis=1, inplace=True)\n",
  55. "ft_OT = pd.read_csv(\"feature_similarity_chebi_ontology_DiShIn_2.csv\", names=label_names)\n",
  56. "ft_OT.rename(mapper= lambda x: x + \"_OT\", axis=1, inplace=True)\n",
  57. "\n",
  58. "X = np.concatenate((ft_FP, ft_OT), axis=1)\n",
  59. "Y = Y.values"
  60. ]
  61. },
  62. {
  63. "cell_type": "code",
  64. "execution_count": 70,
  65. "metadata": {},
  66. "outputs": [
  67. {
  68. "name": "stdout",
  69. "output_type": "stream",
  70. "text": [
  71. "0 th repeat:\n",
  72. "0 th fold.\n",
  73. "1 th fold.\n",
  74. "2 th fold.\n",
  75. "3 th fold.\n",
  76. "4 th fold.\n",
  77. "5 th fold.\n",
  78. "6 th fold.\n",
  79. "7 th fold.\n",
  80. "8 th fold.\n",
  81. "9 th fold.\n",
  82. "1 th repeat:\n",
  83. "0 th fold.\n",
  84. "1 th fold.\n",
  85. "2 th fold.\n",
  86. "3 th fold.\n",
  87. "4 th fold.\n",
  88. "5 th fold.\n",
  89. "6 th fold.\n",
  90. "7 th fold.\n",
  91. "8 th fold.\n",
  92. "9 th fold.\n",
  93. "2 th repeat:\n",
  94. "0 th fold.\n",
  95. "1 th fold.\n",
  96. "2 th fold.\n",
  97. "3 th fold.\n",
  98. "4 th fold.\n",
  99. "5 th fold.\n",
  100. "6 th fold.\n",
  101. "7 th fold.\n",
  102. "8 th fold.\n",
  103. "9 th fold.\n",
  104. "3 th repeat:\n",
  105. "0 th fold.\n",
  106. "1 th fold.\n",
  107. "2 th fold.\n",
  108. "3 th fold.\n",
  109. "4 th fold.\n",
  110. "5 th fold.\n",
  111. "6 th fold.\n",
  112. "7 th fold.\n",
  113. "8 th fold.\n",
  114. "9 th fold.\n",
  115. "4 th repeat:\n",
  116. "0 th fold.\n",
  117. "1 th fold.\n",
  118. "2 th fold.\n",
  119. "3 th fold.\n",
  120. "4 th fold.\n",
  121. "5 th fold.\n",
  122. "6 th fold.\n",
  123. "7 th fold.\n",
  124. "8 th fold.\n",
  125. "9 th fold.\n",
  126. "5 th repeat:\n",
  127. "0 th fold.\n",
  128. "1 th fold.\n",
  129. "2 th fold.\n",
  130. "3 th fold.\n",
  131. "4 th fold.\n",
  132. "5 th fold.\n",
  133. "6 th fold.\n",
  134. "7 th fold.\n",
  135. "8 th fold.\n",
  136. "9 th fold.\n",
  137. "6 th repeat:\n",
  138. "0 th fold.\n",
  139. "1 th fold.\n",
  140. "2 th fold.\n",
  141. "3 th fold.\n",
  142. "4 th fold.\n",
  143. "5 th fold.\n",
  144. "6 th fold.\n",
  145. "7 th fold.\n",
  146. "8 th fold.\n",
  147. "9 th fold.\n",
  148. "7 th repeat:\n",
  149. "0 th fold.\n",
  150. "1 th fold.\n",
  151. "2 th fold.\n",
  152. "3 th fold.\n",
  153. "4 th fold.\n",
  154. "5 th fold.\n",
  155. "6 th fold.\n",
  156. "7 th fold.\n",
  157. "8 th fold.\n",
  158. "9 th fold.\n",
  159. "8 th repeat:\n",
  160. "0 th fold.\n",
  161. "1 th fold.\n",
  162. "2 th fold.\n",
  163. "3 th fold.\n",
  164. "4 th fold.\n",
  165. "5 th fold.\n",
  166. "6 th fold.\n",
  167. "7 th fold.\n",
  168. "8 th fold.\n",
  169. "9 th fold.\n",
  170. "9 th repeat:\n",
  171. "0 th fold.\n",
  172. "1 th fold.\n",
  173. "2 th fold.\n",
  174. "3 th fold.\n",
  175. "4 th fold.\n",
  176. "5 th fold.\n",
  177. "6 th fold.\n",
  178. "7 th fold.\n",
  179. "8 th fold.\n",
  180. "9 th fold.\n",
  181. "[0.86886624 0.82629165 0.95213385 0.93531963 0.97052574 0.97484388\n",
  182. " 0.99404831 0.97431425 0.96509124 0.97319891 0.95619702 0.97633315\n",
  183. " 0.98074716] [0.72206496 0.77961642 0.83943579 0.72212698 0.59753571 0.36666667\n",
  184. " 0.925 0.91736341 0.9194019 0.49666667 0.519 0.5125\n",
  185. " 0.3 ] [0.48469748 0.90494505 0.44450549 0.25406593 0.31071429 0.167\n",
  186. " 0.86666667 0.87695652 0.60961905 0.13333333 0.14097222 0.151\n",
  187. " 0.11 ] [0.57687171 0.83710125 0.57530299 0.36020054 0.38846676 0.21738095\n",
  188. " 0.89090909 0.89447793 0.73041748 0.20373016 0.21516994 0.22747619\n",
  189. " 0.15428571] [0.89082465 0.92430166 0.9057471 0.91329001 0.89746348 0.93414092\n",
  190. " 0.99760552 0.98078294 0.94215135 0.86763828 0.89635493 0.85186265\n",
  191. " 0.85161042]\n"
  192. ]
  193. }
  194. ],
  195. "source": [
  196. "def measure_per_label(measure, y_true, y_predicted):\n",
  197. " \"\"\"\n",
  198. " This code is inspired by skmultilearn, but our y_true and y_predicted are all dense numpy.ndarray\n",
  199. " \"\"\"\n",
  200. " return [\n",
  201. " measure(\n",
  202. " y_true[:, i],\n",
  203. " y_predicted[:, i]\n",
  204. " )\n",
  205. " for i in range(y_true.shape[1])\n",
  206. " ]\n",
  207. "\n",
  208. "\n",
  209. "NLSP_RF = load(\"rf.joblib\")[0]\n",
  210. "\n",
  211. "final_model = NLSP_RF.best_estimator_\n",
  212. "label_acc = []\n",
  213. "label_sp = []\n",
  214. "label_rc = []\n",
  215. "label_f1 = []\n",
  216. "label_auc = []\n",
  217. "\n",
  218. "y_pred = np.zeros_like(Y)\n",
  219. "y_proba = np.zeros_like(Y)\n",
  220. "\n",
  221. "\n",
  222. "\n",
  223. "for i in range(10): #10*10-cv\n",
  224. " print(i, \"th repeat:\")\n",
  225. " kfold = MultilabelStratifiedKFold(n_splits=10, random_state=19941115)\n",
  226. " for k, (train, test) in enumerate(kfold.split(X, Y)):\n",
  227. " print(k, \"th fold.\")\n",
  228. " final_model.fit(X[train], Y[train])\n",
  229. " y_pred = np.array(final_model.predict(X[test]).todense())\n",
  230. " y_proba = np.array(final_model.predict_proba(X[test]).todense())\n",
  231. " label_acc.append(measure_per_label(metrics.accuracy_score, Y[test], y_pred))\n",
  232. " label_sp.append(measure_per_label(metrics.precision_score, Y[test], y_pred))\n",
  233. " label_rc.append(measure_per_label(metrics.recall_score, Y[test], y_pred))\n",
  234. " label_f1.append(measure_per_label(metrics.f1_score, Y[test], y_pred))\n",
  235. " label_auc.append(measure_per_label(metrics.roc_auc_score, Y[test], y_proba))\n",
  236. "\n",
  237. "label_acc = np.array(label_acc)\n",
  238. "label_sp = np.array(label_sp)\n",
  239. "label_rc = np.array(label_rc)\n",
  240. "label_f1 = np.array(label_f1)\n",
  241. "label_auc = np.array(label_auc)\n",
  242. "\n",
  243. "to_sav = dump((label_acc, label_sp, label_rc, label_f1, label_auc), filename=\"report_array.joblib\")\n",
  244. "\n",
  245. "print(label_acc.mean(axis=0), label_sp.mean(axis=0), label_rc.mean(axis=0),label_f1.mean(axis=0), label_auc.mean(axis=0))\n",
  246. "\n",
  247. " "
  248. ]
  249. },
  250. {
  251. "cell_type": "code",
  252. "execution_count": 66,
  253. "metadata": {},
  254. "outputs": [
  255. {
  256. "name": "stdout",
  257. "output_type": "stream",
  258. "text": [
  259. "[0.86728061 0.82665222 0.95449621 0.92741062 0.96749729 0.97074756\n",
  260. " 0.99458288 0.9723727 0.96533044 0.97453954 0.95612134 0.97508126\n",
  261. " 0.97995666] [0.09046587 0.45016251 0.03575298 0.01787649 0.00975081 0.00270856\n",
  262. " 0.02491874 0.1099675 0.04767064 0.00433369 0.00595883 0.00270856\n",
  263. " 0.00216685] [0.09046587 0.45016251 0.03575298 0.01787649 0.00975081 0.00270856\n",
  264. " 0.02491874 0.1099675 0.04767064 0.00433369 0.00595883 0.00270856\n",
  265. " 0.00216685] [0.09046587 0.45016251 0.03575298 0.01787649 0.00975081 0.00270856\n",
  266. " 0.02491874 0.1099675 0.04767064 0.00433369 0.00595883 0.00270856\n",
  267. " 0.00216685] 0.0\n"
  268. ]
  269. }
  270. ],
  271. "source": [
  272. "print(label_acc.mean(axis=0), label_sp.mean(axis=0), label_rc.mean(axis=0),label_f1.mean(axis=0), label_auc.mean(axis=0))"
  273. ]
  274. },
  275. {
  276. "cell_type": "code",
  277. "execution_count": 71,
  278. "metadata": {},
  279. "outputs": [
  280. {
  281. "data": {
  282. "text/plain": [
  283. "array([0.86886624, 0.82629165, 0.95213385, 0.93531963, 0.97052574,\n",
  284. " 0.97484388, 0.99404831, 0.97431425, 0.96509124, 0.97319891,\n",
  285. " 0.95619702, 0.97633315, 0.98074716])"
  286. ]
  287. },
  288. "execution_count": 71,
  289. "metadata": {},
  290. "output_type": "execute_result"
  291. }
  292. ],
  293. "source": [
  294. "label_acc.mean(axis=0)"
  295. ]
  296. },
  297. {
  298. "cell_type": "code",
  299. "execution_count": 72,
  300. "metadata": {},
  301. "outputs": [
  302. {
  303. "data": {
  304. "text/plain": [
  305. "array([0.72206496, 0.77961642, 0.83943579, 0.72212698, 0.59753571,\n",
  306. " 0.36666667, 0.925 , 0.91736341, 0.9194019 , 0.49666667,\n",
  307. " 0.519 , 0.5125 , 0.3 ])"
  308. ]
  309. },
  310. "execution_count": 72,
  311. "metadata": {},
  312. "output_type": "execute_result"
  313. }
  314. ],
  315. "source": [
  316. "label_sp.mean(axis=0)"
  317. ]
  318. },
  319. {
  320. "cell_type": "code",
  321. "execution_count": 73,
  322. "metadata": {},
  323. "outputs": [
  324. {
  325. "data": {
  326. "text/plain": [
  327. "array([0.48469748, 0.90494505, 0.44450549, 0.25406593, 0.31071429,\n",
  328. " 0.167 , 0.86666667, 0.87695652, 0.60961905, 0.13333333,\n",
  329. " 0.14097222, 0.151 , 0.11 ])"
  330. ]
  331. },
  332. "execution_count": 73,
  333. "metadata": {},
  334. "output_type": "execute_result"
  335. }
  336. ],
  337. "source": [
  338. "label_rc.mean(axis=0)"
  339. ]
  340. },
  341. {
  342. "cell_type": "code",
  343. "execution_count": 21,
  344. "metadata": {},
  345. "outputs": [
  346. {
  347. "data": {
  348. "text/plain": [
  349. "array([0.09046587, 0.45016251, 0.03575298, 0.01787649, 0.00975081,\n",
  350. " 0.00270856, 0.02491874, 0.1099675 , 0.04767064, 0.00433369,\n",
  351. " 0.00595883, 0.00270856, 0.00216685])"
  352. ]
  353. },
  354. "execution_count": 21,
  355. "metadata": {},
  356. "output_type": "execute_result"
  357. }
  358. ],
  359. "source": [
  360. "label_f1.mean(axis=0)"
  361. ]
  362. },
  363. {
  364. "cell_type": "code",
  365. "execution_count": 3,
  366. "metadata": {},
  367. "outputs": [],
  368. "source": [
  369. "reports = load(\"report_array.joblib\")"
  370. ]
  371. },
  372. {
  373. "cell_type": "code",
  374. "execution_count": 24,
  375. "metadata": {},
  376. "outputs": [],
  377. "source": [
  378. "final_reports = []\n",
  379. "for i in reports:\n",
  380. " final_reports.append(i.mean(axis=0))"
  381. ]
  382. },
  383. {
  384. "cell_type": "code",
  385. "execution_count": 25,
  386. "metadata": {},
  387. "outputs": [],
  388. "source": [
  389. "final_reports = pd.DataFrame(final_reports)"
  390. ]
  391. },
  392. {
  393. "cell_type": "code",
  394. "execution_count": 26,
  395. "metadata": {},
  396. "outputs": [],
  397. "source": [
  398. "final_reports = final_reports.T"
  399. ]
  400. },
  401. {
  402. "cell_type": "code",
  403. "execution_count": 29,
  404. "metadata": {},
  405. "outputs": [],
  406. "source": [
  407. "final_reports.index = label_names"
  408. ]
  409. },
  410. {
  411. "cell_type": "code",
  412. "execution_count": 30,
  413. "metadata": {},
  414. "outputs": [],
  415. "source": [
  416. "final_reports.columns = [\"ACC\", \"SP\", \"RC\", \"F1\", \"AUC\"]"
  417. ]
  418. },
  419. {
  420. "cell_type": "code",
  421. "execution_count": 33,
  422. "metadata": {},
  423. "outputs": [],
  424. "source": [
  425. "final_reports.to_csv(\"final_reports.csv\")"
  426. ]
  427. },
  428. {
  429. "cell_type": "code",
  430. "execution_count": 8,
  431. "metadata": {},
  432. "outputs": [
  433. {
  434. "data": {
  435. "text/plain": [
  436. "(13,)"
  437. ]
  438. },
  439. "execution_count": 8,
  440. "metadata": {},
  441. "output_type": "execute_result"
  442. }
  443. ],
  444. "source": [
  445. "my_list = reports[0].mean(axis=0)"
  446. ]
  447. },
  448. {
  449. "cell_type": "code",
  450. "execution_count": 7,
  451. "metadata": {},
  452. "outputs": [
  453. {
  454. "data": {
  455. "text/plain": [
  456. "array([0.72206496, 0.77961642, 0.83943579, 0.72212698, 0.59753571,\n",
  457. " 0.36666667, 0.925 , 0.91736341, 0.9194019 , 0.49666667,\n",
  458. " 0.519 , 0.5125 , 0.3 ])"
  459. ]
  460. },
  461. "execution_count": 7,
  462. "metadata": {},
  463. "output_type": "execute_result"
  464. }
  465. ],
  466. "source": [
  467. "reports[1].mean(axis=0)"
  468. ]
  469. }
  470. ],
  471. "metadata": {
  472. "kernelspec": {
  473. "display_name": "Python 3",
  474. "language": "python",
  475. "name": "python3"
  476. },
  477. "language_info": {
  478. "codemirror_mode": {
  479. "name": "ipython",
  480. "version": 3
  481. },
  482. "file_extension": ".py",
  483. "mimetype": "text/x-python",
  484. "name": "python",
  485. "nbconvert_exporter": "python",
  486. "pygments_lexer": "ipython3",
  487. "version": "3.7.3"
  488. }
  489. },
  490. "nbformat": 4,
  491. "nbformat_minor": 4
  492. }