123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929 |
- # -*- coding: utf-8 -*-
- """NMR shift prediction from small data quantities.ipynb
- Automatically generated by Colaboratory.
- Original file is located at
- https://colab.research.google.com/drive/1yKTRjpWzR8T199eCokuJfd9Y5o2oNtPp
- # Data: NMRShiftDB
- ## Reading data set from SD file
- The nmrshiftdb2 data from https://sourceforge.net/projects/nmrshiftdb2/files/data/ has been downloaded and moved to our own domain in order to avoid the https://sourceforge.net/ madness (redirects, cookie banner, etc.) and ensure that colab can be rerun without manual steps or Google Drive dependencies by all collaborators.
- """
- # Install required dependencies
- !pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file
- !pip3 install mendeleev # To to access various features related to atoms
- import torch
- # PyTorch dependencies to represent graph data
- !pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
- !pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
- !pip3 install torch-geometric
- # Use CUDA by default
- if torch.cuda.is_available():
- torch.set_default_tensor_type('torch.cuda.FloatTensor')
- # Downloads the nmrshiftdb2 database if it does not yet exist in our runtime
- !wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine
- #!wget -nc "https://sourceforge.net/projects/nmrshiftdb2/files/data/nmrshiftdb2withsignals.sd/download" #Carbon
- from rdkit import Chem
- supplier3d = Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine
- #supplier3d = Chem.rdmolfiles.SDMolSupplier("download") #Carbon
- print(f"In total there are {len(supplier3d)} molecules")
- from rdkit.Chem import AllChem
- mol= supplier3d[152]
- mol
- """## Graph transformation
- In order to use convolutional networks, we need to transform the molecule data into graphs.
- The atoms themselves will become nodes, and the bonds between the atoms will become the edges.
- """
- import mendeleev
- import torch
- import math
- from torch import tensor
- from torch_geometric.data import Data
- from torch_geometric.loader import DataLoader
- import os
- import numpy as np
- import time
- from sklearn.preprocessing import OneHotEncoder
- import random
- # One hot encoding
- ## Bonds
- bond_idxes = np.array([Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC])
- bond_idxes = bond_idxes.reshape(len(bond_idxes), 1)
- onehot_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")
- onehot_encoder.fit(bond_idxes)
- ## Hybridization
- hybridization_idxes = np.array(list(Chem.HybridizationType.names))
- hybridization_idxes = hybridization_idxes.reshape(len(hybridization_idxes), 1)
- hybridization_ohe = OneHotEncoder(sparse=False)
- hybridization_ohe.fit(hybridization_idxes)
- ## Valence
- valences = np.arange(1, 8);
- valences = valences.reshape(len(valences), 1)
- valence_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
- valence_ohe.fit(valences)
- ## Formal Charge
- fc = np.arange(-1, 1);
- fc = fc.reshape(len(fc), 1)
- fc_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
- fc_ohe.fit(fc)
- ## Atomic number
- atomic_nums = np.array([6,1,7,8,9,17,15,11, 16])
- atomic_nums = atomic_nums.reshape(len(atomic_nums), 1)
- atomic_number_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
- atomic_number_ohe.fit(atomic_nums)
- atomic_number_ohe.transform(np.array([[1]]))
- def get_molecule_solvent(molecule):
- if not molecule:
- return {}
- for key, value in molecule.GetPropsAsDict().items():
- if key.startswith("Solvent"):
- return value
- return None
- solvents={}
- for mol in supplier3d:
- a = get_molecule_solvent(mol)
- if a:
- if a not in solvents:
- solvents[a] = 0
- solvents[a] += 1
- arr = [k for k, v in solvents.items() if v>10]
- solvents = np.array(arr)
- solvents = solvents.reshape(len(solvents), 1)
- solvent_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
- solvent_ohe.fit(solvents)
- #solvent_ohe.transform(np.array([[1]]))
- el_map={}
- def getMendeleevElement(nr):
- if nr not in el_map:
- el_map[nr] = mendeleev.element(nr)
- return el_map[nr]
- def nmr_shift(atom):
- for key, value in atom.GetOwningMol().GetPropsAsDict().items():
- if key.startswith("Spectrum"):
- for shift in value.split('|'):
- x = shift.split(';')
- if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"):
- return float(x[0])
- return float("NaN") # We use NaN for atoms we don't want to predict shifts
- def bond_features(bond):
- onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
- [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
- [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
- distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
- return distance+list(onehot_encoded_bondtype)
- def atom_features_random(atom, molecule=None):
- features = []
- features.append(random.randint(0,9)) # Atomic number
- return features
- def atom_features_update(atom, molecule=None):
- features = []
- me = getMendeleevElement(atom.GetAtomicNum())
- features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
- features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
- features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
- features.append(me.atomic_volume) # Atomic volume
- features.append(me.dipole_polarizability) # Dipole polarizability
- features.append(me.electron_affinity) # Electron affinity
- features.append(me.en_pauling) # Electronegativity
- features.append(me.electrons) # No. of electrons
- features.append(me.neutrons) # No. of neutrons
- features.append(atom.GetChiralTag())
- features.append(atom.IsInRing())
- return features
- def atom_features_2019(atom, molecule=None):
- features.append(atom.GetAtomicNum())
- features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
- features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
- features.extend(valence_ohe.transform(np.array([[atom.GetDefaultValence()]]))[0])
- features.append(atom.GetTotalValence(atom.GetAtomicNum()))
- features.append(atom.GetIsAromatic())
- features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
- features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0])
- features.append(atom.IsInRing())
- return features
- def atom_features_top_single_performers(atom, molecule=None):
- me = getMendeleevElement(atom.GetAtomicNum())
- features = []
- #features.append(atom.GetAtomicNum())
- features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
- features.append(me.atomic_radius or 0) # Atomic volume
- features.append(me.neutrons) # Van der Waals radius
- #features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
- features.append(me.electron_affinity) # Electron affinity
- features.append(me.en_pauling) # Electronegativity
- #features.append(me.electrons) # No. of neutrons
-
- return features
- def atom_features_top_single_performers_with_solvents(atom, molecule=None):
- me = getMendeleevElement(atom.GetAtomicNum())
- features = []
- features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
- #features.append(atom.GetAtomicNum())
- features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
- features.append(me.atomic_radius or 0) # Atomic volume
- features.append(me.neutrons) # Van der Waals radius
- #features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
- features.append(me.electron_affinity) # Electron affinity
- features.append(me.en_pauling) # Electronegativity
- features.append(me.electrons) # No. of neutrons
- return features
- def atom_features(atom, molecule=None):
- me = getMendeleevElement(atom.GetAtomicNum())
- #[x, y, z] = list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx()))
- features = []
- #TBD: Do we need to encode molecule atom itself? Or is atomic number sufficient? One-hot encode?
- features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
- #features.append(atom.GetIsAromatic())
- features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
- features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
- #features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
- features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0])
- features.append(me.atomic_radius or 0) # Atomic radius
- features.append(me.atomic_volume) # Atomic volume
- features.append(me.atomic_weight) # Atomic weight
- features.append(me.covalent_radius) # Covalent radius
- features.append(me.vdw_radius) # Van der Waals radius
- features.append(me.dipole_polarizability) # Dipole polarizability
- features.append(me.electron_affinity) # Electron affinity
- features.append(me.electrophilicity()) # Electrophilicity index
- features.append(me.en_pauling) # Electronegativity
- features.append(me.electrons) # No. of electrons
- features.append(me.neutrons) # No. of neutrons
- #features.append(x) # X coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
- #features.append(y) # Y coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
- #features.append(z) # Z coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
- #features.append(0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge'))) #partial charges
- features.append(atom.GetChiralTag())
- features.append(atom.IsInRing())
- return features
- def convert_to_graph(molecule, atom_feature_constructor=atom_features):
- #Chem.rdPartialCharges.ComputeGasteigerCharges(molecule)
- node_features = [atom_feature_constructor(atom, molecule) for atom in molecule.GetAtoms()]
- node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()]
- edge_features = [bond_features(bond) for bond in molecule.GetBonds()]
- edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()]
- # Bonds are not directed, so lets add the missing pair to make the graph undirected
- edge_index.extend([reversed(bond) for bond in edge_index])
- edge_features.extend(edge_features)
-
- # Some node_features had null values in carbon data and then the long graph compilation process was stopped.
- if any(None in sublist for sublist in node_features):
- return None
- return Data(
- x=tensor(node_features, dtype=torch.float),
- edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(),
- edge_attr=tensor(edge_features, dtype=torch.float),
- y=tensor([[t] for t in node_targets], dtype=torch.float)
- )
- def scale_graph_data(latent_graph_list, scaler=None):
- if scaler:
- node_mean, node_std, edge_mean, edge_std = scaler
- print(f"Using existing scaler: {scaler}")
- else:
- #Iterate through graph list to get stacked NODE and EDGE features
- node_stack=[]
- edge_stack=[]
- for g in latent_graph_list:
-
- node_stack.append(g.x) #Append node features
- edge_stack.append(g.edge_attr) #Append edge features
- node_cat=torch.cat(node_stack,dim=0)
- edge_cat=torch.cat(edge_stack,dim=0)
- node_mean=node_cat.mean(dim=0)
- node_std=node_cat.std(dim=0,unbiased=False)
- edge_mean=edge_cat.mean(dim=0)
- edge_std=edge_cat.std(dim=0,unbiased=False)
- #Apply zero-mean, unit variance scaling, append scaled graph to list
- latent_graph_list_sc=[]
- for g in latent_graph_list:
- x_sc=g.x-node_mean
- x_sc/=node_std
- ea_sc=g.edge_attr-edge_mean
- ea_sc/=edge_std
- ea_sc=torch.nan_to_num(ea_sc, posinf=1.0)
- x_sc=torch.nan_to_num(x_sc, posinf=1.0)
- temp_graph=Data(x=x_sc,edge_index=g.edge_index,edge_attr=ea_sc, y=g.y)
- latent_graph_list_sc.append(temp_graph)
- scaler= (node_mean,node_std,edge_mean,edge_std)
- return latent_graph_list_sc,scaler
- """# Graph Neural Network
- ## Separation into training set and test set
- """
- TRAIN_TEST_SPLIT = 0.8
- all_data = list(supplier3d)
- random.Random(80).shuffle(all_data)
- train_data =all_data[:int(TRAIN_TEST_SPLIT * len(supplier3d))]
- test_data =all_data[int(TRAIN_TEST_SPLIT * len(supplier3d)):]
- #TODO: Use training data scaler on test data.
- train_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(train_data) if molecule])
- test_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(test_data) if molecule], scaler=scaler)
- #all_data = [convert_to_graph(molecule) for idx, molecule in enumerate(supplier3d) if molecule and idx not in[24,25,30,31,32]]
- print(f"Converted {len(supplier3d)} molecules to {len(train_graphs) + len(test_graphs)} graphs")
- print(f"Found {sum([sum([1 for shift in graph.y if not math.isnan(shift[0])]) for graph in train_graphs+test_graphs])} individual NMR shifts")
- """##2023 model"""
- import torch
- from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm
- from torch_scatter import scatter_mean, scatter_add
- from torch_geometric.nn import MetaLayer
- from torch_geometric.data import Batch
- NO_GRAPH_FEATURES=128
- ENCODING_NODE=64
- ENCODING_EDGE=32
- HIDDEN_NODE=128
- HIDDEN_EDGE=64
- HIDDEN_GRAPH=128
- def init_weights(m):
- if type(m) == torch.nn.Linear:
- torch.nn.init.xavier_uniform_(m.weight)
- m.bias.data.fill_(0.01)
- class EdgeModel(torch.nn.Module):
- def __init__(self):
- super(EdgeModel, self).__init__()
- self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_EDGE)).apply(init_weights)
- def forward(self, src, dest, edge_attr, u, batch):
- # source, target: [E, F_x], where E is the number of edges.
- # edge_attr: [E, F_e]
- # u: [B, F_u], where B is the number of graphs.
- # batch: [E] with max entry B - 1.
- out = torch.cat([src, dest, edge_attr], 1)
- return self.edge_mlp(out)
- class NodeModel(torch.nn.Module):
- def __init__(self):
- super(NodeModel, self).__init__()
- self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(),
- LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
- LazyLinear(HIDDEN_NODE)).apply(init_weights)
- self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
- LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_NODE)).apply(init_weights)
- def forward(self, x, edge_index, edge_attr, u, batch):
- # x: [N, F_x], where N is the number of nodes.
- # edge_index: [2, E] with max entry N - 1.
- # edge_attr: [E, F_e]
- # u: [B, F_u]
- # batch: [N] with max entry B - 1.
- row, col = edge_index
- out = torch.cat([x[row], edge_attr], dim=1)
- out = self.node_mlp_1(out)
- out = scatter_add(out, col, dim=0, dim_size=x.size(0))
- out = torch.cat([x, out], dim=1)
- return self.node_mlp_2(out)
- class GlobalModel(torch.nn.Module):
- def __init__(self):
- super(GlobalModel, self).__init__()
- self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
- LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights)
- def forward(self, x, edge_index, edge_attr, u, batch):
- # x: [N, F_x], where N is the number of nodes.
- # edge_index: [2, E] with max entry N - 1.
- # edge_attr: [E, F_e]
- # u: [B, F_u]
- # batch: [N] with max entry B - 1.
- row,col=edge_index
- node_aggregate = scatter_add(x, batch, dim=0)
- edge_aggregate = scatter_add(edge_attr, batch[col], dim=0)
- out = torch.cat([node_aggregate, edge_aggregate], dim=1)
- return self.global_mlp(out)
- class GNN_FULL_CLASS(torch.nn.Module):
- def __init__(self, NO_MP):
- super(GNN_FULL_CLASS,self).__init__()
- #Meta Layer for Message Passing
- self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
-
- #Edge Encoding MLP
- self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
- LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
- LazyLinear(ENCODING_EDGE)).apply(init_weights)
- self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_NODE)).apply(init_weights)
- self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10),
- LazyBatchNorm1d(),
- LazyLinear(HIDDEN_NODE), LeakyReLU(),
- LazyBatchNorm1d(),
- LazyLinear(1)).apply(init_weights)
-
- self.no_mp = NO_MP
-
- def forward(self,dat):
- #Extract the data from the batch
- x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch
- # Embed the node and edge features
- enc_x = self.encoding_node(x)
- enc_ea = self.encoding_edge(ea)
-
- #Create the empty molecular graphs for feature extraction, graph level one
- u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float)
- #Message-Passing
- for _ in range(self.no_mp):
- enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc)
-
- targs = self.mlp_last(enc_x)
- return targs
- #Additional helpful methods
- def init_model(NO_MP, lr, wd):
- # Model
- NO_MP = NO_MP
- model = GNN_FULL_CLASS(NO_MP)
- # Optimizer
- LEARNING_RATE = lr
- optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd)
- # Criterion
- #criterion = torch.nn.MSELoss()
- criterion = torch.nn.L1Loss()
- return model, optimizer, criterion
- def train(model, criterion, optimizer, loader):
- loss_sum = 0
- for batch in loader:
- # Forward pass and gradient descent
- labels = batch.y
- predictions = model(batch)
- loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
- # Backpropagation
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- loss_sum += loss.item()
- return loss_sum/len(loader)
- def evaluate(model, criterion, loader):
- loss_sum = 0
- with torch.no_grad():
- for batch in loader:
- # Forward pass
- labels = batch.y
- predictions = model(batch)
- loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
- loss_sum += loss.item()
- return loss_sum/len(loader)
- def chunk_into_n(lst, n):
- size = math.ceil(len(lst) / n)
- return list(
- map(lambda x: lst[x * size:x * size + size],
- list(range(n)))
- )
- """## Training"""
- VALIDATION_SPLITS = 4
- EPOCHS = 500
- BATCH_SIZE = 128
- splits =chunk_into_n(train_graphs, VALIDATION_SPLITS)
- split_errors=[]
- for idx, split in enumerate(splits):
- split_train_data = []
- for s in splits:
- if s!=split:
- split_train_data+=s
- train_loader = DataLoader(split_train_data, batch_size = BATCH_SIZE)
- test_loader = DataLoader(split, batch_size = BATCH_SIZE)
- model, optimizer, criterion = init_model(6,0.001,0.1)
- loss_list = []
- train_err_list = []
- test_err_list = []
- model.train()
- print(f"Split {idx+1}. Training/test split size:{len(split_train_data)}/{len(split)}")
- for epoch in range(EPOCHS):
- tloss = train(model, criterion, optimizer, train_loader)
- train_err = evaluate(model, criterion, train_loader)
- test_err = evaluate(model, criterion, test_loader)
- loss_list.append(tloss)
- train_err_list.append(train_err)
- test_err_list.append(test_err)
- print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
- train_err, test_err))
- extra_epochs=0
- #Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima.
- while extra_epochs<200 and tloss>2.5:
- tloss = train(model, criterion, optimizer, train_loader)
- train_err = evaluate(model, criterion, train_loader)
- test_err = evaluate(model, criterion, test_loader)
- loss_list.append(tloss)
- train_err_list.append(train_err)
- test_err_list.append(test_err)
- extra_epochs+1
- print("\n")
- split_errors.append(test_err)
- print(f"Split errors: {split_errors} with average error {sum(split_errors) / VALIDATION_SPLITS}")
- """## Evaluation on test set"""
- train_loader = DataLoader(train_graphs, batch_size = BATCH_SIZE)
- test_loader = DataLoader(test_graphs, batch_size = BATCH_SIZE)
- model, optimizer, criterion = init_model(6,0.001,0.1)
- loss_list = []
- train_err_list = []
- test_err_list = []
- model.train()
- for epoch in range(EPOCHS):
- tloss = train(model, criterion, optimizer, train_loader)
- train_err = evaluate(model, criterion, train_loader)
- test_err = evaluate(model, criterion, test_loader)
- loss_list.append(tloss)
- train_err_list.append(train_err)
- test_err_list.append(test_err)
- print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
- train_err, test_err))
- extra_epochs=0
- #Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima.
- while extra_epochs<200 and tloss>2.5:
- tloss = train(model, criterion, optimizer, train_loader)
- train_err = evaluate(model, criterion, train_loader)
- test_err = evaluate(model, criterion, test_loader)
- loss_list.append(tloss)
- train_err_list.append(train_err)
- test_err_list.append(test_err)
- extra_epochs+1
- print("\n")
- #print(len(test_loader))
- #print( criterion, test_loader)
- evaluate(model, criterion, test_loader)
- single = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features_top_single_performers_with_solvents) for molecule in [supplier3d[32], supplier3d[32]] if molecule])
- evaluate(model, criterion, DataLoader(single, batch_size = 2))
- """## Visualisations & Analysis"""
- test_loader = DataLoader(testing_data, batch_size = 128)
- with torch.no_grad():
- for batch in test_loader:
- # Forward pass
- labels = batch.y
- predictions = model(batch)
- a =torch.sub(predictions, labels)
- for a,b,c,d in zip(predictions, labels, batch.x, a):
- if d>50:
- print(a,b,c[0],d)
- loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
- print(loss)
- #loss_sum += loss.item()
- #return loss_sum/len(loader)
- import os
- import matplotlib.pyplot as plt
- def plot_loss(loss_list):
- fig, axs = plt.subplots()
- axs.plot(range(1, len(loss_list) + 1), loss_list, label="Loss")
- axs.set_xlabel("Epoch")
- axs.legend(loc=3)
- axs.set_yscale('log')
- axs.legend()
- def plot_error(train_err_list, test_err_list):
- fig, axs = plt.subplots()
- axs.plot(range(1, len(train_err_list) + 1), train_err_list, label="Train Err")
- axs.plot(range(1, len(test_err_list) + 1), test_err_list, label="Test Err")
- axs.set_xlabel("Epoch")
- axs.legend(loc=3)
- axs.set_yscale('log')
- axs.legend()
- plot_loss(loss_list)
- plot_error(train_err_list, test_err_list)
- """We can draw molecules easily for 2D coordinates, so let's download the dataset containing 2D coordinates."""
- errs=[]
- for i in [100,250,500,len(all_data)]:
- SPLIT = 0.75 # Let's 75% for training and 25% for validation
- # Shuffle all of our data, as input data might be sorted
- random.Random(46).shuffle(all_data)
- # Training set
- training_data = all_data[:int(i*SPLIT)]
- # Validation set
- testing_data = all_data[int(i*SPLIT):i]
- BATCH_SIZE = 128
- train_loader = DataLoader(training_data, batch_size = BATCH_SIZE)
- test_loader = DataLoader(testing_data, batch_size = BATCH_SIZE)
- print(f"Number of train batches: {len(train_loader)}")
- print(f"Number of test batches: {len(test_loader)}")
- # Loading model from the paper
- # Model
- NO_MP = 7
- model = GNN_FULL_CLASS(NO_MP)
- # Optimizer
- LEARNING_RATE = 0.05
- optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=.005)
- # Criterion
- #criterion = torch.nn.MSELoss()
- criterion = torch.nn.L1Loss()
- EPOCHS = 500
- loss_list = []
- train_err_list = []
- test_err_list = []
- model.train()
- for epoch in range(EPOCHS):
- tloss = train(model, criterion, optimizer, train_loader)
- train_err = evaluate(model, criterion, train_loader)
- test_err = evaluate(model, criterion, test_loader)
- loss_list.append(tloss)
- train_err_list.append(train_err)
- test_err_list.append(test_err)
- print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
- train_err, test_err))
- print("\n")
- test_err = evaluate(model, criterion, test_loader)
- print(test_err)
- errs.append(test_err)
- print(errs)
- """# Existing methods
- Looking at existing methods.
- * Baseline for model comparison
- ## Hose
- ### Downloads & Initialization
- """
- from rdkit import Chem
- from rdkit.Chem import AllChem
- import matplotlib.pyplot as plt
- !pip install git+https://github.com/Ratsemaat/HOSE_code_generator
- import hosegen
- hg = hosegen.HoseGenerator()
- """### Preparation"""
- def get_hose_code(molecule, distance, atom_idx):
- return hg.get_Hose_codes(molecule, atom_idx, max_radius=distance)
- get_hose_code(supplier3d[1],5,0)
- def get_atom_indexes(molecul, element_nr):
- indexes=[]
- if not molecul:
- return []
- for atom in molecul.GetAtoms():
- if atom.GetAtomicNum() == element_nr:
- indexes.append(atom.GetIdx())
- return indexes
- mol = Chem.MolFromSmiles('C=CN(CC)CF')
- print(get_atom_indexes(mol, 6))
- print(get_atom_indexes(mol, 7))
- print(get_atom_indexes(mol, 9))
- def get_molecule_hose_codes(molecule, radius, atom_idx):
- indexes = get_atom_indexes(molecule, 9)
- hoses=[]
- for idx in indexes:
- hoses.append((idx,get_hose_code(molecule, radius, idx)))
- return hoses
- print(get_molecule_hose_codes(supplier3d[1], 5,9))
- import random
- SPLIT = 0.75 # Let's 75% for training and 25% for validation
- random.seed(42)
- # Shuffle all of our data, as input data might be sorted
- train_idx = random.sample([i for i in range(len(supplier3d))],math.ceil(len(supplier3d)*SPLIT))
- test_idx = set([i for i in range(len(supplier3d))]).difference(set(train_idx))
- def get_molecule_nmr_shifts(molecule):
- if not molecule:
- return {}
- map={}
- for key, value in molecule.GetPropsAsDict().items():
- if key.startswith("Spectrum"):
- for shift in value.split('|'):
- x = shift.split(';')
- if (len(x) == 3):
- map[x[2]] = float(x[0])
- return map
- print(get_molecule_nmr_shifts(supplier3d[1]))
- def get_molecule_hose_and_nmr_shifts(molecule, distance, element_nr):
- arr=[]
- idxs = set(get_atom_indexes(molecule, element_nr))
- shifts = get_molecule_nmr_shifts(molecule)
- #Only look at atoms where shift is defiend
- arr2 = set([int(k) for k in shifts.keys()])
- idxs = idxs.intersection(arr2)
- for idx in idxs:
- hoses = get_hose_code(molecule, distance, idx)
- arr.append((idx, hoses, shifts[str(idx)]))
- return arr
- print(get_molecule_hose_and_nmr_shifts(supplier3d[1], 5, 9))
- def split_hose(hose):
- return hose.replace('(',' ').replace('/',' ').replace(')',' ').split()
- """### Training"""
- map={}
- for id in train_idx:
- for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
- parts=split_hose(trio[1])
- for i in range(len(parts)):
- hose = "/".join(parts[:(i+1)])
- if hose not in map:
- map[hose] = []
- map[hose].append(trio[2])
- print(map)
- avg_map = {}
- for k,v in map.items():
- avg_map[k] =sum(v)/len(v)
- """### Evaluation"""
- predictions = []
- labels=[]
- familiar_radius = [] # for visualisation
- for id in test_idx:
- for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
- parts=split_hose(trio[1])
- for i in range(len(parts),0,-1):
- hose = "/".join(parts[:(i)])
- if hose in map:
- predictions.append(avg_map[hose])
- familiar_radius.append(i)
- labels.append(trio[2])
- break
- print(labels)
- print(predictions)
- errors=[]
- for i in range(len(predictions)):
- errors.append(abs(labels[i] - predictions[i]))
- print(sum(errors)/len(errors))
- """### Visualisations & Analysis"""
- plt.hist(familiar_radius, bins=[0.5, 1.5,2.5,3.5,4.5,5.5,6.5], align="mid")
- plt.title("Longest exact HOSE code length in training set")
- plt.show()
- errors = [[] for i in range(6)]
- for i in range(len(predictions)):
- errors[familiar_radius[i]-1].append(abs(labels[i] - predictions[i]))
- for i in range(len(errors)):
- errors[i] = sum(errors[i])/len(errors[i])
- plt.plot([i for i in range(6)], errors)
- plt.title("Average error rate of HOSE code")
- plt.show()
- SPLIT = 0.75 # Let's 75% for training and 25% for validation
- hose_error_measurements=[]
- for i in range(10):
- err_hose=[]
- random.Random(42).shuffle(all_data)
- for size in [100,250,500,len(all_data)]:
- train_idx = random.sample([i for i in range(size)],math.ceil(size*SPLIT))
- test_idx = set([i for i in range(size)]).difference(set(train_idx))
- map={}
- for id in train_idx:
- for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
- parts=split_hose(trio[1])
- for i in range(len(parts)):
- hose = "/".join(parts[:(i+1)])
- if hose not in map:
- map[hose] = []
- map[hose].append(trio[2])
-
- avg_map = {}
- for k,v in map.items():
- avg_map[k] =sum(v)/len(v)
-
- predictions = []
- labels=[]
- familiar_radius = [] # for visualisation
- for id in test_idx:
- for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
- parts=split_hose(trio[1])
- for i in range(len(parts),0,-1):
- hose = "/".join(parts[:(i)])
- if hose in map:
- predictions.append(avg_map[hose])
- familiar_radius.append(i)
- labels.append(trio[2])
- break
- errors=[]
- for i in range(len(predictions)):
- errors.append(abs(labels[i] - predictions[i]))
- err_hose.append(sum(errors)/len(errors))
- hose_error_measurements.append(err_hose)
- sum_100=0
- sum_250=0
- sum_500=0
- sum_all=0
- for i in hose_error_measurements:
- sum_100 += i[0]
- sum_250 += i[1]
- sum_500 += i[2]
- sum_all += i[3]
- avg_err_hose = [sum_100/len(hose_error_measurements),sum_250/len(hose_error_measurements),sum_500/len(hose_error_measurements),sum_all/len(hose_error_measurements)]
- err_gnn = [66.1449966430664, 30.329145431518555, 9.740877151489258, 10.084354400634766]
- print(avg_err_hose)
- plt.plot([100,250,500,970],err_gnn, label="GNN model", color="red")
- plt.plot([100,250,500,970],avg_err_hose, label="HOSE code model", color="green")
- plt.xlabel("Number of spectra")
- plt.ylabel("ppm")
- plt.title("Now. Error in ppm in relation to used examples.")
- plt.legend()
- plt.show()
- """# Results
- Since obtaining the results
- ## Edge and node features
- Testing difderent descriptors with fixed hyperparameters(weight decay=0.005, learning rate=0.0003, NO_MP=7). When testing node features the edge_features were also fixed and when testing edge features then node features were fixed. So in other words feature results were obtained by only playing with corresponding values. Each feature was tested using 4-fold crossvalidation on train data (80% of whole dataset).
- """
- node_features_results = {'ohe atomic number': [11.206879615783691, 10.606162071228027, 10.678353309631348, 14.426544666290283],
- 'isAromatic': [20.246356964111328, 21.289775848388672, 22.023090362548828, 21.481626510620117],
- 'hyb ohe': [14.719367980957031, 14.110363483428955, 17.14181661605835, 17.19037103652954],
- 'valence ohe': [16.09677743911743, 16.301819801330566, 18.414384841918945, 14.627246856689453],
- 'valence': [15.74299955368042, 15.868663311004639, 17.854907989501953, 14.021881580352783],
- 'inRing': [20.925933837890625, 19.979466438293457, 20.020546913146973, 19.80652141571045],
- 'hybridization': [15.455873012542725, 16.48819398880005, 16.359801769256592, 17.977892875671387],
- 'atomic num': [10.075446605682373, 10.49792766571045, 11.835340023040771, 12.882370471954346],
- 'atomic charge': [20.686185836791992, 22.381511688232422, 27.067391395568848, 20.001076698303223],
- 'atomic radius': [11.015857696533203, 10.15432596206665, 10.100831508636475, 13.576239109039307],
- 'atomic volume': [11.949167728424072, 8.738459587097168, 10.427918434143066, 13.632889747619629],
- 'atomic weight': [10.183413028717041, 9.64014196395874, 11.397083282470703, 11.427015781402588],
- 'covalent radius': [10.514074325561523, 9.394641399383545, 10.941582679748535, 11.956562995910645],
- 'vdw radius': [9.800463676452637, 9.21529245376587, 10.46424388885498, 13.804336071014404],
- 'dipole polarizability': [10.734958171844482, 10.39086389541626, 14.207355976104736, 14.054275035858154],
- 'electron affinity': [12.086753368377686, 9.388242721557617, 9.779114246368408, 14.187011241912842],
- 'electrophilicity index': [11.224877834320068, 8.167922496795654, 10.476778030395508, 13.519041061401367],
- 'electronegativity': [9.936093807220459, 10.196239948272705, 9.894521236419678, 14.182630062103271],
- 'electrons': [10.851109981536865, 9.980653285980225, 9.819206714630127, 14.225136280059814],
- 'neutrons': [9.413500308990479, 9.620219707489014, 9.220077753067017, 13.0865478515625],
- 'cooridates': [24.324893951416016, 23.782126426696777, 27.821096420288086, 22.696171760559082],
- 'formal charge': [21.54022789001465, 21.80484676361084, 25.49346923828125, 21.4456787109375],
- 'formal charge ohe': [19.7533016204834, 24.704838752746582, 23.275280952453613, 21.496562004089355],
- 'chiral tag': [22.99411392211914, 22.16276741027832, 26.671720504760742, 19.622788429260254],
- 'random': [27.85106086730957, 26.663583755493164, 27.1346435546875, 24.93829917907715]
- }
- bond_features_results = {'smart,atoms': [10.044915676116943, 8.644615173339844, 10.853336811065674, 13.183335781097412],
- 'pure,rdkit': [11.945857048034668, 9.031915664672852, 10.515658855438232, 13.471199989318848],
- 'pure,atoms': [11.971851825714111, 9.971860885620117, 11.541802883148193, 13.614596843719482],
- 'smart,rdkit': [9.997193813323975, 9.00916337966919, 9.385982513427734, 12.707754611968994]}
- import pandas as pd
- import seaborn as sns
- import matplotlib.pyplot as plt
- df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
- df2 = pd.DataFrame(columns = ['descriptor', 'accuracy'])
- for key, value in bond_features_results.items():
- temp_df = pd.DataFrame({'descriptor': [key]*len(value),
- 'accuracy': value})
- df2 = df2.append(temp_df, ignore_index = True)
- for key, value in node_features_results.items():
- temp_df = pd.DataFrame({'descriptor': [key]*len(value),
- 'accuracy': value})
- df = df.append(temp_df, ignore_index = True)
- #print(df)
- my_order = df.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index
- my_order_2 = df2.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index
- fig, ax = plt.subplots(2, figsize=(8, 20))
- sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax[0], order=my_order)
- sns.violinplot(data=df2, y="descriptor", x= "accuracy", ax=ax[1], order=my_order_2)
- plt.show()
- """## Hyperparameter selection"""
- !wget -nc https://www.dropbox.com/s/k6kgqag3daqv90y/hp_results.csv
- with open("hp_results.csv") as f:
- hp_results = pd.read_csv(f)
- with pd.option_context('display.max_rows', None,
- 'display.max_columns', None,
- 'display.precision', 3,
- ):
- print(hp_results)
- print(hp_results.groupby(["m", "lr", "wd"])['accuracy'].mean())
- print(hp_results.groupby(["m"])['accuracy'].mean())
- print(hp_results.groupby(["lr"])['accuracy'].mean())
- print(hp_results.groupby(["wd"])['accuracy'].mean())
- """## Different models(collections of features)
- 2019 <br>
- Top N Features
- ### 2019 model feautures
- """
- paper_2019_features_results = [8.671473503112793, 14.09905195236206, 12.203859806060791, 12.20280122756958]
- """### Top N features"""
- n_feature_results = {"['neutrons']": [10.858759880065918, 7.9520745277404785, 10.061083316802979, 13.083988666534424], "['neutrons', 'atomic weight']": [10.279156684875488, 10.209141254425049, 9.95799732208252, 13.340500354766846], "['neutrons', 'atomic weight', 'covalent radius']": [10.954660892486572, 8.685662269592285, 8.889871597290039, 11.358663558959961], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius']": [11.0108060836792, 10.469686031341553, 10.534903049468994, 12.8761568069458], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index']": [10.096431255340576, 9.76186466217041, 9.827410221099854, 12.698175430297852], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity']": [10.908030033111572, 8.09628176689148, 9.25707221031189, 11.68080759048462], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume']": [9.88636827468872, 9.501356601715088, 9.424680233001709, 12.281991004943848], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius']": [12.68621015548706, 9.14622163772583, 10.40596866607666, 11.856515407562256], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons']": [10.266335487365723, 9.108633041381836, 11.96642780303955, 13.67209243774414], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num']": [10.689188957214355, 9.288278579711914, 10.050706624984741, 14.055354118347168], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity']": [10.014832019805908, 9.91117525100708, 9.540995359420776, 13.940445899963379], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number']": [10.295331001281738, 9.557802677154541, 11.178434371948242, 10.574723720550537], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability']": [10.669241905212402, 9.187427043914795, 9.547076225280762, 12.491205215454102], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe']": [10.314173221588135, 10.875812530517578, 12.00346040725708, 11.990623950958252], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence']": [9.84071683883667, 9.872222900390625, 8.903747081756592, 11.116456031799316], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe']": [11.037650108337402, 10.227884292602539, 11.569035053253174, 15.139636516571045], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization']": [9.662659645080566, 10.27301549911499, 10.229931831359863, 12.425177097320557], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing']": [12.145603656768799, 9.442389011383057, 10.04677677154541, 13.444748401641846], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic']": [10.830190181732178, 9.220112323760986, 9.297557353973389, 13.27143907546997], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe']": [10.53205156326294, 10.221848964691162, 10.479277610778809, 11.17846155166626], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge']": [10.679495811462402, 9.242655277252197, 10.48193883895874, 12.0442533493042], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge']": [9.10008192062378, 9.387939929962158, 10.72908878326416, 12.20565414428711], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag']": [11.04087209701538, 8.146693706512451, 10.116810321807861, 10.707176208496094], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag', 'cooridates']": [11.374996185302734, 10.41971206665039, 11.041788578033447, 13.181884288787842]}
- df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
- for key, value in n_feature_results.items():
- temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))),
- 'accuracy': value})
- df = df.append(temp_df, ignore_index = True)
- fig, ax = plt.subplots(1,figsize=(15, 15))
- sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax)
- plt.show()
- print(df.groupby(by=["descriptor"])["accuracy"].mean().sort_values())
- """### Comparison"""
- df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
- temp_df = pd.DataFrame({'descriptor': '2019 model features', 'accuracy': paper_2019_features_results})
- df = df.append(temp_df, ignore_index = True)
- for key, value in n_feature_results.items():
- temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))),
- 'accuracy': value})
- df = df.append(temp_df, ignore_index = True)
- fig, ax = plt.subplots(1,figsize=(15, 15))
- comparsion_df = df.loc[(df['descriptor'] =='15') | (df['descriptor'] =='2019 model features')]
- sns.violinplot(data=comparsion_df, y="descriptor", x= "accuracy", ax=ax)
- plt.show()
- """## Varying dataset sizes
- ### Flourine
- """
- !wget -nc https://www.dropbox.com/s/m1alwawphb1chbc/fluorine_results.csv
- with open("fluorine_results.csv") as f:
- df = pd.read_csv(f)
- df.groupby(["model","dataset_size"]).mean()
- fig, ax = plt.subplots(figsize=(8, 10))
- sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
- plt.show()
- """### Carbon with choloroform as solvent
- """
- !wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv
- with open("chloroform_results.csv") as f:
- df = pd.read_csv(f)
- print(df.groupby(["model","dataset_size"]).mean())
- print(df)
- fig, ax = plt.subplots(figsize=(8, 10))
- sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
- plt.show()
- """### Carbon with methanol as solvent """
- !wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv
- with open("methanol_results.csv") as f:
- df = pd.read_csv(f)
- print(df.groupby(["model","dataset_size"]).mean())
- fig, ax = plt.subplots(figsize=(8, 10))
- sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
- plt.show()
- """### Carbon with dmso as solvent """
- !wget -nc https://www.dropbox.com/s/gdwv8pspayssk27/dmso_results.csv
- with open("dmso_results.csv") as f:
- df = pd.read_csv(f)
- print(df.groupby(["model","dataset_size"]).mean())
- fig, ax = plt.subplots(figsize=(8, 10))
- sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
- plt.show()
- """# Reproducing results
- Functions that provided in aforementioned results.
- To repeat any experiment you must:
- * Run the core block to load necessary helper methods.
- * Run the corresponding experiment block to obtain the results.
- ## Core block
- """
- TEST_SPLIT= 0.8
- # Install required dependencies
- !pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file
- !pip3 install mendeleev # To to access various features related to atoms
- import torch
- # PyTorch dependencies to represent graph data
- !pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
- !pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
- !pip3 install torch-geometric
- # Use CUDA by default
- if torch.cuda.is_available():
- torch.set_default_tensor_type('torch.cuda.FloatTensor')
- # Downloads the nmrshiftdb2 database if it does not yet exist in our runtime
- !wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine
- from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm
- from torch_scatter import scatter_mean, scatter_add
- from torch_geometric.nn import MetaLayer
- from torch_geometric.data import Batch,Data
- from torch import tensor
- from torch_geometric.loader import DataLoader
- import numpy as np
- from rdkit import Chem
- from sklearn.preprocessing import OneHotEncoder
- import random
- import math
- import mendeleev
- import pandas as pd
- NO_GRAPH_FEATURES=128
- ENCODING_NODE=64
- ENCODING_EDGE=32
- HIDDEN_NODE=128
- HIDDEN_EDGE=64
- HIDDEN_GRAPH=128
- def init_weights(m):
- if type(m) == torch.nn.Linear:
- torch.nn.init.xavier_uniform_(m.weight)
- m.bias.data.fill_(0.01)
- class EdgeModel(torch.nn.Module):
- def __init__(self):
- super(EdgeModel, self).__init__()
- self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_EDGE)).apply(init_weights)
- def forward(self, src, dest, edge_attr, u, batch):
- # source, target: [E, F_x], where E is the number of edges.
- # edge_attr: [E, F_e]
- # u: [B, F_u], where B is the number of graphs.
- # batch: [E] with max entry B - 1.
- out = torch.cat([src, dest, edge_attr], 1)
- return self.edge_mlp(out)
- class NodeModel(torch.nn.Module):
- def __init__(self):
- super(NodeModel, self).__init__()
- self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(),
- LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
- LazyLinear(HIDDEN_NODE)).apply(init_weights)
- self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
- LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_NODE)).apply(init_weights)
- def forward(self, x, edge_index, edge_attr, u, batch):
- # x: [N, F_x], where N is the number of nodes.
- # edge_index: [2, E] with max entry N - 1.
- # edge_attr: [E, F_e]
- # u: [B, F_u]
- # batch: [N] with max entry B - 1.
- row, col = edge_index
- out = torch.cat([x[row], edge_attr], dim=1)
- out = self.node_mlp_1(out)
- out = scatter_add(out, col, dim=0, dim_size=x.size(0))
- out = torch.cat([x, out], dim=1)
- return self.node_mlp_2(out)
- class GlobalModel(torch.nn.Module):
- def __init__(self):
- super(GlobalModel, self).__init__()
- self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
- LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights)
- def forward(self, x, edge_index, edge_attr, u, batch):
- # x: [N, F_x], where N is the number of nodes.
- # edge_index: [2, E] with max entry N - 1.
- # edge_attr: [E, F_e]
- # u: [B, F_u]
- # batch: [N] with max entry B - 1.
- row,col=edge_index
- node_aggregate = scatter_add(x, batch, dim=0)
- edge_aggregate = scatter_add(edge_attr, batch[col], dim=0)
- out = torch.cat([node_aggregate, edge_aggregate], dim=1)
- return self.global_mlp(out)
- class GNN_FULL_CLASS(torch.nn.Module):
- def __init__(self, NO_MP):
- super(GNN_FULL_CLASS,self).__init__()
- #Meta Layer for Message Passing
- self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
-
- #Edge Encoding MLP
- self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
- LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
- LazyLinear(ENCODING_EDGE)).apply(init_weights)
- self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
- LazyLinear(ENCODING_NODE)).apply(init_weights)
- self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10),
- LazyBatchNorm1d(),
- LazyLinear(HIDDEN_NODE), LeakyReLU(),
- LazyBatchNorm1d(),
- LazyLinear(1)).apply(init_weights)
-
- self.no_mp = NO_MP
-
- def forward(self,dat):
- #Extract the data from the batch
- x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch
- # Embed the node and edge features
- enc_x = self.encoding_node(x)
- enc_ea = self.encoding_edge(ea)
-
- #Create the empty molecular graphs for feature extraction, graph level one
- u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float)
- #Message-Passing
- for _ in range(self.no_mp):
- enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc)
-
- targs = self.mlp_last(enc_x)
- return targs
- el_map={}
- def atom_features_default():
- feature_getters = {}
- feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
- feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
- feature_getters["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
- feature_getters["hybridization"] = lambda atom: atom.GetHybridization()
- feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
- feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
- feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
- feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
- feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
- feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
- feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
- feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
- feature_getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
- #feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges
- feature_getters["chiral tag"] = lambda atom: atom.GetChiralTag()
- return feature_getters
- def bond_feature_smart_distance_and_rdkit_type(bond):
- onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
- [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
- [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
- ex_dist = getNaiveBondLength(bond)
- distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance
- return distance+ list(onehot_encoded_bondtype)
- def getMendeleevElement(nr):
- if nr not in el_map:
- el_map[nr] = mendeleev.element(nr)
- return el_map[nr]
- def nmr_shift(atom):
- for key, value in atom.GetOwningMol().GetPropsAsDict().items():
- if key.startswith("Spectrum"):
- for shift in value.split('|'):
- x = shift.split(';')
- if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"):
- return float(x[0])
- return float("NaN") # We use NaN for atoms we don't want to predict shifts
- def bond_features_distance_only(bond):
- #onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
- [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
- [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
- distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
- return distance#+list(onehot_encoded_bondtype)
- def flatten(l):
- ret=[]
- for el in l:
- if isinstance(el, list) or isinstance(el, np.ndarray):
- ret.extend(el)
- else:
- ret.append(el)
- return ret
- def turn_to_graph (molecule, atom_feature_getters= atom_features_default().values(), bond_features=bond_features_distance_only):
- node_features = [flatten([getter(atom) for getter in atom_feature_getters ]) for atom in molecule.GetAtoms() ]
- node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()]
- edge_features = [bond_features(bond) for bond in molecule.GetBonds()]
- edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()]
- # Bonds are not directed, so lets add the missing pair to make the graph undirected
- edge_index.extend([reversed(bond) for bond in edge_index])
- edge_features.extend(edge_features)
- # Some node_features had null values in carbon data and then the long graph compilation process was stopped.
- if any(None in sublist for sublist in node_features):
- return None
- return Data(
- x=tensor(node_features, dtype=torch.float),
- edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(),
- edge_attr=tensor(edge_features, dtype=torch.float),
- y=tensor([[t] for t in node_targets], dtype=torch.float)
- )
- def init_model(NO_MP, lr, wd):
- # Model
- NO_MP = NO_MP
- model = GNN_FULL_CLASS(NO_MP)
- # Optimizer
- LEARNING_RATE = lr
- optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd)
- # Criterion
- #criterion = torch.nn.MSELoss()
- criterion = torch.nn.L1Loss()
- return model, optimizer, criterion
- def train(model, criterion, optimizer, loader):
- loss_sum = 0
- for batch in loader:
- # Forward pass and gradient descent
- labels = batch.y
- predictions = model(batch)
- loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
- # Backpropagation
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- loss_sum += loss.item()
- return loss_sum/len(loader)
- def evaluate(model, criterion, loader):
- loss_sum = 0
- with torch.no_grad():
- for batch in loader:
- # Forward pass
- labels = batch.y
- predictions = model(batch)
- loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
- loss_sum += loss.item()
- return loss_sum/len(loader)
- def chunk_into_n(lst, n):
- size = math.ceil(len(lst) / n)
- return list(
- map(lambda x: lst[x * size:x * size + size],
- list(range(n)))
- )
- def get_data_loaders(data, split, batch_size):
- random.Random().shuffle(data)
- training_data = all_data[:int(len(all_data)*split)]
- testing_data = all_data[int(len(all_data)*split):]
- train_loader = DataLoader(training_data, batch_size = batch_size)
- test_loader = DataLoader(testing_data, batch_size = batch_size)
- return train_loader, test_loader
- def getBondElements(bond):
-
- a= bond.GetEndAtom().GetSymbol()
- b = bond.GetBeginAtom().GetSymbol()
- return a+b if a<b else b+a
- def getNaiveBondLength(bond):
- a = getMendeleevElement(bond.GetEndAtom().GetAtomicNum()).atomic_radius or 0
- b = getMendeleevElement(bond.GetBeginAtom().GetAtomicNum()).atomic_radius or 0
- return a/200.0 + b/200.0
- def train_model(train_set, test_set, split, batch_size, epochs, weight_decay=0.005, learning_rate=0.0003, NO_MP=7):
- NO_MP = NO_MP
- model = GNN_FULL_CLASS(NO_MP)
- # Optimizer
- LEARNING_RATE = learning_rate
- optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=weight_decay)
- # Criterion
- #criterion = torch.nn.MSELoss()
- criterion = torch.nn.L1Loss()
- model.train()
- train_loader = DataLoader(train_set, batch_size = batch_size)
- test_loader = DataLoader(test_set, batch_size = batch_size)
- for epoch in range(epochs):
- tloss = train(model, criterion, optimizer, train_loader)
- train_err = evaluate(model, criterion, train_loader)
- test_err = evaluate(model, criterion, test_loader)
- #print(train_err, test_err)
- return (test_err, tloss, train_err)
- def scale_graph_data(latent_graph_list):
- #Iterate through graph list to get stacked NODE and EDGE features
- node_stack=[]
- edge_stack=[]
- for g in latent_graph_list:
- node_stack.append(g.x) #Append node features
- edge_stack.append(g.edge_attr) #Append edge features
- node_cat=torch.cat(node_stack,dim=0)
- edge_cat=torch.cat(edge_stack,dim=0)
- #Calculate NODE feature MEAN
- node_mean=node_cat.mean(dim=0)
- #Calculate NODE feature STD
- node_std=node_cat.std(dim=0,unbiased=False)
- #Calculate EDGE feature MEAN
- edge_mean=edge_cat.mean(dim=0)
- #Calculate EDGE feature STD
- edge_std=edge_cat.std(dim=0,unbiased=False)
- #Apply zero-mean, unit variance scaling, append scaled graph to list
- latent_graph_list_sc=[]
- for g in latent_graph_list:
- x_sc=g.x-node_mean
- x_sc/=node_std
- ea_sc=g.edge_attr-edge_mean
- ea_sc/=edge_std
- ea_sc[ea_sc != ea_sc] = 0
- x_sc[x_sc != x_sc] = 0
- temp_graph=Data(x=x_sc,edge_index=g.edge_index,edge_attr=ea_sc, y=g.y)
- latent_graph_list_sc.append(temp_graph)
- return latent_graph_list_sc
- bond_idxes = np.array([Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC])
- bond_idxes = bond_idxes.reshape(len(bond_idxes), 1)
- onehot_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")
- onehot_encoder.fit(bond_idxes)
- ## Hybridization
- hybridization_idxes = np.array(list(Chem.HybridizationType.names))
- hybridization_idxes = hybridization_idxes.reshape(len(hybridization_idxes), 1)
- hybridization_ohe = OneHotEncoder(sparse=False)
- hybridization_ohe.fit(hybridization_idxes)
- ## Valence
- valences = np.arange(1, 8);
- valences = valences.reshape(len(valences), 1)
- valence_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
- valence_ohe.fit(valences)
- ## Formal Charge
- fc = np.arange(-1, 1);
- fc = fc.reshape(len(fc), 1)
- fc_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
- fc_ohe.fit(fc)
- ## Atomic number
- atomic_nums = np.array([6,1,7,8,9,17,15,11, 16])
- atomic_nums = atomic_nums.reshape(len(atomic_nums), 1)
- atomic_number_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
- atomic_number_ohe.fit(atomic_nums)
- atomic_number_ohe.transform(np.array([[1]]))
- supplier3d = Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine
- all_data = list(supplier3d)
- random.Random(80).shuffle(all_data)
- number_of_elements=len(all_data)
- training_data = all_data[:int(number_of_elements*TEST_SPLIT)]
- testing_data = all_data[int(number_of_elements*TEST_SPLIT):number_of_elements]
- """## Experiment blocks
- ### Node features
- """
- NO_SPLITS= 4
- descriptor_dict = {}
- descriptor_dict["random"]=lambda atom: random.random()
- descriptor_dict["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
- descriptor_dict["isAromatic"] = lambda atom: atom.GetIsAromatic()
- descriptor_dict["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
- descriptor_dict["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
- descriptor_dict["valence"] = lambda atom: atom.GetTotalValence()
- descriptor_dict["inRing"] = lambda atom: atom.IsInRing()
- descriptor_dict["hybridization"] = lambda atom: atom.GetHybridization()
- descriptor_dict["atomic num"] = lambda atom: atom.GetAtomicNum() # Atomic number
- descriptor_dict["atomic charge"]= lambda atom: atom.GetFormalCharge() # Atomic charge
- descriptor_dict["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
- descriptor_dict["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
- descriptor_dict["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
- descriptor_dict["covalent radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).covalent_radius # Covalent radius
- descriptor_dict["vdw radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).vdw_radius # Van der Waals radius
- descriptor_dict["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
- descriptor_dict["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
- descriptor_dict["electrophilicity index"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrophilicity() # Electrophilicity index
- descriptor_dict["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
- descriptor_dict["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
- descriptor_dict["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
- descriptor_dict["cooridates"] = lambda atom: list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx())) # coordinates
- descriptor_dict["formal charge"] = lambda atom: atom.GetFormalCharge()
- descriptor_dict["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
- #feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges
- descriptor_dict["chiral tag"] = lambda atom: atom.GetChiralTag()
- node_features_results = {}
- for getter in descriptor_dict.items():
- descriptor, func = getter
- print(f"Evaluating descriptor: {descriptor}")
- getter_results = []
- mol_graphs = scale_graph_data([turn_to_graph(mol, [func]) for mol in training_data if mol])
- splits =chunk_into_n(mol_graphs, NO_SPLITS)
- for split in splits:
- train_data = []
- for s in splits:
- if s!=split:
- train_data+=s
- getter_results.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500)[0])
- node_features_results[descriptor] = getter_results
- """### Edge features"""
- NO_SPLITS=4
- BONDS= ["CC", "CF", "CH", "CCl","CN", "FN", "FH","HN" ,"ClN","HO", "CO","NO"] #Pair characters ordererd alphabetically(CO rather than OC)
- bonded_atoms_idxs = np.array(BONDS)
- bonded_atoms_idxs = bonded_atoms_idxs.reshape(len(bonded_atoms_idxs), 1)
- bonded_atoms_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")
- bonded_atoms_encoder.fit(bonded_atoms_idxs)
- def getBondElements(bond):
-
- a= bond.GetEndAtom().GetSymbol()
- b = bond.GetBeginAtom().GetSymbol()
- return a+b if a<b else b+a
- def getNaiveBondLength(bond):
- a = getMendeleevElement(bond.GetEndAtom().GetAtomicNum()).atomic_radius or 0
- b = getMendeleevElement(bond.GetBeginAtom().GetAtomicNum()).atomic_radius or 0
- return a/200.0 + b/200.0
-
-
- def bond_feature_pure_distance_and_rdkit_type(bond):
- onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
- [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
- [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
- distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
- return distance+list(onehot_encoded_bondtype)
- def bond_feature_pure_distance_with_bonded_atoms(bond):
- onehot_encoded_bondtype = bonded_atoms_encoder.transform(np.array([[getBondElements(bond)]]))[0]
- [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
- [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
- distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
- return distance +list(onehot_encoded_bondtype)
- def bond_feature_smart_distance_and_rdkit_type(bond):
- onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
- [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
- [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
- ex_dist = getNaiveBondLength(bond)
- distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance
- return distance+ list(onehot_encoded_bondtype)
- def bond_feature_smart_distance_with_bonded_atoms(bond):
- onehot_encoded_bondtype = bonded_atoms_encoder.transform(np.array([[getBondElements(bond)]]))[0]
- [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
- [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
- ex_dist = getNaiveBondLength(bond)
- distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance
-
- return distance+list(onehot_encoded_bondtype)
- def bond_features_random(bond):
- return [random.random()]
- def all_bond_feature_functions():
- return {
- "smart,atoms":bond_feature_smart_distance_with_bonded_atoms,
- "pure,rdkit":bond_feature_pure_distance_and_rdkit_type,
- "pure,atoms":bond_feature_pure_distance_with_bonded_atoms,
- "smart,rdkit":bond_feature_smart_distance_and_rdkit_type,
- "random":bond_features_random,
- }
- all_functions = all_bond_feature_functions()
- bond_features_results ={}
- for name, method in all_functions.items():
- print(name)
- getter_results = []
- mol_graphs = scale_graph_data([turn_to_graph(mol, bond_features=method) for mol in training_data if mol])
- splits =chunk_into_n(mol_graphs, NO_SPLITS)
- for split in splits:
- train_data = []
- for s in splits:
- if s!=split:
- train_data+=s
- getter_results.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500))
- bond_features_results[name] = getter_results
- """### Hyperparameters"""
- NO_SPLITS= 4
- import pandas as pd
- hp_results = pd.DataFrame(columns = ['wd', 'lr','m', 'accuracy']) #Dataframe in which all results are saved in
- mol_graphs = scale_graph_data([turn_to_graph(mol) for mol in training_data if mol])
- splits =chunk_into_n(mol_graphs, NO_SPLITS)
- for m in range(4,8):
- for lr in [0.0007, 0.001,0.0013]:
- for wd in [ 0.0025, 0.005, 0.0075, 0.01, 0.015]:
- for split in splits:
- train_data = []
- for s in splits:
- if s!=split:
- train_data+=s
- test_err, tloss, train_err = train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500, weight_decay=wd, learning_rate=lr, NO_MP=m)
- hp_results = hp_results.append({'accuracy': test_err, 'loss':tloss, "train_err":train_err, 'm':m,"lr":lr, "wd":wd},ignore_index=True)
- """### Top N feature model - finding optimal n?"""
- NO_SPLITS=4
- def all_atom_feature_ordered_by_acc():
- feature_getters = {}
- feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
- feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
- feature_getters["covalent radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).covalent_radius # Covalent radius
- feature_getters["vdw radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).vdw_radius # Van der Waals radius
- feature_getters["electrophilicity index"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrophilicity() # Electrophilicity index
- feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
- feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
- feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
- feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
- feature_getters["atomic num"] = lambda atom: atom.GetAtomicNum() # Atomic number
- feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
- feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
- feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
- feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
- feature_getters["valence"] = lambda atom: atom.GetTotalValence()
- feature_getters["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
- feature_getters["hybridization"] = lambda atom: atom.GetHybridization()
- feature_getters["inRing"] = lambda atom: atom.IsInRing()
- feature_getters["isAromatic"] = lambda atom: atom.GetIsAromatic()
- feature_getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
- feature_getters["atomic charge"]= lambda atom: atom.GetFormalCharge() # Atomic charge
- feature_getters["formal charge"] = lambda atom: atom.GetFormalCharge()
- #feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges
- feature_getters["chiral tag"] = lambda atom: atom.GetChiralTag()
- feature_getters["cooridates"] = lambda atom: list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx())) # coordinates
- return feature_getters
- model_results_df = pd.DataFrame(columns = ['model', 'accuracy']) #Dataframe in which all results are saved in
- all_getters = all_atom_feature_ordered_by_acc()
- n_feature_results ={}
- used_methods =[]
- features=[]
- for getters in all_getters.items():
- print(getters)
- used_methods.append(getters[1])
- features.append(getters[0])
- getter_results = []
- mol_graphs = scale_graph_data([turn_to_graph(mol, used_methods) for mol in training_data if mol])
- splits =chunk_into_n(mol_graphs, NO_SPLITS)
- for split in splits:
- train_data = []
- for s in splits:
- if s!=split:
- train_data+=s
- getter_results.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500)[0])
- n_feature_results[str(features)] = getter_results
- print(chunk_into_n([1,23,45123,1231,12,123,123,123], 2))
- """### 2019 model features"""
- NO_SPLITS = 4
- pt = Chem.GetPeriodicTable()
- def atom_features_2019():
- getters = {}
- getters["atomic_num"] = lambda atom: atom.GetAtomicNum()
- getters["atomic_num_ohe"] = lambda atom: atom.GetAtomicNum()
- getters["valence"] = lambda atom: atom.GetTotalValence()
- getters["total valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
- getters["defualt valence ohe"] = lambda atom: valence_ohe.transform(np.array([[pt.GetDefaultValence(atom.GetAtomicNum())]]))[0]
- getters["isAromatic"] = lambda atom: atom.GetIsAromatic()
- getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
- getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
- getters["inRing"] = lambda atom: atom.IsInRing()
- return getters
- results_2019=[]
- mol_graphs = scale_graph_data([turn_to_graph(mol, [getter for getter in atom_features_2019().values()]) for mol in training_data if mol])
- splits =chunk_into_n(mol_graphs, NO_SPLITS)
- for split in splits:
- train_data = []
- for s in splits:
- if s!=split:
- train_data+=s
- results_2019.append(train_model(train_data,split, 1-1.0/NO_SPLITS, 128, 500)[0])
- """### Varying dataset sizes experiments"""
- #Downloading all datasets
- import requests
- datasets ={
- "c_full":"https://sourceforge.net/projects/nmrshiftdb2/files/data/nmrshiftdb2withsignals.sd/download",
- "c_cdcl3":"https://nmrshiftdb.nmr.uni-koeln.de/nmrshiftdb2cdcl3.sd",
- "c_dmso":"https://nmrshiftdb.nmr.uni-koeln.de/nmrshiftdb2dmso.sd",
- "c_cd3od":"https://nmrshiftdb.nmr.uni-koeln.de/nmrshiftdb2cd3od.sd"
- # nmrshiftdb2withsignals.sdf for cholorine(already downloaded in core block)
- }
- for name, url in datasets.items():
- r = requests.get(url, allow_redirects=True)
- with open(name+".sdf", 'wb') as f:
- f.write(r.content)
- target_atom_number= 9 #6 for carbon 9 for chlorine
- supplier3d= Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine
- NO_SPLITS=4
- model_results_df_2 = pd.DataFrame(columns = ['model', "dataset_size", 'mse',"rmse"]) #Dataframe in which all results are saved in
- all_data = list(supplier3d)
- for size in (100,250,500, len(supplier3d)):
- random.Random(80).shuffle(all_data)
- for i in range(0, len(all_data), size):
- subset=all_data[i:min(i+size, len(all_data))]
- splits =chunk_into_n(subset, NO_SPLITS)
- for split in splits:
- train_data = []
- for s in splits:
- if s!=split:
- train_data+=s
- hose_map={}
- for mol in train_data:
- for trio in get_molecule_hose_and_nmr_shifts(mol, 6, target_atom_number ):
- parts=split_hose(trio[1])
- for i in range(len(parts)):
- hose = "/".join(parts[:(i+1)])
- if hose not in hose_map:
- hose_map[hose] = []
- hose_map[hose].append(trio[2])
- avg_map = {}
- for k,v in hose_map.items():
- avg_map[k] =sum(v)/len(v)
- predictions = []
- labels=[]
- familiar_radius = [] # for visualisation
- missing=0
- for mol in split:
- for trio in get_molecule_hose_and_nmr_shifts(mol, 6,target_atom_number ):
- parts=split_hose(trio[1])
- for i in range(len(parts),0,-1):
- is_match = False
- hose = "/".join(parts[:(i)])
- if hose in hose_map:
- predictions.append(avg_map[hose])
- familiar_radius.append(i)
- labels.append(trio[2])
- is_match=True
- break
- if not is_match:
- missing+=1
-
-
- errors=[]
- for i in range(len(predictions)):
- errors.append(labels[i] - predictions[i])
- avg_error = sum(errors)/len(errors)
- #print(avg_error)
- model_results_df_2 = pd.concat([model_results_df_2,pd.DataFrame(data = {'model': ["Hose"], "dataset_size":[size],
- "mae":[sum([abs(err) for err in errors])/len(errors)],
- "rmse": [(sum([err**2 for err in errors])/len(errors))**0.5],
- "std": [(sum([(err-avg_error)**2 for err in errors]) /len(errors))**0.5],
- "missing":[ missing]
- })], ignore_index = True,axis=0, join='outer')
- def get_top_feature_getters():
- feature_getters = {}
- feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
- feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
- feature_getters["covalent radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).covalent_radius # Covalent radius
- feature_getters["vdw radius"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).vdw_radius # Van der Waals radius
- feature_getters["electrophilicity index"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrophilicity() # Electrophilicity index
- feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
- feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
- feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
- feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
- feature_getters["atomic num"] = lambda atom: atom.GetAtomicNum() # Atomic number
- feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
- feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
- feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
- feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
- feature_getters["valence"] = lambda atom: atom.GetTotalValence()
- return feature_getters
- random.seed(80)
- # Shuffle all of our data, as input data might be sorted
- all_data = list(supplier3d)
- getters = list(get_top_feature_getters().values())
- all_data = scale_graph_data([turn_to_graph(molecule, atom_feature_getters=getters, bond_features= bond_feature_smart_distance_and_rdkit_type) for idx, molecule in enumerate(supplier3d) if molecule])
- random.Random(80).shuffle(all_data)
- for size in ([100,250,500,len(all_data)]):
- method_name="top n descriptors"
- for i in range(0, len(all_data), size):
- print(i, size)
- mol_graphs=all_data[min(i,len(all_data)-size) :min(i+size, len(all_data))]
- splits =chunk_into_n(mol_graphs, NO_SPLITS)
- for split in splits:
- train_data = []
- for s in splits:
- if s!=split:
- train_data+=s
- BATCH_SIZE = 128
- train_loader = DataLoader(train_data, batch_size = BATCH_SIZE)
- test_loader = DataLoader(split, batch_size = BATCH_SIZE)
- model = GNN_FULL_CLASS(6)
- model.train()
- optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=.01)
- mae = torch.nn.L1Loss()
- mse = torch.nn.MSELoss()
- for epoch in range(500):
- tloss = train(model, mae, optimizer, train_loader)
- train_err = evaluate(model, mae, train_loader)
- test_err = evaluate(model, mae, test_loader)
-
- fixing_epochs=0
- while train_err>2.5 and fixing_epochs <200:
- tloss = train(model, mae, optimizer, train_loader)
- train_err = evaluate(model, mae, train_loader)
- test_err = evaluate(model, mae, test_loader)
- fixing_epochs+=1
- preds = torch.tensor([])
- labels = torch.tensor([])
- with torch.no_grad():
- for batch in test_loader:
- # Forward pass
- labels = torch.cat((labels, batch.y), 0)
- preds = torch.cat((preds,model(batch)),0)
- std = torch.std(torch.subtract(preds[torch.isfinite(labels)], labels[torch.isfinite(labels)]))
- model_results_df_2 = pd.concat([model_results_df_2,pd.DataFrame(data = {'model': [method_name], "dataset_size":[size],
- "mae":[evaluate(model, mae, test_loader)],
- "rmse": [evaluate(model, mse, test_loader)**0.5],
- "std": [float(std)],
- })], ignore_index = True,axis=0, join='outer')
- #model_results_df_2 = model_results_df_2.append({'model': method_name, "dataset_size":size, "mae": evaluate(model, mae, test_loader), "rmse": evaluate(model, mse, test_loader)**0.5, "std":float(std)} , ignore_index = True)
|