# -*- coding: utf-8 -*- """NMR shift prediction from small data quantities.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/1yKTRjpWzR8T199eCokuJfd9Y5o2oNtPp # Data: NMRShiftDB ## Reading data set from SD file The nmrshiftdb2 data from https://sourceforge.net/projects/nmrshiftdb2/files/data/ has been downloaded and moved to our own domain in order to avoid the https://sourceforge.net/ madness (redirects, cookie banner, etc.) and ensure that colab can be rerun without manual steps or Google Drive dependencies by all collaborators. """ # Install required dependencies !pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file !pip3 install mendeleev # To to access various features related to atoms import torch # PyTorch dependencies to represent graph data !pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html !pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html !pip3 install torch-geometric # Use CUDA by default if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') # Downloads the nmrshiftdb2 database if it does not yet exist in our runtime !wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine #!wget -nc "https://sourceforge.net/projects/nmrshiftdb2/files/data/nmrshiftdb2withsignals.sd/download" #Carbon from rdkit import Chem supplier3d = Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine #supplier3d = Chem.rdmolfiles.SDMolSupplier("download") #Carbon print(f"In total there are {len(supplier3d)} molecules") from rdkit.Chem import AllChem mol= supplier3d[152] mol """## Graph transformation In order to use convolutional networks, we need to transform the molecule data into graphs. The atoms themselves will become nodes, and the bonds between the atoms will become the edges. """ import mendeleev import torch import math from torch import tensor from torch_geometric.data import Data from torch_geometric.loader import DataLoader import os import numpy as np import time from sklearn.preprocessing import OneHotEncoder import random # One hot encoding ## Bonds bond_idxes = np.array([Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC]) bond_idxes = bond_idxes.reshape(len(bond_idxes), 1) onehot_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore") onehot_encoder.fit(bond_idxes) ## Hybridization hybridization_idxes = np.array(list(Chem.HybridizationType.names)) hybridization_idxes = hybridization_idxes.reshape(len(hybridization_idxes), 1) hybridization_ohe = OneHotEncoder(sparse=False) hybridization_ohe.fit(hybridization_idxes) ## Valence valences = np.arange(1, 8); valences = valences.reshape(len(valences), 1) valence_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False) valence_ohe.fit(valences) ## Formal Charge fc = np.arange(-1, 1); fc = fc.reshape(len(fc), 1) fc_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False) fc_ohe.fit(fc) ## Atomic number atomic_nums = np.array([6,1,7,8,9,17,15,11, 16]) atomic_nums = atomic_nums.reshape(len(atomic_nums), 1) atomic_number_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False) atomic_number_ohe.fit(atomic_nums) atomic_number_ohe.transform(np.array([[1]])) def get_molecule_solvent(molecule): if not molecule: return {} for key, value in molecule.GetPropsAsDict().items(): if key.startswith("Solvent"): return value return None solvents={} for mol in supplier3d: a = get_molecule_solvent(mol) if a: if a not in solvents: solvents[a] = 0 solvents[a] += 1 arr = [k for k, v in solvents.items() if v>10] solvents = np.array(arr) solvents = solvents.reshape(len(solvents), 1) solvent_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False) solvent_ohe.fit(solvents) #solvent_ohe.transform(np.array([[1]])) el_map={} def getMendeleevElement(nr): if nr not in el_map: el_map[nr] = mendeleev.element(nr) return el_map[nr] def nmr_shift(atom): for key, value in atom.GetOwningMol().GetPropsAsDict().items(): if key.startswith("Spectrum"): for shift in value.split('|'): x = shift.split(';') if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"): return float(x[0]) return float("NaN") # We use NaN for atoms we don't want to predict shifts def bond_features(bond): onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0] [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx())) [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx())) distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance return distance+list(onehot_encoded_bondtype) def atom_features_random(atom, molecule=None): features = [] features.append(random.randint(0,9)) # Atomic number return features def atom_features_update(atom, molecule=None): features = [] me = getMendeleevElement(atom.GetAtomicNum()) features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0]) features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]) features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]) features.append(me.atomic_volume) # Atomic volume features.append(me.dipole_polarizability) # Dipole polarizability features.append(me.electron_affinity) # Electron affinity features.append(me.en_pauling) # Electronegativity features.append(me.electrons) # No. of electrons features.append(me.neutrons) # No. of neutrons features.append(atom.GetChiralTag()) features.append(atom.IsInRing()) return features def atom_features_2019(atom, molecule=None): features.append(atom.GetAtomicNum()) features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]) features.extend(valence_ohe.transform(np.array([[atom.GetDefaultValence()]]))[0]) features.append(atom.GetTotalValence(atom.GetAtomicNum())) features.append(atom.GetIsAromatic()) features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]) features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]) features.append(atom.IsInRing()) return features def atom_features_top_single_performers(atom, molecule=None): me = getMendeleevElement(atom.GetAtomicNum()) features = [] #features.append(atom.GetAtomicNum()) features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number features.append(me.atomic_radius or 0) # Atomic volume features.append(me.neutrons) # Van der Waals radius #features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]) features.append(me.electron_affinity) # Electron affinity features.append(me.en_pauling) # Electronegativity #features.append(me.electrons) # No. of neutrons return features def atom_features_top_single_performers_with_solvents(atom, molecule=None): me = getMendeleevElement(atom.GetAtomicNum()) features = [] features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0]) #features.append(atom.GetAtomicNum()) features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number features.append(me.atomic_radius or 0) # Atomic volume features.append(me.neutrons) # Van der Waals radius #features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]) features.append(me.electron_affinity) # Electron affinity features.append(me.en_pauling) # Electronegativity features.append(me.electrons) # No. of neutrons return features def atom_features(atom, molecule=None): me = getMendeleevElement(atom.GetAtomicNum()) #[x, y, z] = list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx())) features = [] #TBD: Do we need to encode molecule atom itself? Or is atomic number sufficient? One-hot encode? features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number #features.append(atom.GetIsAromatic()) features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0]) features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]) #features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]) features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]) features.append(me.atomic_radius or 0) # Atomic radius features.append(me.atomic_volume) # Atomic volume features.append(me.atomic_weight) # Atomic weight features.append(me.covalent_radius) # Covalent radius features.append(me.vdw_radius) # Van der Waals radius features.append(me.dipole_polarizability) # Dipole polarizability features.append(me.electron_affinity) # Electron affinity features.append(me.electrophilicity()) # Electrophilicity index features.append(me.en_pauling) # Electronegativity features.append(me.electrons) # No. of electrons features.append(me.neutrons) # No. of neutrons #features.append(x) # X coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper) #features.append(y) # Y coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper) #features.append(z) # Z coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper) #features.append(0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge'))) #partial charges features.append(atom.GetChiralTag()) features.append(atom.IsInRing()) return features def convert_to_graph(molecule, atom_feature_constructor=atom_features): #Chem.rdPartialCharges.ComputeGasteigerCharges(molecule) node_features = [atom_feature_constructor(atom, molecule) for atom in molecule.GetAtoms()] node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()] edge_features = [bond_features(bond) for bond in molecule.GetBonds()] edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()] # Bonds are not directed, so lets add the missing pair to make the graph undirected edge_index.extend([reversed(bond) for bond in edge_index]) edge_features.extend(edge_features) # Some node_features had null values in carbon data and then the long graph compilation process was stopped. if any(None in sublist for sublist in node_features): return None return Data( x=tensor(node_features, dtype=torch.float), edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(), edge_attr=tensor(edge_features, dtype=torch.float), y=tensor([[t] for t in node_targets], dtype=torch.float) ) def scale_graph_data(latent_graph_list, scaler=None): if scaler: node_mean, node_std, edge_mean, edge_std = scaler print(f"Using existing scaler: {scaler}") else: #Iterate through graph list to get stacked NODE and EDGE features node_stack=[] edge_stack=[] for g in latent_graph_list: node_stack.append(g.x) #Append node features edge_stack.append(g.edge_attr) #Append edge features node_cat=torch.cat(node_stack,dim=0) edge_cat=torch.cat(edge_stack,dim=0) node_mean=node_cat.mean(dim=0) node_std=node_cat.std(dim=0,unbiased=False) edge_mean=edge_cat.mean(dim=0) edge_std=edge_cat.std(dim=0,unbiased=False) #Apply zero-mean, unit variance scaling, append scaled graph to list latent_graph_list_sc=[] for g in latent_graph_list: x_sc=g.x-node_mean x_sc/=node_std ea_sc=g.edge_attr-edge_mean ea_sc/=edge_std ea_sc=torch.nan_to_num(ea_sc, posinf=1.0) x_sc=torch.nan_to_num(x_sc, posinf=1.0) temp_graph=Data(x=x_sc,edge_index=g.edge_index,edge_attr=ea_sc, y=g.y) latent_graph_list_sc.append(temp_graph) scaler= (node_mean,node_std,edge_mean,edge_std) return latent_graph_list_sc,scaler """# Graph Neural Network ## Separation into training set and test set """ TRAIN_TEST_SPLIT = 0.8 all_data = list(supplier3d) random.Random(80).shuffle(all_data) train_data =all_data[:int(TRAIN_TEST_SPLIT * len(supplier3d))] test_data =all_data[int(TRAIN_TEST_SPLIT * len(supplier3d)):] #TODO: Use training data scaler on test data. train_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(train_data) if molecule]) test_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(test_data) if molecule], scaler=scaler) #all_data = [convert_to_graph(molecule) for idx, molecule in enumerate(supplier3d) if molecule and idx not in[24,25,30,31,32]] print(f"Converted {len(supplier3d)} molecules to {len(train_graphs) + len(test_graphs)} graphs") print(f"Found {sum([sum([1 for shift in graph.y if not math.isnan(shift[0])]) for graph in train_graphs+test_graphs])} individual NMR shifts") """##2023 model""" import torch from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm from torch_scatter import scatter_mean, scatter_add from torch_geometric.nn import MetaLayer from torch_geometric.data import Batch NO_GRAPH_FEATURES=128 ENCODING_NODE=64 ENCODING_EDGE=32 HIDDEN_NODE=128 HIDDEN_EDGE=64 HIDDEN_GRAPH=128 def init_weights(m): if type(m) == torch.nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) class EdgeModel(torch.nn.Module): def __init__(self): super(EdgeModel, self).__init__() self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_EDGE)).apply(init_weights) def forward(self, src, dest, edge_attr, u, batch): # source, target: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # u: [B, F_u], where B is the number of graphs. # batch: [E] with max entry B - 1. out = torch.cat([src, dest, edge_attr], 1) return self.edge_mlp(out) class NodeModel(torch.nn.Module): def __init__(self): super(NodeModel, self).__init__() self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17), LazyLinear(HIDDEN_NODE)).apply(init_weights) self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17), LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_NODE)).apply(init_weights) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) out = self.node_mlp_1(out) out = scatter_add(out, col, dim=0, dim_size=x.size(0)) out = torch.cat([x, out], dim=1) return self.node_mlp_2(out) class GlobalModel(torch.nn.Module): def __init__(self): super(GlobalModel, self).__init__() self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17), LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row,col=edge_index node_aggregate = scatter_add(x, batch, dim=0) edge_aggregate = scatter_add(edge_attr, batch[col], dim=0) out = torch.cat([node_aggregate, edge_aggregate], dim=1) return self.global_mlp(out) class GNN_FULL_CLASS(torch.nn.Module): def __init__(self, NO_MP): super(GNN_FULL_CLASS,self).__init__() #Meta Layer for Message Passing self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) #Edge Encoding MLP self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(ENCODING_EDGE)).apply(init_weights) self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_NODE)).apply(init_weights) self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10), LazyBatchNorm1d(), LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(1)).apply(init_weights) self.no_mp = NO_MP def forward(self,dat): #Extract the data from the batch x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch # Embed the node and edge features enc_x = self.encoding_node(x) enc_ea = self.encoding_edge(ea) #Create the empty molecular graphs for feature extraction, graph level one u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float) #Message-Passing for _ in range(self.no_mp): enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc) targs = self.mlp_last(enc_x) return targs #Additional helpful methods def init_model(NO_MP, lr, wd): # Model NO_MP = NO_MP model = GNN_FULL_CLASS(NO_MP) # Optimizer LEARNING_RATE = lr optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd) # Criterion #criterion = torch.nn.MSELoss() criterion = torch.nn.L1Loss() return model, optimizer, criterion def train(model, criterion, optimizer, loader): loss_sum = 0 for batch in loader: # Forward pass and gradient descent labels = batch.y predictions = model(batch) loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)]) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() loss_sum += loss.item() return loss_sum/len(loader) def evaluate(model, criterion, loader): loss_sum = 0 with torch.no_grad(): for batch in loader: # Forward pass labels = batch.y predictions = model(batch) loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)]) loss_sum += loss.item() return loss_sum/len(loader) def chunk_into_n(lst, n): size = math.ceil(len(lst) / n) return list( map(lambda x: lst[x * size:x * size + size], list(range(n))) ) """## Training""" VALIDATION_SPLITS = 4 EPOCHS = 500 BATCH_SIZE = 128 splits =chunk_into_n(train_graphs, VALIDATION_SPLITS) split_errors=[] for idx, split in enumerate(splits): split_train_data = [] for s in splits: if s!=split: split_train_data+=s train_loader = DataLoader(split_train_data, batch_size = BATCH_SIZE) test_loader = DataLoader(split, batch_size = BATCH_SIZE) model, optimizer, criterion = init_model(6,0.001,0.1) loss_list = [] train_err_list = [] test_err_list = [] model.train() print(f"Split {idx+1}. Training/test split size:{len(split_train_data)}/{len(split)}") for epoch in range(EPOCHS): tloss = train(model, criterion, optimizer, train_loader) train_err = evaluate(model, criterion, train_loader) test_err = evaluate(model, criterion, test_loader) loss_list.append(tloss) train_err_list.append(train_err) test_err_list.append(test_err) print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss, train_err, test_err)) extra_epochs=0 #Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima. while extra_epochs<200 and tloss>2.5: tloss = train(model, criterion, optimizer, train_loader) train_err = evaluate(model, criterion, train_loader) test_err = evaluate(model, criterion, test_loader) loss_list.append(tloss) train_err_list.append(train_err) test_err_list.append(test_err) extra_epochs+1 print("\n") split_errors.append(test_err) print(f"Split errors: {split_errors} with average error {sum(split_errors) / VALIDATION_SPLITS}") """## Evaluation on test set""" train_loader = DataLoader(train_graphs, batch_size = BATCH_SIZE) test_loader = DataLoader(test_graphs, batch_size = BATCH_SIZE) model, optimizer, criterion = init_model(6,0.001,0.1) loss_list = [] train_err_list = [] test_err_list = [] model.train() for epoch in range(EPOCHS): tloss = train(model, criterion, optimizer, train_loader) train_err = evaluate(model, criterion, train_loader) test_err = evaluate(model, criterion, test_loader) loss_list.append(tloss) train_err_list.append(train_err) test_err_list.append(test_err) print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss, train_err, test_err)) extra_epochs=0 #Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima. while extra_epochs<200 and tloss>2.5: tloss = train(model, criterion, optimizer, train_loader) train_err = evaluate(model, criterion, train_loader) test_err = evaluate(model, criterion, test_loader) loss_list.append(tloss) train_err_list.append(train_err) test_err_list.append(test_err) extra_epochs+1 print("\n") #print(len(test_loader)) #print( criterion, test_loader) evaluate(model, criterion, test_loader) single = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features_top_single_performers_with_solvents) for molecule in [supplier3d[32], supplier3d[32]] if molecule]) evaluate(model, criterion, DataLoader(single, batch_size = 2)) """## Visualisations & Analysis""" test_loader = DataLoader(testing_data, batch_size = 128) with torch.no_grad(): for batch in test_loader: # Forward pass labels = batch.y predictions = model(batch) a =torch.sub(predictions, labels) for a,b,c,d in zip(predictions, labels, batch.x, a): if d>50: print(a,b,c[0],d) loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)]) print(loss) #loss_sum += loss.item() #return loss_sum/len(loader) import os import matplotlib.pyplot as plt def plot_loss(loss_list): fig, axs = plt.subplots() axs.plot(range(1, len(loss_list) + 1), loss_list, label="Loss") axs.set_xlabel("Epoch") axs.legend(loc=3) axs.set_yscale('log') axs.legend() def plot_error(train_err_list, test_err_list): fig, axs = plt.subplots() axs.plot(range(1, len(train_err_list) + 1), train_err_list, label="Train Err") axs.plot(range(1, len(test_err_list) + 1), test_err_list, label="Test Err") axs.set_xlabel("Epoch") axs.legend(loc=3) axs.set_yscale('log') axs.legend() plot_loss(loss_list) plot_error(train_err_list, test_err_list) """We can draw molecules easily for 2D coordinates, so let's download the dataset containing 2D coordinates.""" errs=[] for i in [100,250,500,len(all_data)]: SPLIT = 0.75 # Let's 75% for training and 25% for validation # Shuffle all of our data, as input data might be sorted random.Random(46).shuffle(all_data) # Training set training_data = all_data[:int(i*SPLIT)] # Validation set testing_data = all_data[int(i*SPLIT):i] BATCH_SIZE = 128 train_loader = DataLoader(training_data, batch_size = BATCH_SIZE) test_loader = DataLoader(testing_data, batch_size = BATCH_SIZE) print(f"Number of train batches: {len(train_loader)}") print(f"Number of test batches: {len(test_loader)}") # Loading model from the paper # Model NO_MP = 7 model = GNN_FULL_CLASS(NO_MP) # Optimizer LEARNING_RATE = 0.05 optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=.005) # Criterion #criterion = torch.nn.MSELoss() criterion = torch.nn.L1Loss() EPOCHS = 500 loss_list = [] train_err_list = [] test_err_list = [] model.train() for epoch in range(EPOCHS): tloss = train(model, criterion, optimizer, train_loader) train_err = evaluate(model, criterion, train_loader) test_err = evaluate(model, criterion, test_loader) loss_list.append(tloss) train_err_list.append(train_err) test_err_list.append(test_err) print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss, train_err, test_err)) print("\n") test_err = evaluate(model, criterion, test_loader) print(test_err) errs.append(test_err) print(errs) """# Existing methods Looking at existing methods. * Baseline for model comparison ## Hose ### Downloads & Initialization """ from rdkit import Chem from rdkit.Chem import AllChem import matplotlib.pyplot as plt !pip install git+https://github.com/Ratsemaat/HOSE_code_generator import hosegen hg = hosegen.HoseGenerator() """### Preparation""" def get_hose_code(molecule, distance, atom_idx): return hg.get_Hose_codes(molecule, atom_idx, max_radius=distance) get_hose_code(supplier3d[1],5,0) def get_atom_indexes(molecul, element_nr): indexes=[] if not molecul: return [] for atom in molecul.GetAtoms(): if atom.GetAtomicNum() == element_nr: indexes.append(atom.GetIdx()) return indexes mol = Chem.MolFromSmiles('C=CN(CC)CF') print(get_atom_indexes(mol, 6)) print(get_atom_indexes(mol, 7)) print(get_atom_indexes(mol, 9)) def get_molecule_hose_codes(molecule, radius, atom_idx): indexes = get_atom_indexes(molecule, 9) hoses=[] for idx in indexes: hoses.append((idx,get_hose_code(molecule, radius, idx))) return hoses print(get_molecule_hose_codes(supplier3d[1], 5,9)) import random SPLIT = 0.75 # Let's 75% for training and 25% for validation random.seed(42) # Shuffle all of our data, as input data might be sorted train_idx = random.sample([i for i in range(len(supplier3d))],math.ceil(len(supplier3d)*SPLIT)) test_idx = set([i for i in range(len(supplier3d))]).difference(set(train_idx)) def get_molecule_nmr_shifts(molecule): if not molecule: return {} map={} for key, value in molecule.GetPropsAsDict().items(): if key.startswith("Spectrum"): for shift in value.split('|'): x = shift.split(';') if (len(x) == 3): map[x[2]] = float(x[0]) return map print(get_molecule_nmr_shifts(supplier3d[1])) def get_molecule_hose_and_nmr_shifts(molecule, distance, element_nr): arr=[] idxs = set(get_atom_indexes(molecule, element_nr)) shifts = get_molecule_nmr_shifts(molecule) #Only look at atoms where shift is defiend arr2 = set([int(k) for k in shifts.keys()]) idxs = idxs.intersection(arr2) for idx in idxs: hoses = get_hose_code(molecule, distance, idx) arr.append((idx, hoses, shifts[str(idx)])) return arr print(get_molecule_hose_and_nmr_shifts(supplier3d[1], 5, 9)) def split_hose(hose): return hose.replace('(',' ').replace('/',' ').replace(')',' ').split() """### Training""" map={} for id in train_idx: for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ): parts=split_hose(trio[1]) for i in range(len(parts)): hose = "/".join(parts[:(i+1)]) if hose not in map: map[hose] = [] map[hose].append(trio[2]) print(map) avg_map = {} for k,v in map.items(): avg_map[k] =sum(v)/len(v) """### Evaluation""" predictions = [] labels=[] familiar_radius = [] # for visualisation for id in test_idx: for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ): parts=split_hose(trio[1]) for i in range(len(parts),0,-1): hose = "/".join(parts[:(i)]) if hose in map: predictions.append(avg_map[hose]) familiar_radius.append(i) labels.append(trio[2]) break print(labels) print(predictions) errors=[] for i in range(len(predictions)): errors.append(abs(labels[i] - predictions[i])) print(sum(errors)/len(errors)) """### Visualisations & Analysis""" plt.hist(familiar_radius, bins=[0.5, 1.5,2.5,3.5,4.5,5.5,6.5], align="mid") plt.title("Longest exact HOSE code length in training set") plt.show() errors = [[] for i in range(6)] for i in range(len(predictions)): errors[familiar_radius[i]-1].append(abs(labels[i] - predictions[i])) for i in range(len(errors)): errors[i] = sum(errors[i])/len(errors[i]) plt.plot([i for i in range(6)], errors) plt.title("Average error rate of HOSE code") plt.show() SPLIT = 0.75 # Let's 75% for training and 25% for validation hose_error_measurements=[] for i in range(10): err_hose=[] random.Random(42).shuffle(all_data) for size in [100,250,500,len(all_data)]: train_idx = random.sample([i for i in range(size)],math.ceil(size*SPLIT)) test_idx = set([i for i in range(size)]).difference(set(train_idx)) map={} for id in train_idx: for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ): parts=split_hose(trio[1]) for i in range(len(parts)): hose = "/".join(parts[:(i+1)]) if hose not in map: map[hose] = [] map[hose].append(trio[2]) avg_map = {} for k,v in map.items(): avg_map[k] =sum(v)/len(v) predictions = [] labels=[] familiar_radius = [] # for visualisation for id in test_idx: for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ): parts=split_hose(trio[1]) for i in range(len(parts),0,-1): hose = "/".join(parts[:(i)]) if hose in map: predictions.append(avg_map[hose]) familiar_radius.append(i) labels.append(trio[2]) break errors=[] for i in range(len(predictions)): errors.append(abs(labels[i] - predictions[i])) err_hose.append(sum(errors)/len(errors)) hose_error_measurements.append(err_hose) sum_100=0 sum_250=0 sum_500=0 sum_all=0 for i in hose_error_measurements: sum_100 += i[0] sum_250 += i[1] sum_500 += i[2] sum_all += i[3] avg_err_hose = [sum_100/len(hose_error_measurements),sum_250/len(hose_error_measurements),sum_500/len(hose_error_measurements),sum_all/len(hose_error_measurements)] err_gnn = [66.1449966430664, 30.329145431518555, 9.740877151489258, 10.084354400634766] print(avg_err_hose) plt.plot([100,250,500,970],err_gnn, label="GNN model", color="red") plt.plot([100,250,500,970],avg_err_hose, label="HOSE code model", color="green") plt.xlabel("Number of spectra") plt.ylabel("ppm") plt.title("Now. Error in ppm in relation to used examples.") plt.legend() plt.show() """# Results Since obtaining the results ## Edge and node features Testing difderent descriptors with fixed hyperparameters(weight decay=0.005, learning rate=0.0003, NO_MP=7). When testing node features the edge_features were also fixed and when testing edge features then node features were fixed. So in other words feature results were obtained by only playing with corresponding values. Each feature was tested using 4-fold crossvalidation on train data (80% of whole dataset). """ node_features_results = {'ohe atomic number': [11.206879615783691, 10.606162071228027, 10.678353309631348, 14.426544666290283], 'isAromatic': [20.246356964111328, 21.289775848388672, 22.023090362548828, 21.481626510620117], 'hyb ohe': [14.719367980957031, 14.110363483428955, 17.14181661605835, 17.19037103652954], 'valence ohe': [16.09677743911743, 16.301819801330566, 18.414384841918945, 14.627246856689453], 'valence': [15.74299955368042, 15.868663311004639, 17.854907989501953, 14.021881580352783], 'inRing': [20.925933837890625, 19.979466438293457, 20.020546913146973, 19.80652141571045], 'hybridization': [15.455873012542725, 16.48819398880005, 16.359801769256592, 17.977892875671387], 'atomic num': [10.075446605682373, 10.49792766571045, 11.835340023040771, 12.882370471954346], 'atomic charge': [20.686185836791992, 22.381511688232422, 27.067391395568848, 20.001076698303223], 'atomic radius': [11.015857696533203, 10.15432596206665, 10.100831508636475, 13.576239109039307], 'atomic volume': [11.949167728424072, 8.738459587097168, 10.427918434143066, 13.632889747619629], 'atomic weight': [10.183413028717041, 9.64014196395874, 11.397083282470703, 11.427015781402588], 'covalent radius': [10.514074325561523, 9.394641399383545, 10.941582679748535, 11.956562995910645], 'vdw radius': [9.800463676452637, 9.21529245376587, 10.46424388885498, 13.804336071014404], 'dipole polarizability': [10.734958171844482, 10.39086389541626, 14.207355976104736, 14.054275035858154], 'electron affinity': [12.086753368377686, 9.388242721557617, 9.779114246368408, 14.187011241912842], 'electrophilicity index': [11.224877834320068, 8.167922496795654, 10.476778030395508, 13.519041061401367], 'electronegativity': [9.936093807220459, 10.196239948272705, 9.894521236419678, 14.182630062103271], 'electrons': [10.851109981536865, 9.980653285980225, 9.819206714630127, 14.225136280059814], 'neutrons': [9.413500308990479, 9.620219707489014, 9.220077753067017, 13.0865478515625], 'cooridates': [24.324893951416016, 23.782126426696777, 27.821096420288086, 22.696171760559082], 'formal charge': [21.54022789001465, 21.80484676361084, 25.49346923828125, 21.4456787109375], 'formal charge ohe': [19.7533016204834, 24.704838752746582, 23.275280952453613, 21.496562004089355], 'chiral tag': [22.99411392211914, 22.16276741027832, 26.671720504760742, 19.622788429260254], 'random': [27.85106086730957, 26.663583755493164, 27.1346435546875, 24.93829917907715] } bond_features_results = {'smart,atoms': [10.044915676116943, 8.644615173339844, 10.853336811065674, 13.183335781097412], 'pure,rdkit': [11.945857048034668, 9.031915664672852, 10.515658855438232, 13.471199989318848], 'pure,atoms': [11.971851825714111, 9.971860885620117, 11.541802883148193, 13.614596843719482], 'smart,rdkit': [9.997193813323975, 9.00916337966919, 9.385982513427734, 12.707754611968994]} import pandas as pd import seaborn as sns import matplotlib.pyplot as plt df = pd.DataFrame(columns = ['descriptor', 'accuracy']) df2 = pd.DataFrame(columns = ['descriptor', 'accuracy']) for key, value in bond_features_results.items(): temp_df = pd.DataFrame({'descriptor': [key]*len(value), 'accuracy': value}) df2 = df2.append(temp_df, ignore_index = True) for key, value in node_features_results.items(): temp_df = pd.DataFrame({'descriptor': [key]*len(value), 'accuracy': value}) df = df.append(temp_df, ignore_index = True) #print(df) my_order = df.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index my_order_2 = df2.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index fig, ax = plt.subplots(2, figsize=(8, 20)) sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax[0], order=my_order) sns.violinplot(data=df2, y="descriptor", x= "accuracy", ax=ax[1], order=my_order_2) plt.show() """## Hyperparameter selection""" !wget -nc https://www.dropbox.com/s/k6kgqag3daqv90y/hp_results.csv with open("hp_results.csv") as f: hp_results = pd.read_csv(f) with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.precision', 3, ): print(hp_results) print(hp_results.groupby(["m", "lr", "wd"])['accuracy'].mean()) print(hp_results.groupby(["m"])['accuracy'].mean()) print(hp_results.groupby(["lr"])['accuracy'].mean()) print(hp_results.groupby(["wd"])['accuracy'].mean()) """## Different models(collections of features) 2019
Top N Features ### 2019 model feautures """ paper_2019_features_results = [8.671473503112793, 14.09905195236206, 12.203859806060791, 12.20280122756958] """### Top N features""" n_feature_results = {"['neutrons']": [10.858759880065918, 7.9520745277404785, 10.061083316802979, 13.083988666534424], "['neutrons', 'atomic weight']": [10.279156684875488, 10.209141254425049, 9.95799732208252, 13.340500354766846], "['neutrons', 'atomic weight', 'covalent radius']": [10.954660892486572, 8.685662269592285, 8.889871597290039, 11.358663558959961], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius']": [11.0108060836792, 10.469686031341553, 10.534903049468994, 12.8761568069458], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index']": [10.096431255340576, 9.76186466217041, 9.827410221099854, 12.698175430297852], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity']": [10.908030033111572, 8.09628176689148, 9.25707221031189, 11.68080759048462], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume']": [9.88636827468872, 9.501356601715088, 9.424680233001709, 12.281991004943848], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius']": [12.68621015548706, 9.14622163772583, 10.40596866607666, 11.856515407562256], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons']": [10.266335487365723, 9.108633041381836, 11.96642780303955, 13.67209243774414], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num']": [10.689188957214355, 9.288278579711914, 10.050706624984741, 14.055354118347168], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity']": [10.014832019805908, 9.91117525100708, 9.540995359420776, 13.940445899963379], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number']": [10.295331001281738, 9.557802677154541, 11.178434371948242, 10.574723720550537], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability']": [10.669241905212402, 9.187427043914795, 9.547076225280762, 12.491205215454102], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe']": [10.314173221588135, 10.875812530517578, 12.00346040725708, 11.990623950958252], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence']": [9.84071683883667, 9.872222900390625, 8.903747081756592, 11.116456031799316], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe']": [11.037650108337402, 10.227884292602539, 11.569035053253174, 15.139636516571045], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization']": [9.662659645080566, 10.27301549911499, 10.229931831359863, 12.425177097320557], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing']": [12.145603656768799, 9.442389011383057, 10.04677677154541, 13.444748401641846], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic']": [10.830190181732178, 9.220112323760986, 9.297557353973389, 13.27143907546997], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe']": [10.53205156326294, 10.221848964691162, 10.479277610778809, 11.17846155166626], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge']": [10.679495811462402, 9.242655277252197, 10.48193883895874, 12.0442533493042], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge']": [9.10008192062378, 9.387939929962158, 10.72908878326416, 12.20565414428711], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag']": [11.04087209701538, 8.146693706512451, 10.116810321807861, 10.707176208496094], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag', 'cooridates']": [11.374996185302734, 10.41971206665039, 11.041788578033447, 13.181884288787842]} df = pd.DataFrame(columns = ['descriptor', 'accuracy']) for key, value in n_feature_results.items(): temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))), 'accuracy': value}) df = df.append(temp_df, ignore_index = True) fig, ax = plt.subplots(1,figsize=(15, 15)) sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax) plt.show() print(df.groupby(by=["descriptor"])["accuracy"].mean().sort_values()) """### Comparison""" df = pd.DataFrame(columns = ['descriptor', 'accuracy']) temp_df = pd.DataFrame({'descriptor': '2019 model features', 'accuracy': paper_2019_features_results}) df = df.append(temp_df, ignore_index = True) for key, value in n_feature_results.items(): temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))), 'accuracy': value}) df = df.append(temp_df, ignore_index = True) fig, ax = plt.subplots(1,figsize=(15, 15)) comparsion_df = df.loc[(df['descriptor'] =='15') | (df['descriptor'] =='2019 model features')] sns.violinplot(data=comparsion_df, y="descriptor", x= "accuracy", ax=ax) plt.show() """## Varying dataset sizes ### Flourine """ !wget -nc https://www.dropbox.com/s/m1alwawphb1chbc/fluorine_results.csv with open("fluorine_results.csv") as f: df = pd.read_csv(f) df.groupby(["model","dataset_size"]).mean() fig, ax = plt.subplots(figsize=(8, 10)) sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax) plt.show() """### Carbon with choloroform as solvent """ !wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv with open("chloroform_results.csv") as f: df = pd.read_csv(f) print(df.groupby(["model","dataset_size"]).mean()) print(df) fig, ax = plt.subplots(figsize=(8, 10)) sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax) plt.show() """### Carbon with methanol as solvent """ !wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv with open("methanol_results.csv") as f: df = pd.read_csv(f) print(df.groupby(["model","dataset_size"]).mean()) fig, ax = plt.subplots(figsize=(8, 10)) sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax) plt.show() """### Carbon with dmso as solvent """ !wget -nc https://www.dropbox.com/s/gdwv8pspayssk27/dmso_results.csv with open("dmso_results.csv") as f: df = pd.read_csv(f) print(df.groupby(["model","dataset_size"]).mean()) fig, ax = plt.subplots(figsize=(8, 10)) sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax) plt.show() """# Reproducing results Functions that provided in aforementioned results. To repeat any experiment you must: * Run the core block to load necessary helper methods. * Run the corresponding experiment block to obtain the results. ## Core block """ TEST_SPLIT= 0.8 # Install required dependencies !pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file !pip3 install mendeleev # To to access various features related to atoms import torch # PyTorch dependencies to represent graph data !pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html !pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html !pip3 install torch-geometric # Use CUDA by default if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') # Downloads the nmrshiftdb2 database if it does not yet exist in our runtime !wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm from torch_scatter import scatter_mean, scatter_add from torch_geometric.nn import MetaLayer from torch_geometric.data import Batch,Data from torch import tensor from torch_geometric.loader import DataLoader import numpy as np from rdkit import Chem from sklearn.preprocessing import OneHotEncoder import random import math import mendeleev import pandas as pd NO_GRAPH_FEATURES=128 ENCODING_NODE=64 ENCODING_EDGE=32 HIDDEN_NODE=128 HIDDEN_EDGE=64 HIDDEN_GRAPH=128 def init_weights(m): if type(m) == torch.nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) class EdgeModel(torch.nn.Module): def __init__(self): super(EdgeModel, self).__init__() self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_EDGE)).apply(init_weights) def forward(self, src, dest, edge_attr, u, batch): # source, target: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # u: [B, F_u], where B is the number of graphs. # batch: [E] with max entry B - 1. out = torch.cat([src, dest, edge_attr], 1) return self.edge_mlp(out) class NodeModel(torch.nn.Module): def __init__(self): super(NodeModel, self).__init__() self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17), LazyLinear(HIDDEN_NODE)).apply(init_weights) self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17), LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_NODE)).apply(init_weights) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) out = self.node_mlp_1(out) out = scatter_add(out, col, dim=0, dim_size=x.size(0)) out = torch.cat([x, out], dim=1) return self.node_mlp_2(out) class GlobalModel(torch.nn.Module): def __init__(self): super(GlobalModel, self).__init__() self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17), LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row,col=edge_index node_aggregate = scatter_add(x, batch, dim=0) edge_aggregate = scatter_add(edge_attr, batch[col], dim=0) out = torch.cat([node_aggregate, edge_aggregate], dim=1) return self.global_mlp(out) class GNN_FULL_CLASS(torch.nn.Module): def __init__(self, NO_MP): super(GNN_FULL_CLASS,self).__init__() #Meta Layer for Message Passing self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) #Edge Encoding MLP self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(ENCODING_EDGE)).apply(init_weights) self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(), LazyLinear(ENCODING_NODE)).apply(init_weights) self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10), LazyBatchNorm1d(), LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), LazyLinear(1)).apply(init_weights) self.no_mp = NO_MP def forward(self,dat): #Extract the data from the batch x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch # Embed the node and edge features enc_x = self.encoding_node(x) enc_ea = self.encoding_edge(ea) #Create the empty molecular graphs for feature extraction, graph level one u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float) #Message-Passing for _ in range(self.no_mp): enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc) targs = self.mlp_last(enc_x) return targs el_map={} def atom_features_default(): feature_getters = {} feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0] feature_getters["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0] feature_getters["hybridization"] = lambda atom: atom.GetHybridization() feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons feature_getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0] #feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges feature_getters["chiral tag"] = lambda atom: atom.GetChiralTag() return feature_getters def bond_feature_smart_distance_and_rdkit_type(bond): onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0] [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx())) [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx())) ex_dist = getNaiveBondLength(bond) distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance return distance+ list(onehot_encoded_bondtype) def getMendeleevElement(nr): if nr not in el_map: el_map[nr] = mendeleev.element(nr) return el_map[nr] def nmr_shift(atom): for key, value in atom.GetOwningMol().GetPropsAsDict().items(): if key.startswith("Spectrum"): for shift in value.split('|'): x = shift.split(';') if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"): return float(x[0]) return float("NaN") # We use NaN for atoms we don't want to predict shifts def bond_features_distance_only(bond): #onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0] [x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx())) [x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx())) distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance return distance#+list(onehot_encoded_bondtype) def flatten(l): ret=[] for el in l: if isinstance(el, list) or isinstance(el, np.ndarray): ret.extend(el) else: ret.append(el) return ret def turn_to_graph (molecule, atom_feature_getters= atom_features_default().values(), bond_features=bond_features_distance_only): node_features = [flatten([getter(atom) for getter in atom_feature_getters ]) for atom in molecule.GetAtoms() ] node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()] edge_features = [bond_features(bond) for bond in molecule.GetBonds()] edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()] # Bonds are not directed, so lets add the missing pair to make the graph undirected edge_index.extend([reversed(bond) for bond in edge_index]) edge_features.extend(edge_features) # Some node_features had null values in carbon data and then the long graph compilation process was stopped. if any(None in sublist for sublist in node_features): return None return Data( x=tensor(node_features, dtype=torch.float), edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(), edge_attr=tensor(edge_features, dtype=torch.float), y=tensor([[t] for t in node_targets], dtype=torch.float) ) def init_model(NO_MP, lr, wd): # Model NO_MP = NO_MP model = GNN_FULL_CLASS(NO_MP) # Optimizer LEARNING_RATE = lr optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd) # Criterion #criterion = torch.nn.MSELoss() criterion = torch.nn.L1Loss() return model, optimizer, criterion def train(model, criterion, optimizer, loader): loss_sum = 0 for batch in loader: # Forward pass and gradient descent labels = batch.y predictions = model(batch) loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)]) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() loss_sum += loss.item() return loss_sum/len(loader) def evaluate(model, criterion, loader): loss_sum = 0 with torch.no_grad(): for batch in loader: # Forward pass labels = batch.y predictions = model(batch) loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)]) loss_sum += loss.item() return loss_sum/len(loader) def chunk_into_n(lst, n): size = math.ceil(len(lst) / n) return list( map(lambda x: lst[x * size:x * size + size], list(range(n))) ) def get_data_loaders(data, split, batch_size): random.Random().shuffle(data) training_data = all_data[:int(len(all_data)*split)] testing_data = all_data[int(len(all_data)*split):] train_loader = DataLoader(training_data, batch_size = batch_size) test_loader = DataLoader(testing_data, batch_size = batch_size) return train_loader, test_loader def getBondElements(bond): a= bond.GetEndAtom().GetSymbol() b = bond.GetBeginAtom().GetSymbol() return a+b if a2.5 and fixing_epochs <200: tloss = train(model, mae, optimizer, train_loader) train_err = evaluate(model, mae, train_loader) test_err = evaluate(model, mae, test_loader) fixing_epochs+=1 preds = torch.tensor([]) labels = torch.tensor([]) with torch.no_grad(): for batch in test_loader: # Forward pass labels = torch.cat((labels, batch.y), 0) preds = torch.cat((preds,model(batch)),0) std = torch.std(torch.subtract(preds[torch.isfinite(labels)], labels[torch.isfinite(labels)])) model_results_df_2 = pd.concat([model_results_df_2,pd.DataFrame(data = {'model': [method_name], "dataset_size":[size], "mae":[evaluate(model, mae, test_loader)], "rmse": [evaluate(model, mse, test_loader)**0.5], "std": [float(std)], })], ignore_index = True,axis=0, join='outer') #model_results_df_2 = model_results_df_2.append({'model': method_name, "dataset_size":size, "mae": evaluate(model, mae, test_loader), "rmse": evaluate(model, mse, test_loader)**0.5, "std":float(std)} , ignore_index = True)