# -*- coding: utf-8 -*-
"""NMR shift prediction from small data quantities.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1yKTRjpWzR8T199eCokuJfd9Y5o2oNtPp
# Data: NMRShiftDB
## Reading data set from SD file
The nmrshiftdb2 data from https://sourceforge.net/projects/nmrshiftdb2/files/data/ has been downloaded and moved to our own domain in order to avoid the https://sourceforge.net/ madness (redirects, cookie banner, etc.) and ensure that colab can be rerun without manual steps or Google Drive dependencies by all collaborators.
"""
# Install required dependencies
!pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file
!pip3 install mendeleev # To to access various features related to atoms
import torch
# PyTorch dependencies to represent graph data
!pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip3 install torch-geometric
# Use CUDA by default
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# Downloads the nmrshiftdb2 database if it does not yet exist in our runtime
!wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine
#!wget -nc "https://sourceforge.net/projects/nmrshiftdb2/files/data/nmrshiftdb2withsignals.sd/download" #Carbon
from rdkit import Chem
supplier3d = Chem.rdmolfiles.SDMolSupplier("nmrshiftdb2withsignals.sd",True, False, True) #Flourine
#supplier3d = Chem.rdmolfiles.SDMolSupplier("download") #Carbon
print(f"In total there are {len(supplier3d)} molecules")
from rdkit.Chem import AllChem
mol= supplier3d[152]
mol
"""## Graph transformation
In order to use convolutional networks, we need to transform the molecule data into graphs.
The atoms themselves will become nodes, and the bonds between the atoms will become the edges.
"""
import mendeleev
import torch
import math
from torch import tensor
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import os
import numpy as np
import time
from sklearn.preprocessing import OneHotEncoder
import random
# One hot encoding
## Bonds
bond_idxes = np.array([Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC])
bond_idxes = bond_idxes.reshape(len(bond_idxes), 1)
onehot_encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")
onehot_encoder.fit(bond_idxes)
## Hybridization
hybridization_idxes = np.array(list(Chem.HybridizationType.names))
hybridization_idxes = hybridization_idxes.reshape(len(hybridization_idxes), 1)
hybridization_ohe = OneHotEncoder(sparse=False)
hybridization_ohe.fit(hybridization_idxes)
## Valence
valences = np.arange(1, 8);
valences = valences.reshape(len(valences), 1)
valence_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
valence_ohe.fit(valences)
## Formal Charge
fc = np.arange(-1, 1);
fc = fc.reshape(len(fc), 1)
fc_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
fc_ohe.fit(fc)
## Atomic number
atomic_nums = np.array([6,1,7,8,9,17,15,11, 16])
atomic_nums = atomic_nums.reshape(len(atomic_nums), 1)
atomic_number_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
atomic_number_ohe.fit(atomic_nums)
atomic_number_ohe.transform(np.array([[1]]))
def get_molecule_solvent(molecule):
if not molecule:
return {}
for key, value in molecule.GetPropsAsDict().items():
if key.startswith("Solvent"):
return value
return None
solvents={}
for mol in supplier3d:
a = get_molecule_solvent(mol)
if a:
if a not in solvents:
solvents[a] = 0
solvents[a] += 1
arr = [k for k, v in solvents.items() if v>10]
solvents = np.array(arr)
solvents = solvents.reshape(len(solvents), 1)
solvent_ohe = OneHotEncoder(handle_unknown="ignore",sparse=False)
solvent_ohe.fit(solvents)
#solvent_ohe.transform(np.array([[1]]))
el_map={}
def getMendeleevElement(nr):
if nr not in el_map:
el_map[nr] = mendeleev.element(nr)
return el_map[nr]
def nmr_shift(atom):
for key, value in atom.GetOwningMol().GetPropsAsDict().items():
if key.startswith("Spectrum"):
for shift in value.split('|'):
x = shift.split(';')
if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"):
return float(x[0])
return float("NaN") # We use NaN for atoms we don't want to predict shifts
def bond_features(bond):
onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
[x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
[x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
return distance+list(onehot_encoded_bondtype)
def atom_features_random(atom, molecule=None):
features = []
features.append(random.randint(0,9)) # Atomic number
return features
def atom_features_update(atom, molecule=None):
features = []
me = getMendeleevElement(atom.GetAtomicNum())
features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
features.append(me.atomic_volume) # Atomic volume
features.append(me.dipole_polarizability) # Dipole polarizability
features.append(me.electron_affinity) # Electron affinity
features.append(me.en_pauling) # Electronegativity
features.append(me.electrons) # No. of electrons
features.append(me.neutrons) # No. of neutrons
features.append(atom.GetChiralTag())
features.append(atom.IsInRing())
return features
def atom_features_2019(atom, molecule=None):
features.append(atom.GetAtomicNum())
features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
features.extend(valence_ohe.transform(np.array([[atom.GetDefaultValence()]]))[0])
features.append(atom.GetTotalValence(atom.GetAtomicNum()))
features.append(atom.GetIsAromatic())
features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0])
features.append(atom.IsInRing())
return features
def atom_features_top_single_performers(atom, molecule=None):
me = getMendeleevElement(atom.GetAtomicNum())
features = []
#features.append(atom.GetAtomicNum())
features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
features.append(me.atomic_radius or 0) # Atomic volume
features.append(me.neutrons) # Van der Waals radius
#features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
features.append(me.electron_affinity) # Electron affinity
features.append(me.en_pauling) # Electronegativity
#features.append(me.electrons) # No. of neutrons
return features
def atom_features_top_single_performers_with_solvents(atom, molecule=None):
me = getMendeleevElement(atom.GetAtomicNum())
features = []
features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
#features.append(atom.GetAtomicNum())
features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
features.append(me.atomic_radius or 0) # Atomic volume
features.append(me.neutrons) # Van der Waals radius
#features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
features.append(me.electron_affinity) # Electron affinity
features.append(me.en_pauling) # Electronegativity
features.append(me.electrons) # No. of neutrons
return features
def atom_features(atom, molecule=None):
me = getMendeleevElement(atom.GetAtomicNum())
#[x, y, z] = list(atom.GetOwningMol().GetConformer().GetAtomPosition(atom.GetIdx()))
features = []
#TBD: Do we need to encode molecule atom itself? Or is atomic number sufficient? One-hot encode?
features.extend(atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0]) # Atomic number
#features.append(atom.GetIsAromatic())
features.extend(solvent_ohe.transform(np.array([[get_molecule_solvent(molecule)]]))[0])
features.extend(hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0])
#features.extend(valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0])
features.extend(fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0])
features.append(me.atomic_radius or 0) # Atomic radius
features.append(me.atomic_volume) # Atomic volume
features.append(me.atomic_weight) # Atomic weight
features.append(me.covalent_radius) # Covalent radius
features.append(me.vdw_radius) # Van der Waals radius
features.append(me.dipole_polarizability) # Dipole polarizability
features.append(me.electron_affinity) # Electron affinity
features.append(me.electrophilicity()) # Electrophilicity index
features.append(me.en_pauling) # Electronegativity
features.append(me.electrons) # No. of electrons
features.append(me.neutrons) # No. of neutrons
#features.append(x) # X coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
#features.append(y) # Y coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
#features.append(z) # Z coordinate - TBD: Not sure this is a meaningful feature (but they had in the paper)
#features.append(0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge'))) #partial charges
features.append(atom.GetChiralTag())
features.append(atom.IsInRing())
return features
def convert_to_graph(molecule, atom_feature_constructor=atom_features):
#Chem.rdPartialCharges.ComputeGasteigerCharges(molecule)
node_features = [atom_feature_constructor(atom, molecule) for atom in molecule.GetAtoms()]
node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()]
edge_features = [bond_features(bond) for bond in molecule.GetBonds()]
edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()]
# Bonds are not directed, so lets add the missing pair to make the graph undirected
edge_index.extend([reversed(bond) for bond in edge_index])
edge_features.extend(edge_features)
# Some node_features had null values in carbon data and then the long graph compilation process was stopped.
if any(None in sublist for sublist in node_features):
return None
return Data(
x=tensor(node_features, dtype=torch.float),
edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(),
edge_attr=tensor(edge_features, dtype=torch.float),
y=tensor([[t] for t in node_targets], dtype=torch.float)
)
def scale_graph_data(latent_graph_list, scaler=None):
if scaler:
node_mean, node_std, edge_mean, edge_std = scaler
print(f"Using existing scaler: {scaler}")
else:
#Iterate through graph list to get stacked NODE and EDGE features
node_stack=[]
edge_stack=[]
for g in latent_graph_list:
node_stack.append(g.x) #Append node features
edge_stack.append(g.edge_attr) #Append edge features
node_cat=torch.cat(node_stack,dim=0)
edge_cat=torch.cat(edge_stack,dim=0)
node_mean=node_cat.mean(dim=0)
node_std=node_cat.std(dim=0,unbiased=False)
edge_mean=edge_cat.mean(dim=0)
edge_std=edge_cat.std(dim=0,unbiased=False)
#Apply zero-mean, unit variance scaling, append scaled graph to list
latent_graph_list_sc=[]
for g in latent_graph_list:
x_sc=g.x-node_mean
x_sc/=node_std
ea_sc=g.edge_attr-edge_mean
ea_sc/=edge_std
ea_sc=torch.nan_to_num(ea_sc, posinf=1.0)
x_sc=torch.nan_to_num(x_sc, posinf=1.0)
temp_graph=Data(x=x_sc,edge_index=g.edge_index,edge_attr=ea_sc, y=g.y)
latent_graph_list_sc.append(temp_graph)
scaler= (node_mean,node_std,edge_mean,edge_std)
return latent_graph_list_sc,scaler
"""# Graph Neural Network
## Separation into training set and test set
"""
TRAIN_TEST_SPLIT = 0.8
all_data = list(supplier3d)
random.Random(80).shuffle(all_data)
train_data =all_data[:int(TRAIN_TEST_SPLIT * len(supplier3d))]
test_data =all_data[int(TRAIN_TEST_SPLIT * len(supplier3d)):]
#TODO: Use training data scaler on test data.
train_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(train_data) if molecule])
test_graphs, scaler = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features) for idx, molecule in enumerate(test_data) if molecule], scaler=scaler)
#all_data = [convert_to_graph(molecule) for idx, molecule in enumerate(supplier3d) if molecule and idx not in[24,25,30,31,32]]
print(f"Converted {len(supplier3d)} molecules to {len(train_graphs) + len(test_graphs)} graphs")
print(f"Found {sum([sum([1 for shift in graph.y if not math.isnan(shift[0])]) for graph in train_graphs+test_graphs])} individual NMR shifts")
"""##2023 model"""
import torch
from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm
from torch_scatter import scatter_mean, scatter_add
from torch_geometric.nn import MetaLayer
from torch_geometric.data import Batch
NO_GRAPH_FEATURES=128
ENCODING_NODE=64
ENCODING_EDGE=32
HIDDEN_NODE=128
HIDDEN_EDGE=64
HIDDEN_GRAPH=128
def init_weights(m):
if type(m) == torch.nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class EdgeModel(torch.nn.Module):
def __init__(self):
super(EdgeModel, self).__init__()
self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_EDGE)).apply(init_weights)
def forward(self, src, dest, edge_attr, u, batch):
# source, target: [E, F_x], where E is the number of edges.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# batch: [E] with max entry B - 1.
out = torch.cat([src, dest, edge_attr], 1)
return self.edge_mlp(out)
class NodeModel(torch.nn.Module):
def __init__(self):
super(NodeModel, self).__init__()
self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(),
LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
LazyLinear(HIDDEN_NODE)).apply(init_weights)
self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_NODE)).apply(init_weights)
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row, col = edge_index
out = torch.cat([x[row], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter_add(out, col, dim=0, dim_size=x.size(0))
out = torch.cat([x, out], dim=1)
return self.node_mlp_2(out)
class GlobalModel(torch.nn.Module):
def __init__(self):
super(GlobalModel, self).__init__()
self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights)
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row,col=edge_index
node_aggregate = scatter_add(x, batch, dim=0)
edge_aggregate = scatter_add(edge_attr, batch[col], dim=0)
out = torch.cat([node_aggregate, edge_aggregate], dim=1)
return self.global_mlp(out)
class GNN_FULL_CLASS(torch.nn.Module):
def __init__(self, NO_MP):
super(GNN_FULL_CLASS,self).__init__()
#Meta Layer for Message Passing
self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
#Edge Encoding MLP
self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
LazyLinear(ENCODING_EDGE)).apply(init_weights)
self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_NODE)).apply(init_weights)
self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10),
LazyBatchNorm1d(),
LazyLinear(HIDDEN_NODE), LeakyReLU(),
LazyBatchNorm1d(),
LazyLinear(1)).apply(init_weights)
self.no_mp = NO_MP
def forward(self,dat):
#Extract the data from the batch
x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch
# Embed the node and edge features
enc_x = self.encoding_node(x)
enc_ea = self.encoding_edge(ea)
#Create the empty molecular graphs for feature extraction, graph level one
u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float)
#Message-Passing
for _ in range(self.no_mp):
enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc)
targs = self.mlp_last(enc_x)
return targs
#Additional helpful methods
def init_model(NO_MP, lr, wd):
# Model
NO_MP = NO_MP
model = GNN_FULL_CLASS(NO_MP)
# Optimizer
LEARNING_RATE = lr
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd)
# Criterion
#criterion = torch.nn.MSELoss()
criterion = torch.nn.L1Loss()
return model, optimizer, criterion
def train(model, criterion, optimizer, loader):
loss_sum = 0
for batch in loader:
# Forward pass and gradient descent
labels = batch.y
predictions = model(batch)
loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
return loss_sum/len(loader)
def evaluate(model, criterion, loader):
loss_sum = 0
with torch.no_grad():
for batch in loader:
# Forward pass
labels = batch.y
predictions = model(batch)
loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
loss_sum += loss.item()
return loss_sum/len(loader)
def chunk_into_n(lst, n):
size = math.ceil(len(lst) / n)
return list(
map(lambda x: lst[x * size:x * size + size],
list(range(n)))
)
"""## Training"""
VALIDATION_SPLITS = 4
EPOCHS = 500
BATCH_SIZE = 128
splits =chunk_into_n(train_graphs, VALIDATION_SPLITS)
split_errors=[]
for idx, split in enumerate(splits):
split_train_data = []
for s in splits:
if s!=split:
split_train_data+=s
train_loader = DataLoader(split_train_data, batch_size = BATCH_SIZE)
test_loader = DataLoader(split, batch_size = BATCH_SIZE)
model, optimizer, criterion = init_model(6,0.001,0.1)
loss_list = []
train_err_list = []
test_err_list = []
model.train()
print(f"Split {idx+1}. Training/test split size:{len(split_train_data)}/{len(split)}")
for epoch in range(EPOCHS):
tloss = train(model, criterion, optimizer, train_loader)
train_err = evaluate(model, criterion, train_loader)
test_err = evaluate(model, criterion, test_loader)
loss_list.append(tloss)
train_err_list.append(train_err)
test_err_list.append(test_err)
print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
train_err, test_err))
extra_epochs=0
#Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima.
while extra_epochs<200 and tloss>2.5:
tloss = train(model, criterion, optimizer, train_loader)
train_err = evaluate(model, criterion, train_loader)
test_err = evaluate(model, criterion, test_loader)
loss_list.append(tloss)
train_err_list.append(train_err)
test_err_list.append(test_err)
extra_epochs+1
print("\n")
split_errors.append(test_err)
print(f"Split errors: {split_errors} with average error {sum(split_errors) / VALIDATION_SPLITS}")
"""## Evaluation on test set"""
train_loader = DataLoader(train_graphs, batch_size = BATCH_SIZE)
test_loader = DataLoader(test_graphs, batch_size = BATCH_SIZE)
model, optimizer, criterion = init_model(6,0.001,0.1)
loss_list = []
train_err_list = []
test_err_list = []
model.train()
for epoch in range(EPOCHS):
tloss = train(model, criterion, optimizer, train_loader)
train_err = evaluate(model, criterion, train_loader)
test_err = evaluate(model, criterion, test_loader)
loss_list.append(tloss)
train_err_list.append(train_err)
test_err_list.append(test_err)
print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
train_err, test_err))
extra_epochs=0
#Sometimes the optimizer tries to find other local minima, which means that at E500 the solution is not yet at local minima.
while extra_epochs<200 and tloss>2.5:
tloss = train(model, criterion, optimizer, train_loader)
train_err = evaluate(model, criterion, train_loader)
test_err = evaluate(model, criterion, test_loader)
loss_list.append(tloss)
train_err_list.append(train_err)
test_err_list.append(test_err)
extra_epochs+1
print("\n")
#print(len(test_loader))
#print( criterion, test_loader)
evaluate(model, criterion, test_loader)
single = scale_graph_data([convert_to_graph(molecule, atom_feature_constructor = atom_features_top_single_performers_with_solvents) for molecule in [supplier3d[32], supplier3d[32]] if molecule])
evaluate(model, criterion, DataLoader(single, batch_size = 2))
"""## Visualisations & Analysis"""
test_loader = DataLoader(testing_data, batch_size = 128)
with torch.no_grad():
for batch in test_loader:
# Forward pass
labels = batch.y
predictions = model(batch)
a =torch.sub(predictions, labels)
for a,b,c,d in zip(predictions, labels, batch.x, a):
if d>50:
print(a,b,c[0],d)
loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
print(loss)
#loss_sum += loss.item()
#return loss_sum/len(loader)
import os
import matplotlib.pyplot as plt
def plot_loss(loss_list):
fig, axs = plt.subplots()
axs.plot(range(1, len(loss_list) + 1), loss_list, label="Loss")
axs.set_xlabel("Epoch")
axs.legend(loc=3)
axs.set_yscale('log')
axs.legend()
def plot_error(train_err_list, test_err_list):
fig, axs = plt.subplots()
axs.plot(range(1, len(train_err_list) + 1), train_err_list, label="Train Err")
axs.plot(range(1, len(test_err_list) + 1), test_err_list, label="Test Err")
axs.set_xlabel("Epoch")
axs.legend(loc=3)
axs.set_yscale('log')
axs.legend()
plot_loss(loss_list)
plot_error(train_err_list, test_err_list)
"""We can draw molecules easily for 2D coordinates, so let's download the dataset containing 2D coordinates."""
errs=[]
for i in [100,250,500,len(all_data)]:
SPLIT = 0.75 # Let's 75% for training and 25% for validation
# Shuffle all of our data, as input data might be sorted
random.Random(46).shuffle(all_data)
# Training set
training_data = all_data[:int(i*SPLIT)]
# Validation set
testing_data = all_data[int(i*SPLIT):i]
BATCH_SIZE = 128
train_loader = DataLoader(training_data, batch_size = BATCH_SIZE)
test_loader = DataLoader(testing_data, batch_size = BATCH_SIZE)
print(f"Number of train batches: {len(train_loader)}")
print(f"Number of test batches: {len(test_loader)}")
# Loading model from the paper
# Model
NO_MP = 7
model = GNN_FULL_CLASS(NO_MP)
# Optimizer
LEARNING_RATE = 0.05
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=.005)
# Criterion
#criterion = torch.nn.MSELoss()
criterion = torch.nn.L1Loss()
EPOCHS = 500
loss_list = []
train_err_list = []
test_err_list = []
model.train()
for epoch in range(EPOCHS):
tloss = train(model, criterion, optimizer, train_loader)
train_err = evaluate(model, criterion, train_loader)
test_err = evaluate(model, criterion, test_loader)
loss_list.append(tloss)
train_err_list.append(train_err)
test_err_list.append(test_err)
print('Epoch: {:03d}, Loss: {:.5f}, Train Err: {:.5f}, Test Err: {:.5f}'.format(epoch+1, tloss,
train_err, test_err))
print("\n")
test_err = evaluate(model, criterion, test_loader)
print(test_err)
errs.append(test_err)
print(errs)
"""# Existing methods
Looking at existing methods.
* Baseline for model comparison
## Hose
### Downloads & Initialization
"""
from rdkit import Chem
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
!pip install git+https://github.com/Ratsemaat/HOSE_code_generator
import hosegen
hg = hosegen.HoseGenerator()
"""### Preparation"""
def get_hose_code(molecule, distance, atom_idx):
return hg.get_Hose_codes(molecule, atom_idx, max_radius=distance)
get_hose_code(supplier3d[1],5,0)
def get_atom_indexes(molecul, element_nr):
indexes=[]
if not molecul:
return []
for atom in molecul.GetAtoms():
if atom.GetAtomicNum() == element_nr:
indexes.append(atom.GetIdx())
return indexes
mol = Chem.MolFromSmiles('C=CN(CC)CF')
print(get_atom_indexes(mol, 6))
print(get_atom_indexes(mol, 7))
print(get_atom_indexes(mol, 9))
def get_molecule_hose_codes(molecule, radius, atom_idx):
indexes = get_atom_indexes(molecule, 9)
hoses=[]
for idx in indexes:
hoses.append((idx,get_hose_code(molecule, radius, idx)))
return hoses
print(get_molecule_hose_codes(supplier3d[1], 5,9))
import random
SPLIT = 0.75 # Let's 75% for training and 25% for validation
random.seed(42)
# Shuffle all of our data, as input data might be sorted
train_idx = random.sample([i for i in range(len(supplier3d))],math.ceil(len(supplier3d)*SPLIT))
test_idx = set([i for i in range(len(supplier3d))]).difference(set(train_idx))
def get_molecule_nmr_shifts(molecule):
if not molecule:
return {}
map={}
for key, value in molecule.GetPropsAsDict().items():
if key.startswith("Spectrum"):
for shift in value.split('|'):
x = shift.split(';')
if (len(x) == 3):
map[x[2]] = float(x[0])
return map
print(get_molecule_nmr_shifts(supplier3d[1]))
def get_molecule_hose_and_nmr_shifts(molecule, distance, element_nr):
arr=[]
idxs = set(get_atom_indexes(molecule, element_nr))
shifts = get_molecule_nmr_shifts(molecule)
#Only look at atoms where shift is defiend
arr2 = set([int(k) for k in shifts.keys()])
idxs = idxs.intersection(arr2)
for idx in idxs:
hoses = get_hose_code(molecule, distance, idx)
arr.append((idx, hoses, shifts[str(idx)]))
return arr
print(get_molecule_hose_and_nmr_shifts(supplier3d[1], 5, 9))
def split_hose(hose):
return hose.replace('(',' ').replace('/',' ').replace(')',' ').split()
"""### Training"""
map={}
for id in train_idx:
for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
parts=split_hose(trio[1])
for i in range(len(parts)):
hose = "/".join(parts[:(i+1)])
if hose not in map:
map[hose] = []
map[hose].append(trio[2])
print(map)
avg_map = {}
for k,v in map.items():
avg_map[k] =sum(v)/len(v)
"""### Evaluation"""
predictions = []
labels=[]
familiar_radius = [] # for visualisation
for id in test_idx:
for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
parts=split_hose(trio[1])
for i in range(len(parts),0,-1):
hose = "/".join(parts[:(i)])
if hose in map:
predictions.append(avg_map[hose])
familiar_radius.append(i)
labels.append(trio[2])
break
print(labels)
print(predictions)
errors=[]
for i in range(len(predictions)):
errors.append(abs(labels[i] - predictions[i]))
print(sum(errors)/len(errors))
"""### Visualisations & Analysis"""
plt.hist(familiar_radius, bins=[0.5, 1.5,2.5,3.5,4.5,5.5,6.5], align="mid")
plt.title("Longest exact HOSE code length in training set")
plt.show()
errors = [[] for i in range(6)]
for i in range(len(predictions)):
errors[familiar_radius[i]-1].append(abs(labels[i] - predictions[i]))
for i in range(len(errors)):
errors[i] = sum(errors[i])/len(errors[i])
plt.plot([i for i in range(6)], errors)
plt.title("Average error rate of HOSE code")
plt.show()
SPLIT = 0.75 # Let's 75% for training and 25% for validation
hose_error_measurements=[]
for i in range(10):
err_hose=[]
random.Random(42).shuffle(all_data)
for size in [100,250,500,len(all_data)]:
train_idx = random.sample([i for i in range(size)],math.ceil(size*SPLIT))
test_idx = set([i for i in range(size)]).difference(set(train_idx))
map={}
for id in train_idx:
for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
parts=split_hose(trio[1])
for i in range(len(parts)):
hose = "/".join(parts[:(i+1)])
if hose not in map:
map[hose] = []
map[hose].append(trio[2])
avg_map = {}
for k,v in map.items():
avg_map[k] =sum(v)/len(v)
predictions = []
labels=[]
familiar_radius = [] # for visualisation
for id in test_idx:
for trio in get_molecule_hose_and_nmr_shifts(supplier3d[id], 6, 9 ):
parts=split_hose(trio[1])
for i in range(len(parts),0,-1):
hose = "/".join(parts[:(i)])
if hose in map:
predictions.append(avg_map[hose])
familiar_radius.append(i)
labels.append(trio[2])
break
errors=[]
for i in range(len(predictions)):
errors.append(abs(labels[i] - predictions[i]))
err_hose.append(sum(errors)/len(errors))
hose_error_measurements.append(err_hose)
sum_100=0
sum_250=0
sum_500=0
sum_all=0
for i in hose_error_measurements:
sum_100 += i[0]
sum_250 += i[1]
sum_500 += i[2]
sum_all += i[3]
avg_err_hose = [sum_100/len(hose_error_measurements),sum_250/len(hose_error_measurements),sum_500/len(hose_error_measurements),sum_all/len(hose_error_measurements)]
err_gnn = [66.1449966430664, 30.329145431518555, 9.740877151489258, 10.084354400634766]
print(avg_err_hose)
plt.plot([100,250,500,970],err_gnn, label="GNN model", color="red")
plt.plot([100,250,500,970],avg_err_hose, label="HOSE code model", color="green")
plt.xlabel("Number of spectra")
plt.ylabel("ppm")
plt.title("Now. Error in ppm in relation to used examples.")
plt.legend()
plt.show()
"""# Results
Since obtaining the results
## Edge and node features
Testing difderent descriptors with fixed hyperparameters(weight decay=0.005, learning rate=0.0003, NO_MP=7). When testing node features the edge_features were also fixed and when testing edge features then node features were fixed. So in other words feature results were obtained by only playing with corresponding values. Each feature was tested using 4-fold crossvalidation on train data (80% of whole dataset).
"""
node_features_results = {'ohe atomic number': [11.206879615783691, 10.606162071228027, 10.678353309631348, 14.426544666290283],
'isAromatic': [20.246356964111328, 21.289775848388672, 22.023090362548828, 21.481626510620117],
'hyb ohe': [14.719367980957031, 14.110363483428955, 17.14181661605835, 17.19037103652954],
'valence ohe': [16.09677743911743, 16.301819801330566, 18.414384841918945, 14.627246856689453],
'valence': [15.74299955368042, 15.868663311004639, 17.854907989501953, 14.021881580352783],
'inRing': [20.925933837890625, 19.979466438293457, 20.020546913146973, 19.80652141571045],
'hybridization': [15.455873012542725, 16.48819398880005, 16.359801769256592, 17.977892875671387],
'atomic num': [10.075446605682373, 10.49792766571045, 11.835340023040771, 12.882370471954346],
'atomic charge': [20.686185836791992, 22.381511688232422, 27.067391395568848, 20.001076698303223],
'atomic radius': [11.015857696533203, 10.15432596206665, 10.100831508636475, 13.576239109039307],
'atomic volume': [11.949167728424072, 8.738459587097168, 10.427918434143066, 13.632889747619629],
'atomic weight': [10.183413028717041, 9.64014196395874, 11.397083282470703, 11.427015781402588],
'covalent radius': [10.514074325561523, 9.394641399383545, 10.941582679748535, 11.956562995910645],
'vdw radius': [9.800463676452637, 9.21529245376587, 10.46424388885498, 13.804336071014404],
'dipole polarizability': [10.734958171844482, 10.39086389541626, 14.207355976104736, 14.054275035858154],
'electron affinity': [12.086753368377686, 9.388242721557617, 9.779114246368408, 14.187011241912842],
'electrophilicity index': [11.224877834320068, 8.167922496795654, 10.476778030395508, 13.519041061401367],
'electronegativity': [9.936093807220459, 10.196239948272705, 9.894521236419678, 14.182630062103271],
'electrons': [10.851109981536865, 9.980653285980225, 9.819206714630127, 14.225136280059814],
'neutrons': [9.413500308990479, 9.620219707489014, 9.220077753067017, 13.0865478515625],
'cooridates': [24.324893951416016, 23.782126426696777, 27.821096420288086, 22.696171760559082],
'formal charge': [21.54022789001465, 21.80484676361084, 25.49346923828125, 21.4456787109375],
'formal charge ohe': [19.7533016204834, 24.704838752746582, 23.275280952453613, 21.496562004089355],
'chiral tag': [22.99411392211914, 22.16276741027832, 26.671720504760742, 19.622788429260254],
'random': [27.85106086730957, 26.663583755493164, 27.1346435546875, 24.93829917907715]
}
bond_features_results = {'smart,atoms': [10.044915676116943, 8.644615173339844, 10.853336811065674, 13.183335781097412],
'pure,rdkit': [11.945857048034668, 9.031915664672852, 10.515658855438232, 13.471199989318848],
'pure,atoms': [11.971851825714111, 9.971860885620117, 11.541802883148193, 13.614596843719482],
'smart,rdkit': [9.997193813323975, 9.00916337966919, 9.385982513427734, 12.707754611968994]}
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
df2 = pd.DataFrame(columns = ['descriptor', 'accuracy'])
for key, value in bond_features_results.items():
temp_df = pd.DataFrame({'descriptor': [key]*len(value),
'accuracy': value})
df2 = df2.append(temp_df, ignore_index = True)
for key, value in node_features_results.items():
temp_df = pd.DataFrame({'descriptor': [key]*len(value),
'accuracy': value})
df = df.append(temp_df, ignore_index = True)
#print(df)
my_order = df.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index
my_order_2 = df2.groupby(by=["descriptor"])["accuracy"].mean().sort_values().index
fig, ax = plt.subplots(2, figsize=(8, 20))
sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax[0], order=my_order)
sns.violinplot(data=df2, y="descriptor", x= "accuracy", ax=ax[1], order=my_order_2)
plt.show()
"""## Hyperparameter selection"""
!wget -nc https://www.dropbox.com/s/k6kgqag3daqv90y/hp_results.csv
with open("hp_results.csv") as f:
hp_results = pd.read_csv(f)
with pd.option_context('display.max_rows', None,
'display.max_columns', None,
'display.precision', 3,
):
print(hp_results)
print(hp_results.groupby(["m", "lr", "wd"])['accuracy'].mean())
print(hp_results.groupby(["m"])['accuracy'].mean())
print(hp_results.groupby(["lr"])['accuracy'].mean())
print(hp_results.groupby(["wd"])['accuracy'].mean())
"""## Different models(collections of features)
2019
Top N Features
### 2019 model feautures
"""
paper_2019_features_results = [8.671473503112793, 14.09905195236206, 12.203859806060791, 12.20280122756958]
"""### Top N features"""
n_feature_results = {"['neutrons']": [10.858759880065918, 7.9520745277404785, 10.061083316802979, 13.083988666534424], "['neutrons', 'atomic weight']": [10.279156684875488, 10.209141254425049, 9.95799732208252, 13.340500354766846], "['neutrons', 'atomic weight', 'covalent radius']": [10.954660892486572, 8.685662269592285, 8.889871597290039, 11.358663558959961], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius']": [11.0108060836792, 10.469686031341553, 10.534903049468994, 12.8761568069458], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index']": [10.096431255340576, 9.76186466217041, 9.827410221099854, 12.698175430297852], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity']": [10.908030033111572, 8.09628176689148, 9.25707221031189, 11.68080759048462], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume']": [9.88636827468872, 9.501356601715088, 9.424680233001709, 12.281991004943848], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius']": [12.68621015548706, 9.14622163772583, 10.40596866607666, 11.856515407562256], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons']": [10.266335487365723, 9.108633041381836, 11.96642780303955, 13.67209243774414], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num']": [10.689188957214355, 9.288278579711914, 10.050706624984741, 14.055354118347168], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity']": [10.014832019805908, 9.91117525100708, 9.540995359420776, 13.940445899963379], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number']": [10.295331001281738, 9.557802677154541, 11.178434371948242, 10.574723720550537], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability']": [10.669241905212402, 9.187427043914795, 9.547076225280762, 12.491205215454102], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe']": [10.314173221588135, 10.875812530517578, 12.00346040725708, 11.990623950958252], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence']": [9.84071683883667, 9.872222900390625, 8.903747081756592, 11.116456031799316], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe']": [11.037650108337402, 10.227884292602539, 11.569035053253174, 15.139636516571045], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization']": [9.662659645080566, 10.27301549911499, 10.229931831359863, 12.425177097320557], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing']": [12.145603656768799, 9.442389011383057, 10.04677677154541, 13.444748401641846], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic']": [10.830190181732178, 9.220112323760986, 9.297557353973389, 13.27143907546997], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe']": [10.53205156326294, 10.221848964691162, 10.479277610778809, 11.17846155166626], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge']": [10.679495811462402, 9.242655277252197, 10.48193883895874, 12.0442533493042], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge']": [9.10008192062378, 9.387939929962158, 10.72908878326416, 12.20565414428711], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag']": [11.04087209701538, 8.146693706512451, 10.116810321807861, 10.707176208496094], "['neutrons', 'atomic weight', 'covalent radius', 'vdw radius', 'electrophilicity index', 'electronegativity', 'atomic volume', 'atomic radius', 'electrons', 'atomic num', 'electron affinity', 'ohe atomic number', 'dipole polarizability', 'hyb ohe', 'valence', 'valence ohe', 'hybridization', 'inRing', 'isAromatic', 'formal charge ohe', 'atomic charge', 'formal charge', 'chiral tag', 'cooridates']": [11.374996185302734, 10.41971206665039, 11.041788578033447, 13.181884288787842]}
df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
for key, value in n_feature_results.items():
temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))),
'accuracy': value})
df = df.append(temp_df, ignore_index = True)
fig, ax = plt.subplots(1,figsize=(15, 15))
sns.violinplot(data=df, y="descriptor", x= "accuracy", ax=ax)
plt.show()
print(df.groupby(by=["descriptor"])["accuracy"].mean().sort_values())
"""### Comparison"""
df = pd.DataFrame(columns = ['descriptor', 'accuracy'])
temp_df = pd.DataFrame({'descriptor': '2019 model features', 'accuracy': paper_2019_features_results})
df = df.append(temp_df, ignore_index = True)
for key, value in n_feature_results.items():
temp_df = pd.DataFrame({'descriptor': str(len(key.split(","))),
'accuracy': value})
df = df.append(temp_df, ignore_index = True)
fig, ax = plt.subplots(1,figsize=(15, 15))
comparsion_df = df.loc[(df['descriptor'] =='15') | (df['descriptor'] =='2019 model features')]
sns.violinplot(data=comparsion_df, y="descriptor", x= "accuracy", ax=ax)
plt.show()
"""## Varying dataset sizes
### Flourine
"""
!wget -nc https://www.dropbox.com/s/m1alwawphb1chbc/fluorine_results.csv
with open("fluorine_results.csv") as f:
df = pd.read_csv(f)
df.groupby(["model","dataset_size"]).mean()
fig, ax = plt.subplots(figsize=(8, 10))
sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
plt.show()
"""### Carbon with choloroform as solvent
"""
!wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv
with open("chloroform_results.csv") as f:
df = pd.read_csv(f)
print(df.groupby(["model","dataset_size"]).mean())
print(df)
fig, ax = plt.subplots(figsize=(8, 10))
sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
plt.show()
"""### Carbon with methanol as solvent """
!wget -nc https://www.dropbox.com/s/0nmj08b1hpn8jc0/methanol_results.csv
with open("methanol_results.csv") as f:
df = pd.read_csv(f)
print(df.groupby(["model","dataset_size"]).mean())
fig, ax = plt.subplots(figsize=(8, 10))
sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
plt.show()
"""### Carbon with dmso as solvent """
!wget -nc https://www.dropbox.com/s/gdwv8pspayssk27/dmso_results.csv
with open("dmso_results.csv") as f:
df = pd.read_csv(f)
print(df.groupby(["model","dataset_size"]).mean())
fig, ax = plt.subplots(figsize=(8, 10))
sns.violinplot(data=df, y="mae", x="dataset_size", hue="model" , ax=ax)
plt.show()
"""# Reproducing results
Functions that provided in aforementioned results.
To repeat any experiment you must:
* Run the core block to load necessary helper methods.
* Run the corresponding experiment block to obtain the results.
## Core block
"""
TEST_SPLIT= 0.8
# Install required dependencies
!pip3 install rdkit # Used to read and parse nmrshiftdb2 SD file
!pip3 install mendeleev # To to access various features related to atoms
import torch
# PyTorch dependencies to represent graph data
!pip3 install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip3 install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip3 install torch-geometric
# Use CUDA by default
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# Downloads the nmrshiftdb2 database if it does not yet exist in our runtime
!wget -nc https://www.dropbox.com/s/n122zxawpxii5b7/nmrshiftdb2withsignals.sd#Flourine
from torch.nn import Sequential as Seq, LazyLinear, LeakyReLU, LazyBatchNorm1d, LayerNorm
from torch_scatter import scatter_mean, scatter_add
from torch_geometric.nn import MetaLayer
from torch_geometric.data import Batch,Data
from torch import tensor
from torch_geometric.loader import DataLoader
import numpy as np
from rdkit import Chem
from sklearn.preprocessing import OneHotEncoder
import random
import math
import mendeleev
import pandas as pd
NO_GRAPH_FEATURES=128
ENCODING_NODE=64
ENCODING_EDGE=32
HIDDEN_NODE=128
HIDDEN_EDGE=64
HIDDEN_GRAPH=128
def init_weights(m):
if type(m) == torch.nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class EdgeModel(torch.nn.Module):
def __init__(self):
super(EdgeModel, self).__init__()
self.edge_mlp = Seq(LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(HIDDEN_EDGE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_EDGE)).apply(init_weights)
def forward(self, src, dest, edge_attr, u, batch):
# source, target: [E, F_x], where E is the number of edges.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# batch: [E] with max entry B - 1.
out = torch.cat([src, dest, edge_attr], 1)
return self.edge_mlp(out)
class NodeModel(torch.nn.Module):
def __init__(self):
super(NodeModel, self).__init__()
self.node_mlp_1 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(),
LazyLinear(HIDDEN_NODE), LeakyReLU(), LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
LazyLinear(HIDDEN_NODE)).apply(init_weights)
self.node_mlp_2 = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
LazyLinear(HIDDEN_NODE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_NODE)).apply(init_weights)
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row, col = edge_index
out = torch.cat([x[row], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter_add(out, col, dim=0, dim_size=x.size(0))
out = torch.cat([x, out], dim=1)
return self.node_mlp_2(out)
class GlobalModel(torch.nn.Module):
def __init__(self):
super(GlobalModel, self).__init__()
self.global_mlp = Seq(LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(), #torch.nn.Dropout(0.17),
LazyLinear(HIDDEN_GRAPH), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(NO_GRAPH_FEATURES)).apply(init_weights)
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row,col=edge_index
node_aggregate = scatter_add(x, batch, dim=0)
edge_aggregate = scatter_add(edge_attr, batch[col], dim=0)
out = torch.cat([node_aggregate, edge_aggregate], dim=1)
return self.global_mlp(out)
class GNN_FULL_CLASS(torch.nn.Module):
def __init__(self, NO_MP):
super(GNN_FULL_CLASS,self).__init__()
#Meta Layer for Message Passing
self.meta = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
#Edge Encoding MLP
self.encoding_edge=Seq(LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
LazyLinear(ENCODING_EDGE), LeakyReLU(), LazyBatchNorm1d(),
LazyLinear(ENCODING_EDGE)).apply(init_weights)
self.encoding_node = Seq(LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_NODE), LeakyReLU(),LazyBatchNorm1d(),
LazyLinear(ENCODING_NODE)).apply(init_weights)
self.mlp_last = Seq(LazyLinear(HIDDEN_NODE), LeakyReLU(),#torch.nn.Dropout(0.10),
LazyBatchNorm1d(),
LazyLinear(HIDDEN_NODE), LeakyReLU(),
LazyBatchNorm1d(),
LazyLinear(1)).apply(init_weights)
self.no_mp = NO_MP
def forward(self,dat):
#Extract the data from the batch
x, ei, ea, u, btc = dat.x, dat.edge_index, dat.edge_attr, dat.y, dat.batch
# Embed the node and edge features
enc_x = self.encoding_node(x)
enc_ea = self.encoding_edge(ea)
#Create the empty molecular graphs for feature extraction, graph level one
u=torch.full(size=(x.size()[0], 1), fill_value=0.1, dtype=torch.float)
#Message-Passing
for _ in range(self.no_mp):
enc_x, enc_ea, u = self.meta(x = enc_x, edge_index = ei, edge_attr = enc_ea, u = u, batch = btc)
targs = self.mlp_last(enc_x)
return targs
el_map={}
def atom_features_default():
feature_getters = {}
feature_getters["ohe atomic number"] = lambda atom:atomic_number_ohe.transform(np.array([[atom.GetAtomicNum()]]))[0] # Atomic number
feature_getters["hyb ohe"] = lambda atom: hybridization_ohe.transform(np.array([[atom.GetHybridization().name]]))[0]
feature_getters["valence ohe"] = lambda atom: valence_ohe.transform(np.array([[atom.GetTotalValence()]]))[0]
feature_getters["hybridization"] = lambda atom: atom.GetHybridization()
feature_getters["atomic radius"]= lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_radius or 0 # Atomic radius
feature_getters["atomic volume"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_volume # Atomic volume
feature_getters["atomic weight"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).atomic_weight # Atomic weight
feature_getters["dipole polarizability"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).dipole_polarizability # Dipole polarizability
feature_getters["electron affinity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electron_affinity # Electron affinity
feature_getters["electronegativity"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).en_pauling # Electronegativity
feature_getters["electrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).electrons # No. of electrons
feature_getters["neutrons"] = lambda atom: getMendeleevElement(atom.GetAtomicNum()).neutrons # No. of neutrons
feature_getters["formal charge ohe"] = lambda atom: fc_ohe.transform(np.array([[atom.GetFormalCharge()]]))[0]
#feature_getters["gaisteigerCharge"] = lambda atom: 0 if np.isfinite(float(atom.GetProp('_GasteigerCharge'))) else float(atom.GetProp('_GasteigerCharge')) #partial charges
feature_getters["chiral tag"] = lambda atom: atom.GetChiralTag()
return feature_getters
def bond_feature_smart_distance_and_rdkit_type(bond):
onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
[x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
[x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
ex_dist = getNaiveBondLength(bond)
distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)) - ex_dist] # Distance
return distance+ list(onehot_encoded_bondtype)
def getMendeleevElement(nr):
if nr not in el_map:
el_map[nr] = mendeleev.element(nr)
return el_map[nr]
def nmr_shift(atom):
for key, value in atom.GetOwningMol().GetPropsAsDict().items():
if key.startswith("Spectrum"):
for shift in value.split('|'):
x = shift.split(';')
if (len(x) == 3 and x[2] == f"{atom.GetIdx()}"):
return float(x[0])
return float("NaN") # We use NaN for atoms we don't want to predict shifts
def bond_features_distance_only(bond):
#onehot_encoded_bondtype = onehot_encoder.transform(np.array([[bond.GetBondType()]]))[0]
[x1, y1, z1] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetBeginAtomIdx()))
[x2, y2, z2] = list(bond.GetOwningMol().GetConformer().GetAtomPosition(bond.GetEndAtomIdx()))
distance = [(math.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2))] # Distance
return distance#+list(onehot_encoded_bondtype)
def flatten(l):
ret=[]
for el in l:
if isinstance(el, list) or isinstance(el, np.ndarray):
ret.extend(el)
else:
ret.append(el)
return ret
def turn_to_graph (molecule, atom_feature_getters= atom_features_default().values(), bond_features=bond_features_distance_only):
node_features = [flatten([getter(atom) for getter in atom_feature_getters ]) for atom in molecule.GetAtoms() ]
node_targets = [nmr_shift(atom) for atom in molecule.GetAtoms()]
edge_features = [bond_features(bond) for bond in molecule.GetBonds()]
edge_index = [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in molecule.GetBonds()]
# Bonds are not directed, so lets add the missing pair to make the graph undirected
edge_index.extend([reversed(bond) for bond in edge_index])
edge_features.extend(edge_features)
# Some node_features had null values in carbon data and then the long graph compilation process was stopped.
if any(None in sublist for sublist in node_features):
return None
return Data(
x=tensor(node_features, dtype=torch.float),
edge_index=tensor(edge_index, dtype=torch.long).t().contiguous(),
edge_attr=tensor(edge_features, dtype=torch.float),
y=tensor([[t] for t in node_targets], dtype=torch.float)
)
def init_model(NO_MP, lr, wd):
# Model
NO_MP = NO_MP
model = GNN_FULL_CLASS(NO_MP)
# Optimizer
LEARNING_RATE = lr
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd)
# Criterion
#criterion = torch.nn.MSELoss()
criterion = torch.nn.L1Loss()
return model, optimizer, criterion
def train(model, criterion, optimizer, loader):
loss_sum = 0
for batch in loader:
# Forward pass and gradient descent
labels = batch.y
predictions = model(batch)
loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
return loss_sum/len(loader)
def evaluate(model, criterion, loader):
loss_sum = 0
with torch.no_grad():
for batch in loader:
# Forward pass
labels = batch.y
predictions = model(batch)
loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
loss_sum += loss.item()
return loss_sum/len(loader)
def chunk_into_n(lst, n):
size = math.ceil(len(lst) / n)
return list(
map(lambda x: lst[x * size:x * size + size],
list(range(n)))
)
def get_data_loaders(data, split, batch_size):
random.Random().shuffle(data)
training_data = all_data[:int(len(all_data)*split)]
testing_data = all_data[int(len(all_data)*split):]
train_loader = DataLoader(training_data, batch_size = batch_size)
test_loader = DataLoader(testing_data, batch_size = batch_size)
return train_loader, test_loader
def getBondElements(bond):
a= bond.GetEndAtom().GetSymbol()
b = bond.GetBeginAtom().GetSymbol()
return a+b if a2.5 and fixing_epochs <200:
tloss = train(model, mae, optimizer, train_loader)
train_err = evaluate(model, mae, train_loader)
test_err = evaluate(model, mae, test_loader)
fixing_epochs+=1
preds = torch.tensor([])
labels = torch.tensor([])
with torch.no_grad():
for batch in test_loader:
# Forward pass
labels = torch.cat((labels, batch.y), 0)
preds = torch.cat((preds,model(batch)),0)
std = torch.std(torch.subtract(preds[torch.isfinite(labels)], labels[torch.isfinite(labels)]))
model_results_df_2 = pd.concat([model_results_df_2,pd.DataFrame(data = {'model': [method_name], "dataset_size":[size],
"mae":[evaluate(model, mae, test_loader)],
"rmse": [evaluate(model, mse, test_loader)**0.5],
"std": [float(std)],
})], ignore_index = True,axis=0, join='outer')
#model_results_df_2 = model_results_df_2.append({'model': method_name, "dataset_size":size, "mae": evaluate(model, mae, test_loader), "rmse": evaluate(model, mse, test_loader)**0.5, "std":float(std)} , ignore_index = True)