nmr_shift_prediction_from_small_data_quantities.py 83 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929
  1. # -*- coding: utf-8 -*-
  2. """NMR shift prediction from small data quantities.ipynb
  3. Automatically generated by Colaboratory.
  4. Original file is located at
  5. https://colab.research.google.com/drive/1yKTRjpWzR8T199eCokuJfd9Y5o2oNtPp
  6. # Data: NMRShiftDB
  7. ## Reading data set from SD file
  8. The nmrshiftdb2 data from https://sourceforge.net/projects/nmrshiftdb2/files/data/ has been downloaded and moved to our own domain in order to avoid the https://sourceforge.net/ madness (redirects, cookie banner, etc.) and ensure that colab can be rerun without manual steps or Google Drive dependencies by all collaborators.
  9. """
  10. # Install required dependencies
  11. !pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file
  12. !pip3 install mendeleev # To to access various features related to atoms
  13. import torch
  14. # PyTorch dependencies to represent graph data
  15. !pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
  16. !pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
  17. !pip3 install torch-geometric
  18. # Use CUDA by default
  19. if torch.cuda.is_available():
  20. torch.set_default_tensor_type('torch.cuda.FloatTensor')
  21. # Downloads the nmrshiftdb2 database if it does not yet exist in our runtime
  22. !wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine
  23. #!wget -nc "https://sourceforge.net/projects/nmrshiftdb2/files/data/nmrshiftdb2withsignals.sd/download" #Carbon
  24. from rdkit import Chem
  25. supplier3d = Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine
  26. #supplier3d = Chem.rdmolfiles.SDMolSupplier("download") #Carbon
  27. print(f"In total there are {len(supplier3d)} molecules")
  28. from rdkit.Chem import AllChem
  29. mol= supplier3d[152]
  30. mol
  31. """## Graph transformation
  32. In order to use convolutional networks, we need to transform the molecule data into graphs.
  33. The atoms themselves will become nodes, and the bonds between the atoms will become the edges.
  34. """
  35. import mendeleev
  36. import torch
  37. import math
  38. from torch import tensor
  39. from torch_geometric.data import Data
  40. from torch_geometric.loader import DataLoader
  41. import os
  42. import numpy as np
  43. import time
  44. from sklearn.preprocessing import OneHotEncoder
  45. import random
  46. # One hot encoding
  47. ## Bonds
  48. bond_idxes = np.array([Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC])
  49. bond_idxes = bond_idxes.reshape(len(bond_idxes), 1)
  50. onehot_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")
  51. onehot_encoder.fit(bond_idxes)
  52. ## Hybridization
  53. hybridization_idxes = np.array(list(Chem.HybridizationType.names))
  54. hybridization_idxes = hybridization_idxes.reshape(len(hybridization_idxes), 1)
  55. hybridization_ohe = OneHotEncoder(sparse=False)
  56. hybridization_ohe.fit(hybridization_idxes)
  57. ## Valence
  58. valences = np.arange(1, 8);
  59. valences = valences.reshape(len(valences), 1)
  60. valence_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
  61. valence_ohe.fit(valences)
  62. ## Formal Charge
  63. fc = np.arange(-1, 1);
  64. fc = fc.reshape(len(fc), 1)
  65. fc_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
  66. fc_ohe.fit(fc)
  67. ## Atomic number
  68. atomic_nums = np.array([6,1,7,8,9,17,15,11, 16])
  69. atomic_nums = atomic_nums.reshape(len(atomic_nums), 1)
  70. atomic_number_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
  71. atomic_number_ohe.fit(atomic_nums)
  72. atomic_number_ohe.transform(np.array([[1]]))
  73. def get_molecule_solvent(molecule):
  74. if not molecule:
  75. return {}
  76. for key, value in molecule.GetPropsAsDict().items():
  77. if key.startswith("Solvent"):
  78. return value
  79. return None
  80. solvents={}
  81. for mol in supplier3d:
  82. a = get_molecule_solvent(mol)
  83. if a:
  84. if a not in solvents:
  85. solvents[a] = 0
  86. solvents[a] += 1
  87. arr = [k for k, v in solvents.items() if v>10]
  88. solvents = np.array(arr)
  89. solvents = solvents.reshape(len(solvents), 1)
  90. solvent_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
  91. solvent_ohe.fit(solvents)
  92. #solvent_ohe.transform(np.array([[1]]))
  93. el_map={}
  94. def getMendeleevElement(nr):
  95. if nr not in el_map:
  96. el_map[nr] = mendeleev.element(nr)
  97. return el_map[nr]
  98. def nmr_shift(atom):
  99. for key, value in atom.GetOwningMol().GetPropsAsDict().items():
  100. if key.startswith("Spectrum"):
  101. for shift in value.split('|'):
  102. x = shift.split(';')
  103. if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"):
  104. return float(x[0])
  105. return float("NaN") # We use NaN for atoms we don't want to predict shifts
  106. def bond_features(bond):
  107. onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
  108. [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
  109. [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
  110. distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
  111. return distance+list(onehot_encoded_bondtype)
  112. def atom_features_random(atom, molecule=None):
  113. features = []
  114. features.append(random.randint(0,9)) # Atomic number
  115. return features
  116. def atom_features_update(atom, molecule=None):
  117. features = []
  118. me = getMendeleevElement(atom.GetAtomicNum())
  119. features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
  120. features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
  121. features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
  122. features.append(me.atomic_volume) # Atomic volume
  123. features.append(me.dipole_polarizability) # Dipole polarizability
  124. features.append(me.electron_affinity) # Electron affinity
  125. features.append(me.en_pauling) # Electronegativity
  126. features.append(me.electrons) # No. of electrons
  127. features.append(me.neutrons) # No. of neutrons
  128. features.append(atom.GetChiralTag())
  129. features.append(atom.IsInRing())
  130. return features
  131. def atom_features_2019(atom, molecule=None):
  132. features.append(atom.GetAtomicNum())
  133. features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
  134. features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
  135. features.extend(valence_ohe.transform(np.array([[atom.GetDefaultValence()]]))[0])
  136. features.append(atom.GetTotalValence(atom.GetAtomicNum()))
  137. features.append(atom.GetIsAromatic())
  138. features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
  139. features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0])
  140. features.append(atom.IsInRing())
  141. return features
  142. def atom_features_top_single_performers(atom, molecule=None):
  143. me = getMendeleevElement(atom.GetAtomicNum())
  144. features = []
  145. #features.append(atom.GetAtomicNum())
  146. features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
  147. features.append(me.atomic_radius or 0) # Atomic volume
  148. features.append(me.neutrons) # Van der Waals radius
  149. #features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
  150. features.append(me.electron_affinity) # Electron affinity
  151. features.append(me.en_pauling) # Electronegativity
  152. #features.append(me.electrons) # No. of neutrons
  153. return features
  154. def atom_features_top_single_performers_with_solvents(atom, molecule=None):
  155. me = getMendeleevElement(atom.GetAtomicNum())
  156. features = []
  157. features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
  158. #features.append(atom.GetAtomicNum())
  159. features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
  160. features.append(me.atomic_radius or 0) # Atomic volume
  161. features.append(me.neutrons) # Van der Waals radius
  162. #features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
  163. features.append(me.electron_affinity) # Electron affinity
  164. features.append(me.en_pauling) # Electronegativity
  165. features.append(me.electrons) # No. of neutrons
  166. return features
  167. def atom_features(atom, molecule=None):
  168. me = getMendeleevElement(atom.GetAtomicNum())
  169. #[x, y, z] = list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx()))
  170. features = []
  171. #TBD: Do we need to encode molecule atom itself? Or is atomic number sufficient? One-hot encode?
  172. features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
  173. #features.append(atom.GetIsAromatic())
  174. features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
  175. features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
  176. #features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
  177. features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0])
  178. features.append(me.atomic_radius or 0) # Atomic radius
  179. features.append(me.atomic_volume) # Atomic volume
  180. features.append(me.atomic_weight) # Atomic weight
  181. features.append(me.covalent_radius) # Covalent radius
  182. features.append(me.vdw_radius) # Van der Waals radius
  183. features.append(me.dipole_polarizability) # Dipole polarizability
  184. features.append(me.electron_affinity) # Electron affinity
  185. features.append(me.electrophilicity()) # Electrophilicity index
  186. features.append(me.en_pauling) # Electronegativity
  187. features.append(me.electrons) # No. of electrons
  188. features.append(me.neutrons) # No. of neutrons
  189. #features.append(x) # X coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
  190. #features.append(y) # Y coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
  191. #features.append(z) # Z coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
  192. #features.append(0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge'))) #partial charges
  193. features.append(atom.GetChiralTag())
  194. features.append(atom.IsInRing())
  195. return features
  196. def convert_to_graph(molecule, atom_feature_constructor=atom_features):
  197. #Chem.rdPartialCharges.ComputeGasteigerCharges(molecule)
  198. node_features = [atom_feature_constructor(atom, molecule) for atom in molecule.GetAtoms()]
  199. node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()]
  200. edge_features = [bond_features(bond) for bond in molecule.GetBonds()]
  201. edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()]
  202. # Bonds are not directed, so lets add the missing pair to make the graph undirected
  203. edge_index.extend([reversed(bond) for bond in edge_index])
  204. edge_features.extend(edge_features)
  205. # Some node_features had null values in carbon data and then the long graph compilation process was stopped.
  206. if any(None in sublist for sublist in node_features):
  207. return None
  208. return Data(
  209. x=tensor(node_features, dtype=torch.float),
  210. edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(),
  211. edge_attr=tensor(edge_features, dtype=torch.float),
  212. y=tensor([[t] for t in node_targets], dtype=torch.float)
  213. )
  214. def scale_graph_data(latent_graph_list, scaler=None):
  215. if scaler:
  216. node_mean, node_std, edge_mean, edge_std = scaler
  217. print(f"Using existing scaler: {scaler}")
  218. else:
  219. #Iterate through graph list to get stacked NODE and EDGE features
  220. node_stack=[]
  221. edge_stack=[]
  222. for g in latent_graph_list:
  223. node_stack.append(g.x) #Append node features
  224. edge_stack.append(g.edge_attr) #Append edge features
  225. node_cat=torch.cat(node_stack,dim=0)
  226. edge_cat=torch.cat(edge_stack,dim=0)
  227. node_mean=node_cat.mean(dim=0)
  228. node_std=node_cat.std(dim=0,unbiased=False)
  229. edge_mean=edge_cat.mean(dim=0)
  230. edge_std=edge_cat.std(dim=0,unbiased=False)
  231. #Apply zero-mean, unit variance scaling, append scaled graph to list
  232. latent_graph_list_sc=[]
  233. for g in latent_graph_list:
  234. x_sc=g.x-node_mean
  235. x_sc/=node_std
  236. ea_sc=g.edge_attr-edge_mean
  237. ea_sc/=edge_std
  238. ea_sc=torch.nan_to_num(ea_sc, posinf=1.0)
  239. x_sc=torch.nan_to_num(x_sc, posinf=1.0)
  240. temp_graph=Data(x=x_sc,edge_index=g.edge_index,edge_attr=ea_sc, y=g.y)
  241. latent_graph_list_sc.append(temp_graph)
  242. scaler= (node_mean,node_std,edge_mean,edge_std)
  243. return latent_graph_list_sc,scaler
  244. """# Graph Neural Network
  245. ## Separation into training set and test set
  246. """
  247. TRAIN_TEST_SPLIT = 0.8
  248. all_data = list(supplier3d)
  249. random.Random(80).shuffle(all_data)
  250. train_data =all_data[:int(TRAIN_TEST_SPLIT * len(supplier3d))]
  251. test_data =all_data[int(TRAIN_TEST_SPLIT * len(supplier3d)):]
  252. #TODO: Use training data scaler on test data.
  253. train_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(train_data) if molecule])
  254. test_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(test_data) if molecule], scaler=scaler)
  255. #all_data = [convert_to_graph(molecule) for idx, molecule in enumerate(supplier3d) if molecule and idx not in[24,25,30,31,32]]
  256. print(f"Converted {len(supplier3d)} molecules to {len(train_graphs) + len(test_graphs)} graphs")
  257. print(f"Found {sum([sum([1 for shift in graph.y if not math.isnan(shift[0])]) for graph in train_graphs+test_graphs])} individual NMR shifts")
  258. """##2023 model"""
  259. import torch
  260. from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm
  261. from torch_scatter import scatter_mean, scatter_add
  262. from torch_geometric.nn import MetaLayer
  263. from torch_geometric.data import Batch
  264. NO_GRAPH_FEATURES=128
  265. ENCODING_NODE=64
  266. ENCODING_EDGE=32
  267. HIDDEN_NODE=128
  268. HIDDEN_EDGE=64
  269. HIDDEN_GRAPH=128
  270. def init_weights(m):
  271. if type(m) == torch.nn.Linear:
  272. torch.nn.init.xavier_uniform_(m.weight)
  273. m.bias.data.fill_(0.01)
  274. class EdgeModel(torch.nn.Module):
  275. def __init__(self):
  276. super(EdgeModel, self).__init__()
  277. self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
  278. LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
  279. LazyLinear(ENCODING_EDGE)).apply(init_weights)
  280. def forward(self, src, dest, edge_attr, u, batch):
  281. # source, target: [E, F_x], where E is the number of edges.
  282. # edge_attr: [E, F_e]
  283. # u: [B, F_u], where B is the number of graphs.
  284. # batch: [E] with max entry B - 1.
  285. out = torch.cat([src, dest, edge_attr], 1)
  286. return self.edge_mlp(out)
  287. class NodeModel(torch.nn.Module):
  288. def __init__(self):
  289. super(NodeModel, self).__init__()
  290. self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(),
  291. LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
  292. LazyLinear(HIDDEN_NODE)).apply(init_weights)
  293. self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
  294. LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(),
  295. LazyLinear(ENCODING_NODE)).apply(init_weights)
  296. def forward(self, x, edge_index, edge_attr, u, batch):
  297. # x: [N, F_x], where N is the number of nodes.
  298. # edge_index: [2, E] with max entry N - 1.
  299. # edge_attr: [E, F_e]
  300. # u: [B, F_u]
  301. # batch: [N] with max entry B - 1.
  302. row, col = edge_index
  303. out = torch.cat([x[row], edge_attr], dim=1)
  304. out = self.node_mlp_1(out)
  305. out = scatter_add(out, col, dim=0, dim_size=x.size(0))
  306. out = torch.cat([x, out], dim=1)
  307. return self.node_mlp_2(out)
  308. class GlobalModel(torch.nn.Module):
  309. def __init__(self):
  310. super(GlobalModel, self).__init__()
  311. self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
  312. LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(),
  313. LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights)
  314. def forward(self, x, edge_index, edge_attr, u, batch):
  315. # x: [N, F_x], where N is the number of nodes.
  316. # edge_index: [2, E] with max entry N - 1.
  317. # edge_attr: [E, F_e]
  318. # u: [B, F_u]
  319. # batch: [N] with max entry B - 1.
  320. row,col=edge_index
  321. node_aggregate = scatter_add(x, batch, dim=0)
  322. edge_aggregate = scatter_add(edge_attr, batch[col], dim=0)
  323. out = torch.cat([node_aggregate, edge_aggregate], dim=1)
  324. return self.global_mlp(out)
  325. class GNN_FULL_CLASS(torch.nn.Module):
  326. def __init__(self, NO_MP):
  327. super(GNN_FULL_CLASS,self).__init__()
  328. #Meta Layer for Message Passing
  329. self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
  330. #Edge Encoding MLP
  331. self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
  332. LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
  333. LazyLinear(ENCODING_EDGE)).apply(init_weights)
  334. self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
  335. LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
  336. LazyLinear(ENCODING_NODE)).apply(init_weights)
  337. self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10),
  338. LazyBatchNorm1d(),
  339. LazyLinear(HIDDEN_NODE), LeakyReLU(),
  340. LazyBatchNorm1d(),
  341. LazyLinear(1)).apply(init_weights)
  342. self.no_mp = NO_MP
  343. def forward(self,dat):
  344. #Extract the data from the batch
  345. x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch
  346. # Embed the node and edge features
  347. enc_x = self.encoding_node(x)
  348. enc_ea = self.encoding_edge(ea)
  349. #Create the empty molecular graphs for feature extraction, graph level one
  350. u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float)
  351. #Message-Passing
  352. for _ in range(self.no_mp):
  353. enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc)
  354. targs = self.mlp_last(enc_x)
  355. return targs
  356. #Additional helpful methods
  357. def init_model(NO_MP, lr, wd):
  358. # Model
  359. NO_MP = NO_MP
  360. model = GNN_FULL_CLASS(NO_MP)
  361. # Optimizer
  362. LEARNING_RATE = lr
  363. optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd)
  364. # Criterion
  365. #criterion = torch.nn.MSELoss()
  366. criterion = torch.nn.L1Loss()
  367. return model, optimizer, criterion
  368. def train(model, criterion, optimizer, loader):
  369. loss_sum = 0
  370. for batch in loader:
  371. # Forward pass and gradient descent
  372. labels = batch.y
  373. predictions = model(batch)
  374. loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
  375. # Backpropagation
  376. optimizer.zero_grad()
  377. loss.backward()
  378. optimizer.step()
  379. loss_sum += loss.item()
  380. return loss_sum/len(loader)
  381. def evaluate(model, criterion, loader):
  382. loss_sum = 0
  383. with torch.no_grad():
  384. for batch in loader:
  385. # Forward pass
  386. labels = batch.y
  387. predictions = model(batch)
  388. loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
  389. loss_sum += loss.item()
  390. return loss_sum/len(loader)
  391. def chunk_into_n(lst, n):
  392. size = math.ceil(len(lst) / n)
  393. return list(
  394. map(lambda x: lst[x * size:x * size + size],
  395. list(range(n)))
  396. )
  397. """## Training"""
  398. VALIDATION_SPLITS = 4
  399. EPOCHS = 500
  400. BATCH_SIZE = 128
  401. splits =chunk_into_n(train_graphs, VALIDATION_SPLITS)
  402. split_errors=[]
  403. for idx, split in enumerate(splits):
  404. split_train_data = []
  405. for s in splits:
  406. if s!=split:
  407. split_train_data+=s
  408. train_loader = DataLoader(split_train_data, batch_size = BATCH_SIZE)
  409. test_loader = DataLoader(split, batch_size = BATCH_SIZE)
  410. model, optimizer, criterion = init_model(6,0.001,0.1)
  411. loss_list = []
  412. train_err_list = []
  413. test_err_list = []
  414. model.train()
  415. print(f"Split {idx+1}. Training/test split size:{len(split_train_data)}/{len(split)}")
  416. for epoch in range(EPOCHS):
  417. tloss = train(model, criterion, optimizer, train_loader)
  418. train_err = evaluate(model, criterion, train_loader)
  419. test_err = evaluate(model, criterion, test_loader)
  420. loss_list.append(tloss)
  421. train_err_list.append(train_err)
  422. test_err_list.append(test_err)
  423. print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
  424. train_err, test_err))
  425. extra_epochs=0
  426. #Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima.
  427. while extra_epochs<200 and tloss>2.5:
  428. tloss = train(model, criterion, optimizer, train_loader)
  429. train_err = evaluate(model, criterion, train_loader)
  430. test_err = evaluate(model, criterion, test_loader)
  431. loss_list.append(tloss)
  432. train_err_list.append(train_err)
  433. test_err_list.append(test_err)
  434. extra_epochs+1
  435. print("\n")
  436. split_errors.append(test_err)
  437. print(f"Split errors: {split_errors} with average error {sum(split_errors) / VALIDATION_SPLITS}")
  438. """## Evaluation on test set"""
  439. train_loader = DataLoader(train_graphs, batch_size = BATCH_SIZE)
  440. test_loader = DataLoader(test_graphs, batch_size = BATCH_SIZE)
  441. model, optimizer, criterion = init_model(6,0.001,0.1)
  442. loss_list = []
  443. train_err_list = []
  444. test_err_list = []
  445. model.train()
  446. for epoch in range(EPOCHS):
  447. tloss = train(model, criterion, optimizer, train_loader)
  448. train_err = evaluate(model, criterion, train_loader)
  449. test_err = evaluate(model, criterion, test_loader)
  450. loss_list.append(tloss)
  451. train_err_list.append(train_err)
  452. test_err_list.append(test_err)
  453. print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
  454. train_err, test_err))
  455. extra_epochs=0
  456. #Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima.
  457. while extra_epochs<200 and tloss>2.5:
  458. tloss = train(model, criterion, optimizer, train_loader)
  459. train_err = evaluate(model, criterion, train_loader)
  460. test_err = evaluate(model, criterion, test_loader)
  461. loss_list.append(tloss)
  462. train_err_list.append(train_err)
  463. test_err_list.append(test_err)
  464. extra_epochs+1
  465. print("\n")
  466. #print(len(test_loader))
  467. #print( criterion, test_loader)
  468. evaluate(model, criterion, test_loader)
  469. single = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features_top_single_performers_with_solvents) for molecule in [supplier3d[32], supplier3d[32]] if molecule])
  470. evaluate(model, criterion, DataLoader(single, batch_size = 2))
  471. """## Visualisations & Analysis"""
  472. test_loader = DataLoader(testing_data, batch_size = 128)
  473. with torch.no_grad():
  474. for batch in test_loader:
  475. # Forward pass
  476. labels = batch.y
  477. predictions = model(batch)
  478. a =torch.sub(predictions, labels)
  479. for a,b,c,d in zip(predictions, labels, batch.x, a):
  480. if d>50:
  481. print(a,b,c[0],d)
  482. loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
  483. print(loss)
  484. #loss_sum += loss.item()
  485. #return loss_sum/len(loader)
  486. import os
  487. import matplotlib.pyplot as plt
  488. def plot_loss(loss_list):
  489. fig, axs = plt.subplots()
  490. axs.plot(range(1, len(loss_list) + 1), loss_list, label="Loss")
  491. axs.set_xlabel("Epoch")
  492. axs.legend(loc=3)
  493. axs.set_yscale('log')
  494. axs.legend()
  495. def plot_error(train_err_list, test_err_list):
  496. fig, axs = plt.subplots()
  497. axs.plot(range(1, len(train_err_list) + 1), train_err_list, label="Train Err")
  498. axs.plot(range(1, len(test_err_list) + 1), test_err_list, label="Test Err")
  499. axs.set_xlabel("Epoch")
  500. axs.legend(loc=3)
  501. axs.set_yscale('log')
  502. axs.legend()
  503. plot_loss(loss_list)
  504. plot_error(train_err_list, test_err_list)
  505. """We can draw molecules easily for 2D coordinates, so let's download the dataset containing 2D coordinates."""
  506. errs=[]
  507. for i in [100,250,500,len(all_data)]:
  508. SPLIT = 0.75 # Let's 75% for training and 25% for validation
  509. # Shuffle all of our data, as input data might be sorted
  510. random.Random(46).shuffle(all_data)
  511. # Training set
  512. training_data = all_data[:int(i*SPLIT)]
  513. # Validation set
  514. testing_data = all_data[int(i*SPLIT):i]
  515. BATCH_SIZE = 128
  516. train_loader = DataLoader(training_data, batch_size = BATCH_SIZE)
  517. test_loader = DataLoader(testing_data, batch_size = BATCH_SIZE)
  518. print(f"Number of train batches: {len(train_loader)}")
  519. print(f"Number of test batches: {len(test_loader)}")
  520. # Loading model from the paper
  521. # Model
  522. NO_MP = 7
  523. model = GNN_FULL_CLASS(NO_MP)
  524. # Optimizer
  525. LEARNING_RATE = 0.05
  526. optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=.005)
  527. # Criterion
  528. #criterion = torch.nn.MSELoss()
  529. criterion = torch.nn.L1Loss()
  530. EPOCHS = 500
  531. loss_list = []
  532. train_err_list = []
  533. test_err_list = []
  534. model.train()
  535. for epoch in range(EPOCHS):
  536. tloss = train(model, criterion, optimizer, train_loader)
  537. train_err = evaluate(model, criterion, train_loader)
  538. test_err = evaluate(model, criterion, test_loader)
  539. loss_list.append(tloss)
  540. train_err_list.append(train_err)
  541. test_err_list.append(test_err)
  542. print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
  543. train_err, test_err))
  544. print("\n")
  545. test_err = evaluate(model, criterion, test_loader)
  546. print(test_err)
  547. errs.append(test_err)
  548. print(errs)
  549. """# Existing methods
  550. Looking at existing methods.
  551. * Baseline for model comparison
  552. ## Hose
  553. ### Downloads & Initialization
  554. """
  555. from rdkit import Chem
  556. from rdkit.Chem import AllChem
  557. import matplotlib.pyplot as plt
  558. !pip install git+https://github.com/Ratsemaat/HOSE_code_generator
  559. import hosegen
  560. hg = hosegen.HoseGenerator()
  561. """### Preparation"""
  562. def get_hose_code(molecule, distance, atom_idx):
  563. return hg.get_Hose_codes(molecule, atom_idx, max_radius=distance)
  564. get_hose_code(supplier3d[1],5,0)
  565. def get_atom_indexes(molecul, element_nr):
  566. indexes=[]
  567. if not molecul:
  568. return []
  569. for atom in molecul.GetAtoms():
  570. if atom.GetAtomicNum() == element_nr:
  571. indexes.append(atom.GetIdx())
  572. return indexes
  573. mol = Chem.MolFromSmiles('C=CN(CC)CF')
  574. print(get_atom_indexes(mol, 6))
  575. print(get_atom_indexes(mol, 7))
  576. print(get_atom_indexes(mol, 9))
  577. def get_molecule_hose_codes(molecule, radius, atom_idx):
  578. indexes = get_atom_indexes(molecule, 9)
  579. hoses=[]
  580. for idx in indexes:
  581. hoses.append((idx,get_hose_code(molecule, radius, idx)))
  582. return hoses
  583. print(get_molecule_hose_codes(supplier3d[1], 5,9))
  584. import random
  585. SPLIT = 0.75 # Let's 75% for training and 25% for validation
  586. random.seed(42)
  587. # Shuffle all of our data, as input data might be sorted
  588. train_idx = random.sample([i for i in range(len(supplier3d))],math.ceil(len(supplier3d)*SPLIT))
  589. test_idx = set([i for i in range(len(supplier3d))]).difference(set(train_idx))
  590. def get_molecule_nmr_shifts(molecule):
  591. if not molecule:
  592. return {}
  593. map={}
  594. for key, value in molecule.GetPropsAsDict().items():
  595. if key.startswith("Spectrum"):
  596. for shift in value.split('|'):
  597. x = shift.split(';')
  598. if (len(x) == 3):
  599. map[x[2]] = float(x[0])
  600. return map
  601. print(get_molecule_nmr_shifts(supplier3d[1]))
  602. def get_molecule_hose_and_nmr_shifts(molecule, distance, element_nr):
  603. arr=[]
  604. idxs = set(get_atom_indexes(molecule, element_nr))
  605. shifts = get_molecule_nmr_shifts(molecule)
  606. #Only look at atoms where shift is defiend
  607. arr2 = set([int(k) for k in shifts.keys()])
  608. idxs = idxs.intersection(arr2)
  609. for idx in idxs:
  610. hoses = get_hose_code(molecule, distance, idx)
  611. arr.append((idx, hoses, shifts[str(idx)]))
  612. return arr
  613. print(get_molecule_hose_and_nmr_shifts(supplier3d[1], 5, 9))
  614. def split_hose(hose):
  615. return hose.replace('(',' ').replace('/',' ').replace(')',' ').split()
  616. """### Training"""
  617. map={}
  618. for id in train_idx:
  619. for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
  620. parts=split_hose(trio[1])
  621. for i in range(len(parts)):
  622. hose = "/".join(parts[:(i+1)])
  623. if hose not in map:
  624. map[hose] = []
  625. map[hose].append(trio[2])
  626. print(map)
  627. avg_map = {}
  628. for k,v in map.items():
  629. avg_map[k] =sum(v)/len(v)
  630. """### Evaluation"""
  631. predictions = []
  632. labels=[]
  633. familiar_radius = [] # for visualisation
  634. for id in test_idx:
  635. for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
  636. parts=split_hose(trio[1])
  637. for i in range(len(parts),0,-1):
  638. hose = "/".join(parts[:(i)])
  639. if hose in map:
  640. predictions.append(avg_map[hose])
  641. familiar_radius.append(i)
  642. labels.append(trio[2])
  643. break
  644. print(labels)
  645. print(predictions)
  646. errors=[]
  647. for i in range(len(predictions)):
  648. errors.append(abs(labels[i] - predictions[i]))
  649. print(sum(errors)/len(errors))
  650. """### Visualisations & Analysis"""
  651. plt.hist(familiar_radius, bins=[0.5, 1.5,2.5,3.5,4.5,5.5,6.5], align="mid")
  652. plt.title("Longest exact HOSE code length in training set")
  653. plt.show()
  654. errors = [[] for i in range(6)]
  655. for i in range(len(predictions)):
  656. errors[familiar_radius[i]-1].append(abs(labels[i] - predictions[i]))
  657. for i in range(len(errors)):
  658. errors[i] = sum(errors[i])/len(errors[i])
  659. plt.plot([i for i in range(6)], errors)
  660. plt.title("Average error rate of HOSE code")
  661. plt.show()
  662. SPLIT = 0.75 # Let's 75% for training and 25% for validation
  663. hose_error_measurements=[]
  664. for i in range(10):
  665. err_hose=[]
  666. random.Random(42).shuffle(all_data)
  667. for size in [100,250,500,len(all_data)]:
  668. train_idx = random.sample([i for i in range(size)],math.ceil(size*SPLIT))
  669. test_idx = set([i for i in range(size)]).difference(set(train_idx))
  670. map={}
  671. for id in train_idx:
  672. for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
  673. parts=split_hose(trio[1])
  674. for i in range(len(parts)):
  675. hose = "/".join(parts[:(i+1)])
  676. if hose not in map:
  677. map[hose] = []
  678. map[hose].append(trio[2])
  679. avg_map = {}
  680. for k,v in map.items():
  681. avg_map[k] =sum(v)/len(v)
  682. predictions = []
  683. labels=[]
  684. familiar_radius = [] # for visualisation
  685. for id in test_idx:
  686. for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
  687. parts=split_hose(trio[1])
  688. for i in range(len(parts),0,-1):
  689. hose = "/".join(parts[:(i)])
  690. if hose in map:
  691. predictions.append(avg_map[hose])
  692. familiar_radius.append(i)
  693. labels.append(trio[2])
  694. break
  695. errors=[]
  696. for i in range(len(predictions)):
  697. errors.append(abs(labels[i] - predictions[i]))
  698. err_hose.append(sum(errors)/len(errors))
  699. hose_error_measurements.append(err_hose)
  700. sum_100=0
  701. sum_250=0
  702. sum_500=0
  703. sum_all=0
  704. for i in hose_error_measurements:
  705. sum_100 += i[0]
  706. sum_250 += i[1]
  707. sum_500 += i[2]
  708. sum_all += i[3]
  709. avg_err_hose = [sum_100/len(hose_error_measurements),sum_250/len(hose_error_measurements),sum_500/len(hose_error_measurements),sum_all/len(hose_error_measurements)]
  710. err_gnn = [66.1449966430664, 30.329145431518555, 9.740877151489258, 10.084354400634766]
  711. print(avg_err_hose)
  712. plt.plot([100,250,500,970],err_gnn, label="GNN model", color="red")
  713. plt.plot([100,250,500,970],avg_err_hose, label="HOSE code model", color="green")
  714. plt.xlabel("Number of spectra")
  715. plt.ylabel("ppm")
  716. plt.title("Now. Error in ppm in relation to used examples.")
  717. plt.legend()
  718. plt.show()
  719. """# Results
  720. Since obtaining the results
  721. ## Edge and node features
  722. Testing difderent descriptors with fixed hyperparameters(weight decay=0.005, learning rate=0.0003, NO_MP=7). When testing node features the edge_features were also fixed and when testing edge features then node features were fixed. So in other words feature results were obtained by only playing with corresponding values. Each feature was tested using 4-fold crossvalidation on train data (80% of whole dataset).
  723. """
  724. node_features_results = {'ohe atomic number': [11.206879615783691, 10.606162071228027, 10.678353309631348, 14.426544666290283],
  725. 'isAromatic': [20.246356964111328, 21.289775848388672, 22.023090362548828, 21.481626510620117],
  726. 'hyb ohe': [14.719367980957031, 14.110363483428955, 17.14181661605835, 17.19037103652954],
  727. 'valence ohe': [16.09677743911743, 16.301819801330566, 18.414384841918945, 14.627246856689453],
  728. 'valence': [15.74299955368042, 15.868663311004639, 17.854907989501953, 14.021881580352783],
  729. 'inRing': [20.925933837890625, 19.979466438293457, 20.020546913146973, 19.80652141571045],
  730. 'hybridization': [15.455873012542725, 16.48819398880005, 16.359801769256592, 17.977892875671387],
  731. 'atomic num': [10.075446605682373, 10.49792766571045, 11.835340023040771, 12.882370471954346],
  732. 'atomic charge': [20.686185836791992, 22.381511688232422, 27.067391395568848, 20.001076698303223],
  733. 'atomic radius': [11.015857696533203, 10.15432596206665, 10.100831508636475, 13.576239109039307],
  734. 'atomic volume': [11.949167728424072, 8.738459587097168, 10.427918434143066, 13.632889747619629],
  735. 'atomic weight': [10.183413028717041, 9.64014196395874, 11.397083282470703, 11.427015781402588],
  736. 'covalent radius': [10.514074325561523, 9.394641399383545, 10.941582679748535, 11.956562995910645],
  737. 'vdw radius': [9.800463676452637, 9.21529245376587, 10.46424388885498, 13.804336071014404],
  738. 'dipole polarizability': [10.734958171844482, 10.39086389541626, 14.207355976104736, 14.054275035858154],
  739. 'electron affinity': [12.086753368377686, 9.388242721557617, 9.779114246368408, 14.187011241912842],
  740. 'electrophilicity index': [11.224877834320068, 8.167922496795654, 10.476778030395508, 13.519041061401367],
  741. 'electronegativity': [9.936093807220459, 10.196239948272705, 9.894521236419678, 14.182630062103271],
  742. 'electrons': [10.851109981536865, 9.980653285980225, 9.819206714630127, 14.225136280059814],
  743. 'neutrons': [9.413500308990479, 9.620219707489014, 9.220077753067017, 13.0865478515625],
  744. 'cooridates': [24.324893951416016, 23.782126426696777, 27.821096420288086, 22.696171760559082],
  745. 'formal charge': [21.54022789001465, 21.80484676361084, 25.49346923828125, 21.4456787109375],
  746. 'formal charge ohe': [19.7533016204834, 24.704838752746582, 23.275280952453613, 21.496562004089355],
  747. 'chiral tag': [22.99411392211914, 22.16276741027832, 26.671720504760742, 19.622788429260254],
  748. 'random': [27.85106086730957, 26.663583755493164, 27.1346435546875, 24.93829917907715]
  749. }
  750. bond_features_results = {'smart,atoms': [10.044915676116943, 8.644615173339844, 10.853336811065674, 13.183335781097412],
  751. 'pure,rdkit': [11.945857048034668, 9.031915664672852, 10.515658855438232, 13.471199989318848],
  752. 'pure,atoms': [11.971851825714111, 9.971860885620117, 11.541802883148193, 13.614596843719482],
  753. 'smart,rdkit': [9.997193813323975, 9.00916337966919, 9.385982513427734, 12.707754611968994]}
  754. import pandas as pd
  755. import seaborn as sns
  756. import matplotlib.pyplot as plt
  757. df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
  758. df2 = pd.DataFrame(columns = ['descriptor', 'accuracy'])
  759. for key, value in bond_features_results.items():
  760. temp_df = pd.DataFrame({'descriptor': [key]*len(value),
  761. 'accuracy': value})
  762. df2 = df2.append(temp_df, ignore_index = True)
  763. for key, value in node_features_results.items():
  764. temp_df = pd.DataFrame({'descriptor': [key]*len(value),
  765. 'accuracy': value})
  766. df = df.append(temp_df, ignore_index = True)
  767. #print(df)
  768. my_order = df.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index
  769. my_order_2 = df2.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index
  770. fig, ax = plt.subplots(2, figsize=(8, 20))
  771. sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax[0], order=my_order)
  772. sns.violinplot(data=df2, y="descriptor", x= "accuracy", ax=ax[1], order=my_order_2)
  773. plt.show()
  774. """## Hyperparameter selection"""
  775. !wget -nc https://www.dropbox.com/s/k6kgqag3daqv90y/hp_results.csv
  776. with open("hp_results.csv") as f:
  777. hp_results = pd.read_csv(f)
  778. with pd.option_context('display.max_rows', None,
  779. 'display.max_columns', None,
  780. 'display.precision', 3,
  781. ):
  782. print(hp_results)
  783. print(hp_results.groupby(["m", "lr", "wd"])['accuracy'].mean())
  784. print(hp_results.groupby(["m"])['accuracy'].mean())
  785. print(hp_results.groupby(["lr"])['accuracy'].mean())
  786. print(hp_results.groupby(["wd"])['accuracy'].mean())
  787. """## Different models(collections of features)
  788. 2019 <br>
  789. Top N Features
  790. ### 2019 model feautures
  791. """
  792. paper_2019_features_results = [8.671473503112793, 14.09905195236206, 12.203859806060791, 12.20280122756958]
  793. """### Top N features"""
  794. n_feature_results = {"['neutrons']": [10.858759880065918, 7.9520745277404785, 10.061083316802979, 13.083988666534424], "['neutrons', 'atomic weight']": [10.279156684875488, 10.209141254425049, 9.95799732208252, 13.340500354766846], "['neutrons', 'atomic weight', 'covalent radius']": [10.954660892486572, 8.685662269592285, 8.889871597290039, 11.358663558959961], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius']": [11.0108060836792, 10.469686031341553, 10.534903049468994, 12.8761568069458], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index']": [10.096431255340576, 9.76186466217041, 9.827410221099854, 12.698175430297852], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity']": [10.908030033111572, 8.09628176689148, 9.25707221031189, 11.68080759048462], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume']": [9.88636827468872, 9.501356601715088, 9.424680233001709, 12.281991004943848], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius']": [12.68621015548706, 9.14622163772583, 10.40596866607666, 11.856515407562256], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons']": [10.266335487365723, 9.108633041381836, 11.96642780303955, 13.67209243774414], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num']": [10.689188957214355, 9.288278579711914, 10.050706624984741, 14.055354118347168], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity']": [10.014832019805908, 9.91117525100708, 9.540995359420776, 13.940445899963379], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number']": [10.295331001281738, 9.557802677154541, 11.178434371948242, 10.574723720550537], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability']": [10.669241905212402, 9.187427043914795, 9.547076225280762, 12.491205215454102], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe']": [10.314173221588135, 10.875812530517578, 12.00346040725708, 11.990623950958252], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence']": [9.84071683883667, 9.872222900390625, 8.903747081756592, 11.116456031799316], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe']": [11.037650108337402, 10.227884292602539, 11.569035053253174, 15.139636516571045], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization']": [9.662659645080566, 10.27301549911499, 10.229931831359863, 12.425177097320557], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing']": [12.145603656768799, 9.442389011383057, 10.04677677154541, 13.444748401641846], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic']": [10.830190181732178, 9.220112323760986, 9.297557353973389, 13.27143907546997], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe']": [10.53205156326294, 10.221848964691162, 10.479277610778809, 11.17846155166626], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge']": [10.679495811462402, 9.242655277252197, 10.48193883895874, 12.0442533493042], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge']": [9.10008192062378, 9.387939929962158, 10.72908878326416, 12.20565414428711], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag']": [11.04087209701538, 8.146693706512451, 10.116810321807861, 10.707176208496094], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag', 'cooridates']": [11.374996185302734, 10.41971206665039, 11.041788578033447, 13.181884288787842]}
  795. df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
  796. for key, value in n_feature_results.items():
  797. temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))),
  798. 'accuracy': value})
  799. df = df.append(temp_df, ignore_index = True)
  800. fig, ax = plt.subplots(1,figsize=(15, 15))
  801. sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax)
  802. plt.show()
  803. print(df.groupby(by=["descriptor"])["accuracy"].mean().sort_values())
  804. """### Comparison"""
  805. df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
  806. temp_df = pd.DataFrame({'descriptor': '2019 model features', 'accuracy': paper_2019_features_results})
  807. df = df.append(temp_df, ignore_index = True)
  808. for key, value in n_feature_results.items():
  809. temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))),
  810. 'accuracy': value})
  811. df = df.append(temp_df, ignore_index = True)
  812. fig, ax = plt.subplots(1,figsize=(15, 15))
  813. comparsion_df = df.loc[(df['descriptor'] =='15') | (df['descriptor'] =='2019 model features')]
  814. sns.violinplot(data=comparsion_df, y="descriptor", x= "accuracy", ax=ax)
  815. plt.show()
  816. """## Varying dataset sizes
  817. ### Flourine
  818. """
  819. !wget -nc https://www.dropbox.com/s/m1alwawphb1chbc/fluorine_results.csv
  820. with open("fluorine_results.csv") as f:
  821. df = pd.read_csv(f)
  822. df.groupby(["model","dataset_size"]).mean()
  823. fig, ax = plt.subplots(figsize=(8, 10))
  824. sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
  825. plt.show()
  826. """### Carbon with choloroform as solvent
  827. """
  828. !wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv
  829. with open("chloroform_results.csv") as f:
  830. df = pd.read_csv(f)
  831. print(df.groupby(["model","dataset_size"]).mean())
  832. print(df)
  833. fig, ax = plt.subplots(figsize=(8, 10))
  834. sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
  835. plt.show()
  836. """### Carbon with methanol as solvent """
  837. !wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv
  838. with open("methanol_results.csv") as f:
  839. df = pd.read_csv(f)
  840. print(df.groupby(["model","dataset_size"]).mean())
  841. fig, ax = plt.subplots(figsize=(8, 10))
  842. sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
  843. plt.show()
  844. """### Carbon with dmso as solvent """
  845. !wget -nc https://www.dropbox.com/s/gdwv8pspayssk27/dmso_results.csv
  846. with open("dmso_results.csv") as f:
  847. df = pd.read_csv(f)
  848. print(df.groupby(["model","dataset_size"]).mean())
  849. fig, ax = plt.subplots(figsize=(8, 10))
  850. sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
  851. plt.show()
  852. """# Reproducing results
  853. Functions that provided in aforementioned results.
  854. To repeat any experiment you must:
  855. * Run the core block to load necessary helper methods.
  856. * Run the corresponding experiment block to obtain the results.
  857. ## Core block
  858. """
  859. TEST_SPLIT= 0.8
  860. # Install required dependencies
  861. !pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file
  862. !pip3 install mendeleev # To to access various features related to atoms
  863. import torch
  864. # PyTorch dependencies to represent graph data
  865. !pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
  866. !pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
  867. !pip3 install torch-geometric
  868. # Use CUDA by default
  869. if torch.cuda.is_available():
  870. torch.set_default_tensor_type('torch.cuda.FloatTensor')
  871. # Downloads the nmrshiftdb2 database if it does not yet exist in our runtime
  872. !wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine
  873. from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm
  874. from torch_scatter import scatter_mean, scatter_add
  875. from torch_geometric.nn import MetaLayer
  876. from torch_geometric.data import Batch,Data
  877. from torch import tensor
  878. from torch_geometric.loader import DataLoader
  879. import numpy as np
  880. from rdkit import Chem
  881. from sklearn.preprocessing import OneHotEncoder
  882. import random
  883. import math
  884. import mendeleev
  885. import pandas as pd
  886. NO_GRAPH_FEATURES=128
  887. ENCODING_NODE=64
  888. ENCODING_EDGE=32
  889. HIDDEN_NODE=128
  890. HIDDEN_EDGE=64
  891. HIDDEN_GRAPH=128
  892. def init_weights(m):
  893. if type(m) == torch.nn.Linear:
  894. torch.nn.init.xavier_uniform_(m.weight)
  895. m.bias.data.fill_(0.01)
  896. class EdgeModel(torch.nn.Module):
  897. def __init__(self):
  898. super(EdgeModel, self).__init__()
  899. self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
  900. LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
  901. LazyLinear(ENCODING_EDGE)).apply(init_weights)
  902. def forward(self, src, dest, edge_attr, u, batch):
  903. # source, target: [E, F_x], where E is the number of edges.
  904. # edge_attr: [E, F_e]
  905. # u: [B, F_u], where B is the number of graphs.
  906. # batch: [E] with max entry B - 1.
  907. out = torch.cat([src, dest, edge_attr], 1)
  908. return self.edge_mlp(out)
  909. class NodeModel(torch.nn.Module):
  910. def __init__(self):
  911. super(NodeModel, self).__init__()
  912. self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(),
  913. LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
  914. LazyLinear(HIDDEN_NODE)).apply(init_weights)
  915. self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
  916. LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(),
  917. LazyLinear(ENCODING_NODE)).apply(init_weights)
  918. def forward(self, x, edge_index, edge_attr, u, batch):
  919. # x: [N, F_x], where N is the number of nodes.
  920. # edge_index: [2, E] with max entry N - 1.
  921. # edge_attr: [E, F_e]
  922. # u: [B, F_u]
  923. # batch: [N] with max entry B - 1.
  924. row, col = edge_index
  925. out = torch.cat([x[row], edge_attr], dim=1)
  926. out = self.node_mlp_1(out)
  927. out = scatter_add(out, col, dim=0, dim_size=x.size(0))
  928. out = torch.cat([x, out], dim=1)
  929. return self.node_mlp_2(out)
  930. class GlobalModel(torch.nn.Module):
  931. def __init__(self):
  932. super(GlobalModel, self).__init__()
  933. self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
  934. LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(),
  935. LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights)
  936. def forward(self, x, edge_index, edge_attr, u, batch):
  937. # x: [N, F_x], where N is the number of nodes.
  938. # edge_index: [2, E] with max entry N - 1.
  939. # edge_attr: [E, F_e]
  940. # u: [B, F_u]
  941. # batch: [N] with max entry B - 1.
  942. row,col=edge_index
  943. node_aggregate = scatter_add(x, batch, dim=0)
  944. edge_aggregate = scatter_add(edge_attr, batch[col], dim=0)
  945. out = torch.cat([node_aggregate, edge_aggregate], dim=1)
  946. return self.global_mlp(out)
  947. class GNN_FULL_CLASS(torch.nn.Module):
  948. def __init__(self, NO_MP):
  949. super(GNN_FULL_CLASS,self).__init__()
  950. #Meta Layer for Message Passing
  951. self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
  952. #Edge Encoding MLP
  953. self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
  954. LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
  955. LazyLinear(ENCODING_EDGE)).apply(init_weights)
  956. self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
  957. LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
  958. LazyLinear(ENCODING_NODE)).apply(init_weights)
  959. self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10),
  960. LazyBatchNorm1d(),
  961. LazyLinear(HIDDEN_NODE), LeakyReLU(),
  962. LazyBatchNorm1d(),
  963. LazyLinear(1)).apply(init_weights)
  964. self.no_mp = NO_MP
  965. def forward(self,dat):
  966. #Extract the data from the batch
  967. x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch
  968. # Embed the node and edge features
  969. enc_x = self.encoding_node(x)
  970. enc_ea = self.encoding_edge(ea)
  971. #Create the empty molecular graphs for feature extraction, graph level one
  972. u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float)
  973. #Message-Passing
  974. for _ in range(self.no_mp):
  975. enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc)
  976. targs = self.mlp_last(enc_x)
  977. return targs
  978. el_map={}
  979. def atom_features_default():
  980. feature_getters = {}
  981. feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
  982. feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
  983. feature_getters["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
  984. feature_getters["hybridization"] = lambda atom: atom.GetHybridization()
  985. feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
  986. feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
  987. feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
  988. feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
  989. feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
  990. feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
  991. feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
  992. feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
  993. feature_getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
  994. #feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges
  995. feature_getters["chiral tag"] = lambda atom: atom.GetChiralTag()
  996. return feature_getters
  997. def bond_feature_smart_distance_and_rdkit_type(bond):
  998. onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
  999. [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
  1000. [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
  1001. ex_dist = getNaiveBondLength(bond)
  1002. distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance
  1003. return distance+ list(onehot_encoded_bondtype)
  1004. def getMendeleevElement(nr):
  1005. if nr not in el_map:
  1006. el_map[nr] = mendeleev.element(nr)
  1007. return el_map[nr]
  1008. def nmr_shift(atom):
  1009. for key, value in atom.GetOwningMol().GetPropsAsDict().items():
  1010. if key.startswith("Spectrum"):
  1011. for shift in value.split('|'):
  1012. x = shift.split(';')
  1013. if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"):
  1014. return float(x[0])
  1015. return float("NaN") # We use NaN for atoms we don't want to predict shifts
  1016. def bond_features_distance_only(bond):
  1017. #onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
  1018. [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
  1019. [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
  1020. distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
  1021. return distance#+list(onehot_encoded_bondtype)
  1022. def flatten(l):
  1023. ret=[]
  1024. for el in l:
  1025. if isinstance(el, list) or isinstance(el, np.ndarray):
  1026. ret.extend(el)
  1027. else:
  1028. ret.append(el)
  1029. return ret
  1030. def turn_to_graph (molecule, atom_feature_getters= atom_features_default().values(), bond_features=bond_features_distance_only):
  1031. node_features = [flatten([getter(atom) for getter in atom_feature_getters ]) for atom in molecule.GetAtoms() ]
  1032. node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()]
  1033. edge_features = [bond_features(bond) for bond in molecule.GetBonds()]
  1034. edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()]
  1035. # Bonds are not directed, so lets add the missing pair to make the graph undirected
  1036. edge_index.extend([reversed(bond) for bond in edge_index])
  1037. edge_features.extend(edge_features)
  1038. # Some node_features had null values in carbon data and then the long graph compilation process was stopped.
  1039. if any(None in sublist for sublist in node_features):
  1040. return None
  1041. return Data(
  1042. x=tensor(node_features, dtype=torch.float),
  1043. edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(),
  1044. edge_attr=tensor(edge_features, dtype=torch.float),
  1045. y=tensor([[t] for t in node_targets], dtype=torch.float)
  1046. )
  1047. def init_model(NO_MP, lr, wd):
  1048. # Model
  1049. NO_MP = NO_MP
  1050. model = GNN_FULL_CLASS(NO_MP)
  1051. # Optimizer
  1052. LEARNING_RATE = lr
  1053. optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd)
  1054. # Criterion
  1055. #criterion = torch.nn.MSELoss()
  1056. criterion = torch.nn.L1Loss()
  1057. return model, optimizer, criterion
  1058. def train(model, criterion, optimizer, loader):
  1059. loss_sum = 0
  1060. for batch in loader:
  1061. # Forward pass and gradient descent
  1062. labels = batch.y
  1063. predictions = model(batch)
  1064. loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
  1065. # Backpropagation
  1066. optimizer.zero_grad()
  1067. loss.backward()
  1068. optimizer.step()
  1069. loss_sum += loss.item()
  1070. return loss_sum/len(loader)
  1071. def evaluate(model, criterion, loader):
  1072. loss_sum = 0
  1073. with torch.no_grad():
  1074. for batch in loader:
  1075. # Forward pass
  1076. labels = batch.y
  1077. predictions = model(batch)
  1078. loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
  1079. loss_sum += loss.item()
  1080. return loss_sum/len(loader)
  1081. def chunk_into_n(lst, n):
  1082. size = math.ceil(len(lst) / n)
  1083. return list(
  1084. map(lambda x: lst[x * size:x * size + size],
  1085. list(range(n)))
  1086. )
  1087. def get_data_loaders(data, split, batch_size):
  1088. random.Random().shuffle(data)
  1089. training_data = all_data[:int(len(all_data)*split)]
  1090. testing_data = all_data[int(len(all_data)*split):]
  1091. train_loader = DataLoader(training_data, batch_size = batch_size)
  1092. test_loader = DataLoader(testing_data, batch_size = batch_size)
  1093. return train_loader, test_loader
  1094. def getBondElements(bond):
  1095. a= bond.GetEndAtom().GetSymbol()
  1096. b = bond.GetBeginAtom().GetSymbol()
  1097. return a+b if a<b else b+a
  1098. def getNaiveBondLength(bond):
  1099. a = getMendeleevElement(bond.GetEndAtom().GetAtomicNum()).atomic_radius or 0
  1100. b = getMendeleevElement(bond.GetBeginAtom().GetAtomicNum()).atomic_radius or 0
  1101. return a/200.0 + b/200.0
  1102. def train_model(train_set, test_set, split, batch_size, epochs, weight_decay=0.005, learning_rate=0.0003, NO_MP=7):
  1103. NO_MP = NO_MP
  1104. model = GNN_FULL_CLASS(NO_MP)
  1105. # Optimizer
  1106. LEARNING_RATE = learning_rate
  1107. optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=weight_decay)
  1108. # Criterion
  1109. #criterion = torch.nn.MSELoss()
  1110. criterion = torch.nn.L1Loss()
  1111. model.train()
  1112. train_loader = DataLoader(train_set, batch_size = batch_size)
  1113. test_loader = DataLoader(test_set, batch_size = batch_size)
  1114. for epoch in range(epochs):
  1115. tloss = train(model, criterion, optimizer, train_loader)
  1116. train_err = evaluate(model, criterion, train_loader)
  1117. test_err = evaluate(model, criterion, test_loader)
  1118. #print(train_err, test_err)
  1119. return (test_err, tloss, train_err)
  1120. def scale_graph_data(latent_graph_list):
  1121. #Iterate through graph list to get stacked NODE and EDGE features
  1122. node_stack=[]
  1123. edge_stack=[]
  1124. for g in latent_graph_list:
  1125. node_stack.append(g.x) #Append node features
  1126. edge_stack.append(g.edge_attr) #Append edge features
  1127. node_cat=torch.cat(node_stack,dim=0)
  1128. edge_cat=torch.cat(edge_stack,dim=0)
  1129. #Calculate NODE feature MEAN
  1130. node_mean=node_cat.mean(dim=0)
  1131. #Calculate NODE feature STD
  1132. node_std=node_cat.std(dim=0,unbiased=False)
  1133. #Calculate EDGE feature MEAN
  1134. edge_mean=edge_cat.mean(dim=0)
  1135. #Calculate EDGE feature STD
  1136. edge_std=edge_cat.std(dim=0,unbiased=False)
  1137. #Apply zero-mean, unit variance scaling, append scaled graph to list
  1138. latent_graph_list_sc=[]
  1139. for g in latent_graph_list:
  1140. x_sc=g.x-node_mean
  1141. x_sc/=node_std
  1142. ea_sc=g.edge_attr-edge_mean
  1143. ea_sc/=edge_std
  1144. ea_sc[ea_sc != ea_sc] = 0
  1145. x_sc[x_sc != x_sc] = 0
  1146. temp_graph=Data(x=x_sc,edge_index=g.edge_index,edge_attr=ea_sc, y=g.y)
  1147. latent_graph_list_sc.append(temp_graph)
  1148. return latent_graph_list_sc
  1149. bond_idxes = np.array([Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC])
  1150. bond_idxes = bond_idxes.reshape(len(bond_idxes), 1)
  1151. onehot_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")
  1152. onehot_encoder.fit(bond_idxes)
  1153. ## Hybridization
  1154. hybridization_idxes = np.array(list(Chem.HybridizationType.names))
  1155. hybridization_idxes = hybridization_idxes.reshape(len(hybridization_idxes), 1)
  1156. hybridization_ohe = OneHotEncoder(sparse=False)
  1157. hybridization_ohe.fit(hybridization_idxes)
  1158. ## Valence
  1159. valences = np.arange(1, 8);
  1160. valences = valences.reshape(len(valences), 1)
  1161. valence_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
  1162. valence_ohe.fit(valences)
  1163. ## Formal Charge
  1164. fc = np.arange(-1, 1);
  1165. fc = fc.reshape(len(fc), 1)
  1166. fc_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
  1167. fc_ohe.fit(fc)
  1168. ## Atomic number
  1169. atomic_nums = np.array([6,1,7,8,9,17,15,11, 16])
  1170. atomic_nums = atomic_nums.reshape(len(atomic_nums), 1)
  1171. atomic_number_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
  1172. atomic_number_ohe.fit(atomic_nums)
  1173. atomic_number_ohe.transform(np.array([[1]]))
  1174. supplier3d = Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine
  1175. all_data = list(supplier3d)
  1176. random.Random(80).shuffle(all_data)
  1177. number_of_elements=len(all_data)
  1178. training_data = all_data[:int(number_of_elements*TEST_SPLIT)]
  1179. testing_data = all_data[int(number_of_elements*TEST_SPLIT):number_of_elements]
  1180. """## Experiment blocks
  1181. ### Node features
  1182. """
  1183. NO_SPLITS= 4
  1184. descriptor_dict = {}
  1185. descriptor_dict["random"]=lambda atom: random.random()
  1186. descriptor_dict["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
  1187. descriptor_dict["isAromatic"] = lambda atom: atom.GetIsAromatic()
  1188. descriptor_dict["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
  1189. descriptor_dict["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
  1190. descriptor_dict["valence"] = lambda atom: atom.GetTotalValence()
  1191. descriptor_dict["inRing"] = lambda atom: atom.IsInRing()
  1192. descriptor_dict["hybridization"] = lambda atom: atom.GetHybridization()
  1193. descriptor_dict["atomic num"] = lambda atom: atom.GetAtomicNum() # Atomic number
  1194. descriptor_dict["atomic charge"]= lambda atom: atom.GetFormalCharge() # Atomic charge
  1195. descriptor_dict["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
  1196. descriptor_dict["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
  1197. descriptor_dict["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
  1198. descriptor_dict["covalent radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).covalent_radius # Covalent radius
  1199. descriptor_dict["vdw radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).vdw_radius # Van der Waals radius
  1200. descriptor_dict["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
  1201. descriptor_dict["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
  1202. descriptor_dict["electrophilicity index"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrophilicity() # Electrophilicity index
  1203. descriptor_dict["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
  1204. descriptor_dict["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
  1205. descriptor_dict["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
  1206. descriptor_dict["cooridates"] = lambda atom: list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx())) # coordinates
  1207. descriptor_dict["formal charge"] = lambda atom: atom.GetFormalCharge()
  1208. descriptor_dict["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
  1209. #feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges
  1210. descriptor_dict["chiral tag"] = lambda atom: atom.GetChiralTag()
  1211. node_features_results = {}
  1212. for getter in descriptor_dict.items():
  1213. descriptor, func = getter
  1214. print(f"Evaluating descriptor: {descriptor}")
  1215. getter_results = []
  1216. mol_graphs = scale_graph_data([turn_to_graph(mol, [func]) for mol in training_data if mol])
  1217. splits =chunk_into_n(mol_graphs, NO_SPLITS)
  1218. for split in splits:
  1219. train_data = []
  1220. for s in splits:
  1221. if s!=split:
  1222. train_data+=s
  1223. getter_results.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500)[0])
  1224. node_features_results[descriptor] = getter_results
  1225. """### Edge features"""
  1226. NO_SPLITS=4
  1227. BONDS= ["CC", "CF", "CH", "CCl","CN", "FN", "FH","HN" ,"ClN","HO", "CO","NO"] #Pair characters ordererd alphabetically(CO rather than OC)
  1228. bonded_atoms_idxs = np.array(BONDS)
  1229. bonded_atoms_idxs = bonded_atoms_idxs.reshape(len(bonded_atoms_idxs), 1)
  1230. bonded_atoms_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")
  1231. bonded_atoms_encoder.fit(bonded_atoms_idxs)
  1232. def getBondElements(bond):
  1233. a= bond.GetEndAtom().GetSymbol()
  1234. b = bond.GetBeginAtom().GetSymbol()
  1235. return a+b if a<b else b+a
  1236. def getNaiveBondLength(bond):
  1237. a = getMendeleevElement(bond.GetEndAtom().GetAtomicNum()).atomic_radius or 0
  1238. b = getMendeleevElement(bond.GetBeginAtom().GetAtomicNum()).atomic_radius or 0
  1239. return a/200.0 + b/200.0
  1240. def bond_feature_pure_distance_and_rdkit_type(bond):
  1241. onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
  1242. [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
  1243. [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
  1244. distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
  1245. return distance+list(onehot_encoded_bondtype)
  1246. def bond_feature_pure_distance_with_bonded_atoms(bond):
  1247. onehot_encoded_bondtype = bonded_atoms_encoder.transform(np.array([[getBondElements(bond)]]))[0]
  1248. [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
  1249. [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
  1250. distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
  1251. return distance +list(onehot_encoded_bondtype)
  1252. def bond_feature_smart_distance_and_rdkit_type(bond):
  1253. onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
  1254. [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
  1255. [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
  1256. ex_dist = getNaiveBondLength(bond)
  1257. distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance
  1258. return distance+ list(onehot_encoded_bondtype)
  1259. def bond_feature_smart_distance_with_bonded_atoms(bond):
  1260. onehot_encoded_bondtype = bonded_atoms_encoder.transform(np.array([[getBondElements(bond)]]))[0]
  1261. [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
  1262. [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
  1263. ex_dist = getNaiveBondLength(bond)
  1264. distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance
  1265. return distance+list(onehot_encoded_bondtype)
  1266. def bond_features_random(bond):
  1267. return [random.random()]
  1268. def all_bond_feature_functions():
  1269. return {
  1270. "smart,atoms":bond_feature_smart_distance_with_bonded_atoms,
  1271. "pure,rdkit":bond_feature_pure_distance_and_rdkit_type,
  1272. "pure,atoms":bond_feature_pure_distance_with_bonded_atoms,
  1273. "smart,rdkit":bond_feature_smart_distance_and_rdkit_type,
  1274. "random":bond_features_random,
  1275. }
  1276. all_functions = all_bond_feature_functions()
  1277. bond_features_results ={}
  1278. for name, method in all_functions.items():
  1279. print(name)
  1280. getter_results = []
  1281. mol_graphs = scale_graph_data([turn_to_graph(mol, bond_features=method) for mol in training_data if mol])
  1282. splits =chunk_into_n(mol_graphs, NO_SPLITS)
  1283. for split in splits:
  1284. train_data = []
  1285. for s in splits:
  1286. if s!=split:
  1287. train_data+=s
  1288. getter_results.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500))
  1289. bond_features_results[name] = getter_results
  1290. """### Hyperparameters"""
  1291. NO_SPLITS= 4
  1292. import pandas as pd
  1293. hp_results = pd.DataFrame(columns = ['wd', 'lr','m', 'accuracy']) #Dataframe in which all results are saved in
  1294. mol_graphs = scale_graph_data([turn_to_graph(mol) for mol in training_data if mol])
  1295. splits =chunk_into_n(mol_graphs, NO_SPLITS)
  1296. for m in range(4,8):
  1297. for lr in [0.0007, 0.001,0.0013]:
  1298. for wd in [ 0.0025, 0.005, 0.0075, 0.01, 0.015]:
  1299. for split in splits:
  1300. train_data = []
  1301. for s in splits:
  1302. if s!=split:
  1303. train_data+=s
  1304. test_err, tloss, train_err = train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500, weight_decay=wd, learning_rate=lr, NO_MP=m)
  1305. hp_results = hp_results.append({'accuracy': test_err, 'loss':tloss, "train_err":train_err, 'm':m,"lr":lr, "wd":wd},ignore_index=True)
  1306. """### Top N feature model - finding optimal n?"""
  1307. NO_SPLITS=4
  1308. def all_atom_feature_ordered_by_acc():
  1309. feature_getters = {}
  1310. feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
  1311. feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
  1312. feature_getters["covalent radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).covalent_radius # Covalent radius
  1313. feature_getters["vdw radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).vdw_radius # Van der Waals radius
  1314. feature_getters["electrophilicity index"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrophilicity() # Electrophilicity index
  1315. feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
  1316. feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
  1317. feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
  1318. feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
  1319. feature_getters["atomic num"] = lambda atom: atom.GetAtomicNum() # Atomic number
  1320. feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
  1321. feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
  1322. feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
  1323. feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
  1324. feature_getters["valence"] = lambda atom: atom.GetTotalValence()
  1325. feature_getters["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
  1326. feature_getters["hybridization"] = lambda atom: atom.GetHybridization()
  1327. feature_getters["inRing"] = lambda atom: atom.IsInRing()
  1328. feature_getters["isAromatic"] = lambda atom: atom.GetIsAromatic()
  1329. feature_getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
  1330. feature_getters["atomic charge"]= lambda atom: atom.GetFormalCharge() # Atomic charge
  1331. feature_getters["formal charge"] = lambda atom: atom.GetFormalCharge()
  1332. #feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges
  1333. feature_getters["chiral tag"] = lambda atom: atom.GetChiralTag()
  1334. feature_getters["cooridates"] = lambda atom: list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx())) # coordinates
  1335. return feature_getters
  1336. model_results_df = pd.DataFrame(columns = ['model', 'accuracy']) #Dataframe in which all results are saved in
  1337. all_getters = all_atom_feature_ordered_by_acc()
  1338. n_feature_results ={}
  1339. used_methods =[]
  1340. features=[]
  1341. for getters in all_getters.items():
  1342. print(getters)
  1343. used_methods.append(getters[1])
  1344. features.append(getters[0])
  1345. getter_results = []
  1346. mol_graphs = scale_graph_data([turn_to_graph(mol, used_methods) for mol in training_data if mol])
  1347. splits =chunk_into_n(mol_graphs, NO_SPLITS)
  1348. for split in splits:
  1349. train_data = []
  1350. for s in splits:
  1351. if s!=split:
  1352. train_data+=s
  1353. getter_results.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500)[0])
  1354. n_feature_results[str(features)] = getter_results
  1355. print(chunk_into_n([1,23,45123,1231,12,123,123,123], 2))
  1356. """### 2019 model features"""
  1357. NO_SPLITS = 4
  1358. pt = Chem.GetPeriodicTable()
  1359. def atom_features_2019():
  1360. getters = {}
  1361. getters["atomic_num"] = lambda atom: atom.GetAtomicNum()
  1362. getters["atomic_num_ohe"] = lambda atom: atom.GetAtomicNum()
  1363. getters["valence"] = lambda atom: atom.GetTotalValence()
  1364. getters["total valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
  1365. getters["defualt valence ohe"] = lambda atom: valence_ohe.transform(np.array([[pt.GetDefaultValence(atom.GetAtomicNum())]]))[0]
  1366. getters["isAromatic"] = lambda atom: atom.GetIsAromatic()
  1367. getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
  1368. getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
  1369. getters["inRing"] = lambda atom: atom.IsInRing()
  1370. return getters
  1371. results_2019=[]
  1372. mol_graphs = scale_graph_data([turn_to_graph(mol, [getter for getter in atom_features_2019().values()]) for mol in training_data if mol])
  1373. splits =chunk_into_n(mol_graphs, NO_SPLITS)
  1374. for split in splits:
  1375. train_data = []
  1376. for s in splits:
  1377. if s!=split:
  1378. train_data+=s
  1379. results_2019.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500)[0])
  1380. """### Varying dataset sizes experiments"""
  1381. #Downloading all datasets
  1382. import requests
  1383. datasets ={
  1384. "c_full":"https://sourceforge.net/projects/nmrshiftdb2/files/data/nmrshiftdb2withsignals.sd/download",
  1385. "c_cdcl3":"https://nmrshiftdb.nmr.uni-koeln.de/nmrshiftdb2cdcl3.sd",
  1386. "c_dmso":"https://nmrshiftdb.nmr.uni-koeln.de/nmrshiftdb2dmso.sd",
  1387. "c_cd3od":"https://nmrshiftdb.nmr.uni-koeln.de/nmrshiftdb2cd3od.sd"
  1388. # nmrshiftdb2withsignals.sdf for cholorine(already downloaded in core block)
  1389. }
  1390. for name, url in datasets.items():
  1391. r = requests.get(url, allow_redirects=True)
  1392. with open(name+".sdf", 'wb') as f:
  1393. f.write(r.content)
  1394. target_atom_number= 9 #6 for carbon 9 for chlorine
  1395. supplier3d= Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine
  1396. NO_SPLITS=4
  1397. model_results_df_2 = pd.DataFrame(columns = ['model', "dataset_size", 'mse',"rmse"]) #Dataframe in which all results are saved in
  1398. all_data = list(supplier3d)
  1399. for size in (100,250,500, len(supplier3d)):
  1400. random.Random(80).shuffle(all_data)
  1401. for i in range(0, len(all_data), size):
  1402. subset=all_data[i:min(i+size, len(all_data))]
  1403. splits =chunk_into_n(subset, NO_SPLITS)
  1404. for split in splits:
  1405. train_data = []
  1406. for s in splits:
  1407. if s!=split:
  1408. train_data+=s
  1409. hose_map={}
  1410. for mol in train_data:
  1411. for trio in get_molecule_hose_and_nmr_shifts(mol, 6, target_atom_number ):
  1412. parts=split_hose(trio[1])
  1413. for i in range(len(parts)):
  1414. hose = "/".join(parts[:(i+1)])
  1415. if hose not in hose_map:
  1416. hose_map[hose] = []
  1417. hose_map[hose].append(trio[2])
  1418. avg_map = {}
  1419. for k,v in hose_map.items():
  1420. avg_map[k] =sum(v)/len(v)
  1421. predictions = []
  1422. labels=[]
  1423. familiar_radius = [] # for visualisation
  1424. missing=0
  1425. for mol in split:
  1426. for trio in get_molecule_hose_and_nmr_shifts(mol, 6,target_atom_number ):
  1427. parts=split_hose(trio[1])
  1428. for i in range(len(parts),0,-1):
  1429. is_match = False
  1430. hose = "/".join(parts[:(i)])
  1431. if hose in hose_map:
  1432. predictions.append(avg_map[hose])
  1433. familiar_radius.append(i)
  1434. labels.append(trio[2])
  1435. is_match=True
  1436. break
  1437. if not is_match:
  1438. missing+=1
  1439. errors=[]
  1440. for i in range(len(predictions)):
  1441. errors.append(labels[i] - predictions[i])
  1442. avg_error = sum(errors)/len(errors)
  1443. #print(avg_error)
  1444. model_results_df_2 = pd.concat([model_results_df_2,pd.DataFrame(data = {'model': ["Hose"], "dataset_size":[size],
  1445. "mae":[sum([abs(err) for err in errors])/len(errors)],
  1446. "rmse": [(sum([err**2 for err in errors])/len(errors))**0.5],
  1447. "std": [(sum([(err-avg_error)**2 for err in errors]) /len(errors))**0.5],
  1448. "missing":[ missing]
  1449. })], ignore_index = True,axis=0, join='outer')
  1450. def get_top_feature_getters():
  1451. feature_getters = {}
  1452. feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
  1453. feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
  1454. feature_getters["covalent radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).covalent_radius # Covalent radius
  1455. feature_getters["vdw radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).vdw_radius # Van der Waals radius
  1456. feature_getters["electrophilicity index"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrophilicity() # Electrophilicity index
  1457. feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
  1458. feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
  1459. feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
  1460. feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
  1461. feature_getters["atomic num"] = lambda atom: atom.GetAtomicNum() # Atomic number
  1462. feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
  1463. feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
  1464. feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
  1465. feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
  1466. feature_getters["valence"] = lambda atom: atom.GetTotalValence()
  1467. return feature_getters
  1468. random.seed(80)
  1469. # Shuffle all of our data, as input data might be sorted
  1470. all_data = list(supplier3d)
  1471. getters = list(get_top_feature_getters().values())
  1472. all_data = scale_graph_data([turn_to_graph(molecule, atom_feature_getters=getters, bond_features= bond_feature_smart_distance_and_rdkit_type) for idx, molecule in enumerate(supplier3d) if molecule])
  1473. random.Random(80).shuffle(all_data)
  1474. for size in ([100,250,500,len(all_data)]):
  1475. method_name="top n descriptors"
  1476. for i in range(0, len(all_data), size):
  1477. print(i, size)
  1478. mol_graphs=all_data[min(i,len(all_data)-size) :min(i+size, len(all_data))]
  1479. splits =chunk_into_n(mol_graphs, NO_SPLITS)
  1480. for split in splits:
  1481. train_data = []
  1482. for s in splits:
  1483. if s!=split:
  1484. train_data+=s
  1485. BATCH_SIZE = 128
  1486. train_loader = DataLoader(train_data, batch_size = BATCH_SIZE)
  1487. test_loader = DataLoader(split, batch_size = BATCH_SIZE)
  1488. model = GNN_FULL_CLASS(6)
  1489. model.train()
  1490. optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=.01)
  1491. mae = torch.nn.L1Loss()
  1492. mse = torch.nn.MSELoss()
  1493. for epoch in range(500):
  1494. tloss = train(model, mae, optimizer, train_loader)
  1495. train_err = evaluate(model, mae, train_loader)
  1496. test_err = evaluate(model, mae, test_loader)
  1497. fixing_epochs=0
  1498. while train_err>2.5 and fixing_epochs <200:
  1499. tloss = train(model, mae, optimizer, train_loader)
  1500. train_err = evaluate(model, mae, train_loader)
  1501. test_err = evaluate(model, mae, test_loader)
  1502. fixing_epochs+=1
  1503. preds = torch.tensor([])
  1504. labels = torch.tensor([])
  1505. with torch.no_grad():
  1506. for batch in test_loader:
  1507. # Forward pass
  1508. labels = torch.cat((labels, batch.y), 0)
  1509. preds = torch.cat((preds,model(batch)),0)
  1510. std = torch.std(torch.subtract(preds[torch.isfinite(labels)], labels[torch.isfinite(labels)]))
  1511. model_results_df_2 = pd.concat([model_results_df_2,pd.DataFrame(data = {'model': [method_name], "dataset_size":[size],
  1512. "mae":[evaluate(model, mae, test_loader)],
  1513. "rmse": [evaluate(model, mse, test_loader)**0.5],
  1514. "std": [float(std)],
  1515. })], ignore_index = True,axis=0, join='outer')
  1516. #model_results_df_2 = model_results_df_2.append({'model': method_name, "dataset_size":size, "mae": evaluate(model, mae, test_loader), "rmse": evaluate(model, mse, test_loader)**0.5, "std":float(std)} , ignore_index = True)