diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..5195c1705731d972c425d3cfc0977f8240b3f3bd --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.9.7 + +WORKDIR /app +COPY requirements.txt . +RUN pip install -r requirements.txt +# preload models +RUN python -c '\ +from transformers import BartForConditionalGeneration, AutoTokenizer;\ +AutoTokenizer.from_pretrained("ibm/materials.selfies-ted");\ +BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")' +COPY . . + +CMD ["python", "app.py"] \ No newline at end of file diff --git a/Dockerfile-conda b/Dockerfile-conda new file mode 100644 index 0000000000000000000000000000000000000000..98d0dfefb3c6b01820abeaea921c62a8246e8e6a --- /dev/null +++ b/Dockerfile-conda @@ -0,0 +1,13 @@ +FROM condaforge/miniforge3 + +WORKDIR /app +SHELL ["/bin/bash", "-i", "-c"] +RUN apt-get update && \ + apt-get install -y build-essential libxrender1 libxext-dev +RUN conda create --name fm4m python=3.9.7 +RUN conda activate fm4m +COPY requirements.txt . +RUN pip install -r requirements.txt +COPY . . + +CMD ["python", "app.py"] \ No newline at end of file diff --git a/README.md b/README.md index 119afc463ff2dc915b31145bf2b2489d393dbc0c..ef50b227222b4d50684a8c88b96950448e2aeaa5 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- -title: Fm4m -emoji: 👁 -colorFrom: pink -colorTo: purple +title: Fix Fm4m Kit +emoji: 🐢 +colorFrom: indigo +colorTo: blue sdk: gradio sdk_version: 5.4.0 app_file: app.py diff --git a/app.py b/app.py index aeb07d471df4bcc6959c758f4a63f8cbbafd03ea..97d96b99fd8f6f9bccb0d1b90af7db7c2fd93af5 100644 --- a/app.py +++ b/app.py @@ -1,142 +1,103 @@ import gradio as gr -from huggingface_hub import InferenceClient import matplotlib.pyplot as plt -from PIL import Image -from rdkit.Chem import Descriptors, QED, Draw -from rdkit.Chem.Crippen import MolLogP +import numpy as np +import os import pandas as pd -from rdkit.Contrib.SA_Score import sascorer -from rdkit.Chem import DataStructs, AllChem -from transformers import BartForConditionalGeneration, AutoTokenizer, AutoModel -from transformers.modeling_outputs import BaseModelOutput +import re import selfies as sf -from rdkit import Chem import torch -import numpy as np -import umap -import pickle import xgboost as xgb -from sklearn.svm import SVR -from sklearn.linear_model import LinearRegression +from PIL import Image +from rdkit import Chem, RDLogger +from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw +from rdkit.Chem.Crippen import MolLogP +from rdkit.Contrib.SA_Score import sascorer from sklearn.kernel_ridge import KernelRidge -import json - -import os +from sklearn.linear_model import LinearRegression +from sklearn.svm import SVR +from transformers import BartForConditionalGeneration, AutoTokenizer +from transformers.modeling_outputs import BaseModelOutput os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1" -# my_theme = gr.Theme.from_hub("ysharma/steampunk") -# my_theme = gr.themes.Glass() - -""" -# カスタムテーマ設定 -theme = gr.themes.Default().set( - body_background_fill="#000000", # 背景色を黒に設定 - text_color="#FFFFFF", # テキスト色を白に設定 -) -""" -""" -import sys -sys.path.append("models") -sys.path.append("../models") -sys.path.append("../")""" - - -# Get the current file's directory -base_dir = os.path.dirname(__file__) -print("Base Dir : ", base_dir) - import models.fm4m as fm4m +RDLogger.logger().setLevel(RDLogger.ERROR) + # Function to display molecule image from SMILES def smiles_to_image(smiles): mol = Chem.MolFromSmiles(smiles) - if mol: - img = Draw.MolToImage(mol) - return img - return None - - -# Function to get canonical SMILES -def get_canonical_smiles(smiles): - mol = Chem.MolFromSmiles(smiles) - if mol: - return Chem.MolToSmiles(mol, canonical=True) - return None + return Draw.MolToImage(mol) if mol else None # Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths) smiles_image_mapping = { - "Mol 1": {"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", "image": "img/img1.png"}, + "Mol 1": { + "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", + "image": "img/img1.png", + }, # Example SMILES for ethanol - "Mol 2": {"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", "image": "img/img2.png"}, + "Mol 2": { + "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", + "image": "img/img2.png", + }, # Example SMILES for butane - "Mol 3": {"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1", - "image": "img/img3.png"}, # Example SMILES for ethylamine - "Mol 4": {"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", "image": "img/img4.png"}, + "Mol 3": { + "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1", + "image": "img/img3.png", + }, # Example SMILES for ethylamine + "Mol 4": { + "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", + "image": "img/img4.png", + }, # Example SMILES for diethyl ether - "Mol 5": {"smiles": "C=CCS[C@@H](C)CC(=O)OCC", "image": "img/img5.png"} # Example SMILES for chloroethane + "Mol 5": { + "smiles": "C=CCS[C@@H](C)CC(=O)OCC", + "image": "img/img5.png", + }, # Example SMILES for chloroethane } datasets = [" ", "BACE", "ESOL", "Load Custom Dataset"] -models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"] +models_enabled = [ + "SELFIES-TED", + "MHG-GED", + "MolFormer", + "SMI-TED", + "Mordred", + "MorganFingerprint", +] fusion_available = ["Concat"] -global log_df -log_df = pd.DataFrame(columns=["Selected Models", "Dataset", "Task", "Result"]) - - -def log_selection(models, dataset, task_type, result, log_df): - # Append the new entry to the DataFrame - new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_type, "Result": result} - updated_log_df = log_df.append(new_entry, ignore_index=True) - return updated_log_df - # Function to handle evaluation and logging -def save_rep(models, dataset, task_type, eval_output): - return -def evaluate_and_log(models, dataset, task_type, eval_output): +def evaluate_and_log(models, dataset, task_type, eval_output, state): task_dic = {'Classification': 'CLS', 'Regression': 'RGR'} - result = f"{eval_output}"#display_eval(models, dataset, task_type, fusion_type=None) + result = f"{eval_output}" result = result.replace(" Score", "") - new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_dic[task_type], "Result": result} + new_entry = { + "Selected Models": str(models), + "Dataset": dataset, + "Task": task_dic[task_type], + "Result": result, + } new_entry_df = pd.DataFrame([new_entry]) - log_df = pd.read_csv('log.csv', index_col=0) - log_df = pd.concat([new_entry_df, log_df]) - - log_df.to_csv('log.csv') - - return log_df - - -try: - log_df = pd.read_csv('log.csv', index_col=0) -except: - log_df = pd.DataFrame({"":[], - 'Selected Models': [], - 'Dataset': [], - 'Task': [], - 'Result': [] - }) - csv_file_path = 'log.csv' - log_df.to_csv(csv_file_path, index=False) + state["log_df"] = pd.concat([new_entry_df, state["log_df"]]) + return state["log_df"] # Load images for selection def load_image(path): try: - return Image.open(smiles_image_mapping[path]["image"])# Image.1open(path) + return Image.open(smiles_image_mapping[path]["image"]) except: pass - # Function to handle image selection def handle_image_selection(image_key): smiles = smiles_image_mapping[image_key]["smiles"] @@ -160,59 +121,55 @@ def calculate_tanimoto(smiles1, smiles2): mol1 = Chem.MolFromSmiles(smiles1) mol2 = Chem.MolFromSmiles(smiles2) if mol1 and mol2: - # fp1 = FingerprintMols.FingerprintMol(mol1) - # fp2 = FingerprintMols.FingerprintMol(mol2) fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2) fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2) return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2) return None -#with open("models/selfies_model/bart-2908.pickle", "rb") as input_file: -# gen_model, gen_tokenizer = pickle.load(input_file) - gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted") def generate(latent_vector, mask): encoder_outputs = BaseModelOutput(latent_vector) - decoder_output = gen_model.generate(encoder_outputs=encoder_outputs, attention_mask=mask, - max_new_tokens=64, do_sample=True, top_k=5, top_p=0.95, num_return_sequences=1) + decoder_output = gen_model.generate( + encoder_outputs=encoder_outputs, + attention_mask=mask, + max_new_tokens=64, + do_sample=True, + top_k=5, + top_p=0.95, + num_return_sequences=1, + ) selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True) - outs = [] - for i in selfies: - try: - print("Generated SELFIES : ", i) - decoded = sf.decoder(i.replace("] [", "][")) - print("Generated SMILES : ", decoded) - outs.append(decoded) - #except selfies.exceptions.DecoderError: - # print(f"Error decoding SELFIES string: {i}") - except: - pass - - #outs.append(sf.decoder(i.replace("] [", "]["))) - return outs + return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies] def perturb_latent(latent_vecs, noise_scale=0.5): - modified_vec = torch.tensor(np.random.uniform(0, 1, latent_vecs.shape) * noise_scale, - dtype=torch.float32) + latent_vecs - return modified_vec + return ( + torch.tensor( + np.random.uniform(0, 1, latent_vecs.shape) * noise_scale, + dtype=torch.float32, + ) + + latent_vecs + ) def encode(selfies): - encoding = gen_tokenizer(selfies, return_tensors='pt', max_length=128, truncation=True, padding='max_length') + encoding = gen_tokenizer( + selfies, + return_tensors='pt', + max_length=128, + truncation=True, + padding='max_length', + ) input_ids = encoding['input_ids'] attention_mask = encoding['attention_mask'] - outputs = gen_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask) + outputs = gen_model.model.encoder( + input_ids=input_ids, attention_mask=attention_mask + ) model_output = outputs.last_hidden_state - - """input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float() - sum_embeddings = torch.sum(model_output * input_mask_expanded, 1) - sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) - model_output = sum_embeddings / sum_mask""" return model_output, attention_mask @@ -227,8 +184,13 @@ def generate_canonical(smiles): noise = i / 10 perturbed_latent = perturb_latent(latent_vec, noise_scale=noise) gen = generate(perturbed_latent, mask) - gen_mol = Chem.MolToSmiles(Chem.MolFromSmiles(gen[0])) - if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break + mol = Chem.MolFromSmiles(gen[0]) + if mol: + gen_mol = Chem.MolToSmiles(mol) + if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): + break + else: + print('Abnormal molecule:', gen[0]) if gen_mol: # Calculate properties for ref and gen molecules @@ -240,9 +202,20 @@ def generate_canonical(smiles): # Prepare the table with ref mol and gen mol data = { "Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"], - "Reference Mol": [ref_properties[0], ref_properties[1], ref_properties[2], ref_properties[3], - tanimoto_similarity], - "Generated Mol": [gen_properties[0], gen_properties[1], gen_properties[2], gen_properties[3], ""] + "Reference Mol": [ + ref_properties[0], + ref_properties[1], + ref_properties[2], + ref_properties[3], + tanimoto_similarity, + ], + "Generated Mol": [ + gen_properties[0], + gen_properties[1], + gen_properties[2], + gen_properties[3], + "", + ], } df = pd.DataFrame(data) @@ -255,7 +228,7 @@ def generate_canonical(smiles): # Function to display evaluation score -def display_eval(selected_models, dataset, task_type, downstream, fusion_type): +def display_eval(selected_models, dataset, task_type, downstream, fusion_type, state): result = None try: @@ -270,72 +243,87 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type): downstream_model = downstream_model.rstrip() params = None - - - try: if not selected_models: return "Please select at least one enabled model." - if task_type == "Classification": - global roc_auc, fpr, tpr, x_batch, y_batch - elif task_type == "Regression": - global RMSE, y_batch_test, y_prob - if len(selected_models) > 1: if task_type == "Classification": - #result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models, - # downstream_model="XGBClassifier", - # dataset=dataset.lower()) if downstream_model == "Default Settings": downstream_model = "DefaultClassifier" params = None - result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models, - downstream_model=downstream_model, - params = params, - dataset=dataset) - elif task_type == "Regression": - #result, RMSE, y_batch_test, y_prob = fm4m.multi_modal(model_list=selected_models, - # downstream_model="XGBRegressor", - # dataset=dataset.lower()) + ( + result, + state["roc_auc"], + state["fpr"], + state["tpr"], + state["x_batch"], + state["y_batch"], + ) = fm4m.multi_modal( + model_list=selected_models, + downstream_model=downstream_model, + params=params, + dataset=dataset, + ) + elif task_type == "Regression": if downstream_model == "Default Settings": downstream_model = "DefaultRegressor" params = None - result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models, - downstream_model=downstream_model, - params=params, - dataset=dataset) + ( + result, + state["RMSE"], + state["y_batch_test"], + state["y_prob"], + state["x_batch"], + state["y_batch"], + ) = fm4m.multi_modal( + model_list=selected_models, + downstream_model=downstream_model, + params=params, + dataset=dataset, + ) else: if task_type == "Classification": - #result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0], - # downstream_model="XGBClassifier", - # dataset=dataset.lower()) if downstream_model == "Default Settings": downstream_model = "DefaultClassifier" params = None - result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0], - downstream_model=downstream_model, - params=params, - dataset=dataset) + ( + result, + state["roc_auc"], + state["fpr"], + state["tpr"], + state["x_batch"], + state["y_batch"], + ) = fm4m.single_modal( + model=selected_models[0], + downstream_model=downstream_model, + params=params, + dataset=dataset, + ) elif task_type == "Regression": - #result, RMSE, y_batch_test, y_prob = fm4m.single_modal(model=selected_models[0], - # downstream_model="XGBRegressor", - # dataset=dataset.lower()) - if downstream_model == "Default Settings": downstream_model = "DefaultRegressor" params = None - result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.single_modal(model=selected_models[0], - downstream_model=downstream_model, - params=params, - dataset=dataset) + ( + result, + state["RMSE"], + state["y_batch_test"], + state["y_prob"], + state["x_batch"], + state["y_batch"], + ) = fm4m.single_modal( + model=selected_models[0], + downstream_model=downstream_model, + params=params, + dataset=dataset, + ) if result == None: result = "Data & Model Setting is incorrect" @@ -345,23 +333,15 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type): # Function to handle plot display -def display_plot(plot_type): +def display_plot(plot_type, state): fig, ax = plt.subplots() if plot_type == "Latent Space": - global x_batch, y_batch + x_batch, y_batch = state.get("x_batch"), state.get("y_batch") ax.set_title("T-SNE Plot") - # reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False) - # features_umap = reducer.fit_transform(x_batch[:500]) - # x = y_batch.values[:500] - # index_0 = [index for index in range(len(x)) if x[index] == 0] - # index_1 = [index for index in range(len(x)) if x[index] == 1] - class_0 = x_batch # features_umap[index_0] - class_1 = y_batch # features_umap[index_1] - - """with open("latent_multi_bace.pkl", "rb") as f: - class_0, class_1 = pickle.load(f) - """ + class_0 = x_batch + class_1 = y_batch + plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1') plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0') @@ -370,10 +350,16 @@ def display_plot(plot_type): ax.set_title('Dataset Distribution') elif plot_type == "ROC-AUC": - global roc_auc, fpr, tpr + roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr") ax.set_title("ROC-AUC Curve") try: - ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})') + ax.plot( + fpr, + tpr, + color='darkorange', + lw=2, + label=f'ROC curve (area = {roc_auc:.4f})', + ) ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.05]) @@ -385,7 +371,11 @@ def display_plot(plot_type): ax.legend(loc='lower right') elif plot_type == "Parity Plot": - global RMSE, y_batch_test, y_prob + RMSE, y_batch_test, y_prob = ( + state.get("RMSE"), + state.get("y_batch_test"), + state.get("y_prob"), + ) ax.set_title("Parity plot") # change format @@ -394,7 +384,12 @@ def display_plot(plot_type): print(y_prob) y_batch_test = np.array(y_batch_test, dtype=float) y_prob = np.array(y_prob, dtype=float) - ax.scatter(y_batch_test, y_prob, color="blue", label=f"Predicted vs Actual (RMSE: {RMSE:.4f})") + ax.scatter( + y_batch_test, + y_prob, + color="blue", + label=f"Predicted vs Actual (RMSE: {RMSE:.4f})", + ) min_val = min(min(y_batch_test), min(y_prob)) max_val = max(max(y_batch_test), max(y_prob)) ax.plot([min_val, max_val], [min_val, max_val], 'r-') @@ -407,10 +402,6 @@ def display_plot(plot_type): print(y_batch_test) print(y_prob) - - - - ax.set_xlabel('Actual Values') ax.set_ylabel('Predicted Values') @@ -429,13 +420,25 @@ predefined_datasets = { # Function to load a predefined dataset from the local path def load_predefined_dataset(dataset_name): val = predefined_datasets.get(dataset_name) - try: file_path = val.split(",")[0] - except:file_path=False + try: + file_path = val.split(",")[0] + except: + file_path = False if file_path: df = pd.read_csv(file_path) - return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns)), f"{dataset_name.lower()}" - return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]), f"Dataset not found" + return ( + df.head(), + gr.update(choices=list(df.columns)), + gr.update(choices=list(df.columns)), + f"{dataset_name.lower()}", + ) + return ( + pd.DataFrame(), + gr.update(choices=[]), + gr.update(choices=[]), + f"Dataset not found", + ) # Function to display the head of the uploaded CSV file @@ -443,7 +446,11 @@ def display_csv_head(file): if file is not None: # Load the CSV file into a DataFrame df = pd.read_csv(file.name) - return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns)) + return ( + df.head(), + gr.update(choices=list(df.columns)), + gr.update(choices=list(df.columns)), + ) return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]) @@ -451,28 +458,54 @@ def display_csv_head(file): def handle_dataset_selection(selected_dataset): if selected_dataset == "Custom Dataset": # Show file upload fields for train and test datasets if "Custom Dataset" is selected - return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( - visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) + return ( + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=True), + ) else: - return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update( - visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + ) # Function to select input and output columns and display a message -def select_columns(input_column, output_column, train_data, test_data,dataset_name): +def select_columns(input_column, output_column, train_data, test_data, dataset_name): if input_column and output_column: return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}" return "Please select both input and output columns." -def set_dataname(dataset_name, dataset_selector ): + +def set_dataname(dataset_name, dataset_selector): if dataset_selector == "Custom Dataset": return f"{dataset_name}" return f"{dataset_selector}" + # Function to create model based on user input -def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None): +def create_model( + model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None +): if model_name == "XGBClassifier": - model = xgb.XGBClassifier(objective='binary:logistic',eval_metric= 'auc', max_depth=max_depth, n_estimators=n_estimators, alpha=alpha) + model = xgb.XGBClassifier( + objective='binary:logistic', + eval_metric='auc', + max_depth=max_depth, + n_estimators=n_estimators, + alpha=alpha, + ) elif model_name == "SVR": model = SVR(degree=degree, kernel=kernel) elif model_name == "Kernel Ridge": @@ -486,224 +519,339 @@ def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degr return "Model not supported." return f"{model_name} * {model.get_params()}" -def model_selector(model_name): - # Dynamically return the appropriate hyperparameter components based on the selected model - if model_name == "XGBClassifier": - return ( - gr.Slider(1, 10, label="max_depth"), - gr.Slider(50, 500, label="n_estimators"), - gr.Slider(0.1, 10.0, step=0.1, label="alpha") - ) - elif model_name == "SVR": - return ( - gr.Slider(1, 5, label="degree"), - gr.Dropdown(["rbf", "poly", "linear"], label="kernel") - ) - elif model_name == "Kernel Ridge": - return ( - gr.Slider(0.1, 10.0, step=0.1, label="alpha"), - gr.Slider(1, 5, label="degree"), - gr.Dropdown(["rbf", "poly", "linear"], label="kernel") - ) - elif model_name == "Linear Regression": - return () # No hyperparameters for Linear Regression - else: - return () - # Define the Gradio layout -# with gr.Blocks(theme=my_theme) as demo: with gr.Blocks() as demo: + log_df = pd.DataFrame( + {"": [], 'Selected Models': [], 'Dataset': [], 'Task': [], 'Result': []} + ) + state = gr.State({"log_df": log_df}) with gr.Row(): # Left Column with gr.Column(): - gr.HTML(''' + gr.HTML( + ''' <div style="background-color: #6A8EAE; color: #FFFFFF; padding: 10px;"> <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3> </div> - ''') - # gr.Markdown("## Data & Model Setting") - #dataset_dropdown = gr.Dropdown(choices=datasets, label="Select Dat") - + ''' + ) # Dropdown menu for predefined datasets including "Custom Dataset" option - dataset_selector = gr.Dropdown(label="Select Dataset", - choices=list(predefined_datasets.keys()) + ["Custom Dataset"]) + dataset_selector = gr.Dropdown( + label="Select Dataset", + choices=list(predefined_datasets.keys()) + ["Custom Dataset"], + ) # Display the message for selected columns - selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=False) + selected_columns_message = gr.Textbox( + label="Selected Columns Info", visible=False + ) with gr.Accordion("Dataset Settings", open=True): # File upload options for custom dataset (train and test) dataset_name = gr.Textbox(label="Dataset Name", visible=False) - train_file = gr.File(label="Upload Custom Train Dataset", file_types=[".csv"], visible=False) - train_display = gr.Dataframe(label="Train Dataset Preview (First 5 Rows)", visible=False, interactive=False) + train_file = gr.File( + label="Upload Custom Train Dataset", + file_types=[".csv"], + visible=False, + ) + train_display = gr.Dataframe( + label="Train Dataset Preview (First 5 Rows)", + visible=False, + interactive=False, + ) - test_file = gr.File(label="Upload Custom Test Dataset", file_types=[".csv"], visible=False) - test_display = gr.Dataframe(label="Test Dataset Preview (First 5 Rows)", visible=False, interactive=False) + test_file = gr.File( + label="Upload Custom Test Dataset", + file_types=[".csv"], + visible=False, + ) + test_display = gr.Dataframe( + label="Test Dataset Preview (First 5 Rows)", + visible=False, + interactive=False, + ) # Predefined dataset displays - predefined_display = gr.Dataframe(label="Predefined Dataset Preview (First 5 Rows)", visible=False, - interactive=False) - - + predefined_display = gr.Dataframe( + label="Predefined Dataset Preview (First 5 Rows)", + visible=False, + interactive=False, + ) # Dropdowns for selecting input and output columns for the custom dataset - input_column_selector = gr.Dropdown(label="Select Input Column", choices=[], visible=False) - output_column_selector = gr.Dropdown(label="Select Output Column", choices=[], visible=False) - - #selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=True) + input_column_selector = gr.Dropdown( + label="Select Input Column", choices=[], visible=False + ) + output_column_selector = gr.Dropdown( + label="Select Output Column", choices=[], visible=False + ) # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets - dataset_selector.change(handle_dataset_selection, - inputs=dataset_selector, - outputs=[dataset_name, train_file, train_display, test_file, test_display, predefined_display, - input_column_selector, output_column_selector]) + dataset_selector.change( + handle_dataset_selection, + inputs=dataset_selector, + outputs=[ + dataset_name, + train_file, + train_display, + test_file, + test_display, + predefined_display, + input_column_selector, + output_column_selector, + ], + ) # When a predefined dataset is selected, load its head and update column selectors - dataset_selector.change(load_predefined_dataset, - inputs=dataset_selector, - outputs=[predefined_display, input_column_selector, output_column_selector, selected_columns_message]) + dataset_selector.change( + load_predefined_dataset, + inputs=dataset_selector, + outputs=[ + predefined_display, + input_column_selector, + output_column_selector, + selected_columns_message, + ], + ) # When a custom train file is uploaded, display its head and update column selectors - train_file.change(display_csv_head, inputs=train_file, - outputs=[train_display, input_column_selector, output_column_selector]) + train_file.change( + display_csv_head, + inputs=train_file, + outputs=[ + train_display, + input_column_selector, + output_column_selector, + ], + ) # When a custom test file is uploaded, display its head - test_file.change(display_csv_head, inputs=test_file, - outputs=[test_display, input_column_selector, output_column_selector]) + test_file.change( + display_csv_head, + inputs=test_file, + outputs=[ + test_display, + input_column_selector, + output_column_selector, + ], + ) - dataset_selector.change(set_dataname, - inputs=[dataset_name, dataset_selector], - outputs=dataset_name) + dataset_selector.change( + set_dataname, + inputs=[dataset_name, dataset_selector], + outputs=dataset_name, + ) # Update the selected columns information when dropdown values are changed - input_column_selector.change(select_columns, - inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name], - outputs=selected_columns_message) - - output_column_selector.change(select_columns, - inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name], - outputs=selected_columns_message) + input_column_selector.change( + select_columns, + inputs=[ + input_column_selector, + output_column_selector, + train_file, + test_file, + dataset_name, + ], + outputs=selected_columns_message, + ) - model_checkbox = gr.CheckboxGroup(choices=models_enabled, label="Select Model") + output_column_selector.change( + select_columns, + inputs=[ + input_column_selector, + output_column_selector, + train_file, + test_file, + dataset_name, + ], + outputs=selected_columns_message, + ) - # Add disabled checkboxes for GNN and FNN - # gnn_checkbox = gr.Checkbox(label="GNN (Disabled)", value=False, interactive=False) - # fnn_checkbox = gr.Checkbox(label="FNN (Disabled)", value=False, interactive=False) + model_checkbox = gr.CheckboxGroup( + choices=models_enabled, label="Select Model" + ) - task_radiobutton = gr.Radio(choices=["Classification", "Regression"], label="Task Type") + task_radiobutton = gr.Radio( + choices=["Classification", "Regression"], label="Task Type" + ) ####### adding hyper parameter tuning ########### - model_name = gr.Dropdown(["Default - Auto", "XGBClassifier", "SVR", "Kernel Ridge", "Linear Regression"], label="Select Downstream Model") + model_name = gr.Dropdown( + [ + "Default - Auto", + "XGBClassifier", + "SVR", + "Kernel Ridge", + "Linear Regression", + ], + label="Select Downstream Model", + ) with gr.Accordion("Downstream Hyperparameter Settings", open=True): # Create placeholders for hyperparameter components - max_depth = gr.Slider(1, 20, step=1,visible=False, label="max_depth") - n_estimators = gr.Slider(100, 5000, step=100, visible=False, label="n_estimators") + max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth") + n_estimators = gr.Slider( + 100, 5000, step=100, visible=False, label="n_estimators" + ) alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha") - degree = gr.Slider(1, 20, step=1,visible=False, label="degree") - kernel = gr.Dropdown(choices=["rbf", "poly", "linear"], visible=False, label="kernel") + degree = gr.Slider(1, 20, step=1, visible=False, label="degree") + kernel = gr.Dropdown( + choices=["rbf", "poly", "linear"], visible=False, label="kernel" + ) # Output textbox output = gr.Textbox(label="Loaded Parameters") - # Dynamically show relevant hyperparameters based on selected model def update_hyperparameters(model_name): if model_name == "XGBClassifier": - return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( - visible=False), gr.update(visible=False) + return ( + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + ) elif model_name == "SVR": - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update( - visible=True), gr.update(visible=True) + return ( + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=True), + ) elif model_name == "Kernel Ridge": - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update( - visible=True), gr.update(visible=True) + return ( + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=True), + ) elif model_name == "Linear Regression": - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update( - visible=False), gr.update(visible=False) + return ( + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + ) elif model_name == "Default - Auto": - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update( - visible=False), gr.update(visible=False) - + return ( + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + ) # When model is selected, update which hyperparameters are visible - model_name.change(update_hyperparameters, inputs=[model_name], - outputs=[max_depth, n_estimators, alpha, degree, kernel]) + model_name.change( + update_hyperparameters, + inputs=[model_name], + outputs=[max_depth, n_estimators, alpha, degree, kernel], + ) # Submit button to create the model with selected hyperparameters submit_button = gr.Button("Create Downstream Model") - # Function to handle model creation based on input parameters def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel): if model_name == "XGBClassifier": - return create_model(model_name, max_depth=max_depth, n_estimators=n_estimators, alpha=alpha) + return create_model( + model_name, + max_depth=max_depth, + n_estimators=n_estimators, + alpha=alpha, + ) elif model_name == "SVR": return create_model(model_name, degree=degree, kernel=kernel) elif model_name == "Kernel Ridge": - return create_model(model_name, alpha=alpha, degree=degree, kernel=kernel) + return create_model( + model_name, alpha=alpha, degree=degree, kernel=kernel + ) elif model_name == "Linear Regression": return create_model(model_name) elif model_name == "Default - Auto": return create_model(model_name) # When the submit button is clicked, run the on_submit function - submit_button.click(on_submit, inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel], - outputs=output) + submit_button.click( + on_submit, + inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel], + outputs=output, + ) ###### End of hyper param tuning ######### fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type") - - eval_button = gr.Button("Train downstream model") - #eval_button.style(css_class="custom-button-left") # Middle Column with gr.Column(): - gr.HTML(''' + gr.HTML( + ''' <div style="background-color: #8F9779; color: #FFFFFF; padding: 10px;"> <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3> </div> - ''') - # gr.Markdown("## Downstream task Result") + ''' + ) eval_output = gr.Textbox(label="Train downstream model") - plot_radio = gr.Radio(choices=["ROC-AUC", "Parity Plot", "Latent Space"], label="Select Plot Type") - plot_output = gr.Plot(label="Visualization")#, height=250, width=250) - - #download_rep = gr.Button("Download representation") + plot_radio = gr.Radio( + choices=["ROC-AUC", "Parity Plot", "Latent Space"], + label="Select Plot Type", + ) + plot_output = gr.Plot(label="Visualization") create_log = gr.Button("Store log") - log_table = gr.Dataframe(value=log_df, label="Log of Selections and Results", interactive=False) - - eval_button.click(display_eval, - inputs=[model_checkbox, selected_columns_message, task_radiobutton, output, fusion_radiobutton], - outputs=eval_output) - - plot_radio.change(display_plot, inputs=plot_radio, outputs=plot_output) - + log_table = gr.Dataframe( + value=log_df, label="Log of Selections and Results", interactive=False + ) + + eval_button.click( + display_eval, + inputs=[ + model_checkbox, + selected_columns_message, + task_radiobutton, + output, + fusion_radiobutton, + state, + ], + outputs=eval_output, + ) + + plot_radio.change( + display_plot, inputs=[plot_radio, state], outputs=plot_output + ) # Function to gather selected models def gather_selected_models(*models): selected = [model for model in models if model] return selected - - create_log.click(evaluate_and_log, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output], - outputs=log_table) - #download_rep.click(save_rep, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output], - # outputs=None) - + create_log.click( + evaluate_and_log, + inputs=[ + model_checkbox, + dataset_name, + task_radiobutton, + eval_output, + state, + ], + outputs=log_table, + ) # Right Column with gr.Column(): - gr.HTML(''' + gr.HTML( + ''' <div style="background-color: #D2B48C; color: #FFFFFF; padding: 10px;"> <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3> </div> - ''') - # gr.Markdown("## Molecular Generation") + ''' + ) smiles_input = gr.Textbox(label="Input SMILES String") image_display = gr.Image(label="Molecule Image", height=250, width=250) # Show images for selection @@ -712,24 +860,32 @@ with gr.Blocks() as demo: choices=list(smiles_image_mapping.keys()), label="Select from sample molecules", value=None, - #item_images=[load_image(smiles_image_mapping[key]["image"]) for key in smiles_image_mapping.keys()] ) image_selector.change(load_image, image_selector, image_display) generate_button = gr.Button("Generate") - gen_image_display = gr.Image(label="Generated Molecule Image", height=250, width=250) + gen_image_display = gr.Image( + label="Generated Molecule Image", height=250, width=250 + ) generated_output = gr.Textbox(label="Generated Output") property_table = gr.Dataframe(label="Molecular Properties Comparison") - - # Handle image selection - image_selector.change(handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display]) - smiles_input.change(smiles_to_image, inputs=smiles_input, outputs=image_display) + image_selector.change( + handle_image_selection, + inputs=image_selector, + outputs=[smiles_input, image_display], + ) + smiles_input.change( + smiles_to_image, inputs=smiles_input, outputs=image_display + ) # Generate button to display canonical SMILES and molecule image - generate_button.click(generate_canonical, inputs=smiles_input, - outputs=[property_table, generated_output, gen_image_display]) + generate_button.click( + generate_canonical, + inputs=smiles_input, + outputs=[property_table, generated_output, gen_image_display], + ) if __name__ == "__main__": - demo.launch(share=True) + demo.launch(server_name="0.0.0.0") diff --git a/data/lce/test.csv b/data/lce/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..95272b1fce743d785b366bce48b49270184452b7 --- /dev/null +++ b/data/lce/test.csv @@ -0,0 +1,31 @@ +smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE +C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0.0,O,0.0,O,0.0,O,0.0,1.629 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,1.085 +COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0.0,O,0.0,O,0.0,2.056 +COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],,O,0.0,O,0.0,1.658 +C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0.0,1.638 +C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.276 +O1CCOC1,0.368,COCCOC,0.547,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.076,CSi(C)(C)([N+]).C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.008,O,0.0,O,0.0,1.569 +COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0.0,O,0.0,O,0.0,2.268 +C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,1.602 +C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0.0,1.678 +O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0.0,O,0.0,2.0 +C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0.0,O,0.0,O,0.0,0.921 +C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0.0,O,0.0,O,0.0,1.301 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0.0,O,0.0,0.854 +C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0.0,O,0.0,1.108 +O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0.0,O,0.0,O,0.0,1.523 +CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0.0,O,0.0,O,0.0,O,0.0,1.921 +CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0.0,O,0.0,O,0.0,O,0.0,1.602 +O1CCOC1,0.375,COCCOC,0.557,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.061,O,0.0,1.523 +COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0.0,O,0.0,O,0.0,2.155 +C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.26 +CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0.0,O,0.0,O,0.0,2.155 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.225 +COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155 +COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0.0,O,0.0,O,0.0,2.155 +O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0.0,O,0.0,O,0.0,2.301 +COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0.0,O,0.0,O,0.0,O,0.0,1.991 +COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,2.301 +C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0.0,O,0.0,O,0.0,1.398 +COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0.0,O,0.0,O,0.0,1.268 diff --git a/data/lce/test_data.csv b/data/lce/test_data.csv new file mode 100644 index 0000000000000000000000000000000000000000..ddbd100e66ad449ecfaf7026b091fb84ded3fce8 --- /dev/null +++ b/data/lce/test_data.csv @@ -0,0 +1,14 @@ +smiles1,conc1,mol1,smiles2,conc2,mol2,smiles3,conc3,mol3,smiles4,conc4,mol4,smiles5,conc5,mol5,smiles6,conc6,LCE_Predicted,LCE +C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.187,1.094 +COCCOC,0.596,59.5609428,COCCOCCOCCOCCOC,0.281,28.07124115,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.124,12.36781605,O,0,0,O,0,0,O,0,1.691,1.384 +C1COC(=O)O1,0.285,28.50894036,C1C(OC(=O)O1)F,0.261,26.07552384,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.228,22.82322096,COC(=O)OC,0.226,22.59231484,O,0,0,O,0,1.508,1.468 +COCCOC,0.434,43.4423376,COCCOCCOCCOCCOC,0.205,20.47449683,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.361,36.08316557,O,0,0,O,0,0,O,0,1.882,1.71 +C1C(OC(=O)O1)F,0.187,18.72872664,COC(=O)OC,0.162,16.22691423,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.109,10.92850826,FC(F)C(F)(F)COC(F)(F)C(F)F,0.541,54.11585087,O,0,0,O,0,2.103,1.832 +C1COC(=O)O1,0.134,13.35070843,C1C(OC(=O)O1)F,0.122,12.2111419,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.107,10.72028474,COC(=O)OC,0.106,10.57995858,FC(F)C(F)(F)COC(F)(F)C(F)F,0.531,53.13790635,O,0,2.077,2.104 +COCCOC,0.096,9.614613177,COCCOCCOCCOCCOC,0.045,4.53139444,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.12,12.01491409,C1COCO1,0.143,14.28400162,FC(F)C(F)(F)COC(F)(F)C(F)F,0.596,59.55507668,O,0,2.211,2.274 +C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].F[P-](F)(F)(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.17,1.071 +C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.077,1.166 +C1COC(=O)O1,0.519,51.85215842,COC(=O)OC,0.411,41.09097965,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.918492083,[Li+].[N+](=O)([O-])[O-],0.001,0.138369842,O,0,0,O,0,1.19,1.335 +C1COC(=O)O1,0.513,51.33049845,COC(=O)OC,0.407,40.6775828,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.9173773,C1=COC(=O)O1,0.011,1.07454145,O,0,0,O,0,1.114,1.129 +COCCOC,0.53,53.00533987,COCCOCCOCCOCCOC,0.25,24.98156691,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.22,22.01309322,O,0,0,O,0,0,O,0,1.758,1.501 +COCCOC,0.477,47.74974224,COCCOCCOCCOCCOC,0.225,22.50458884,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.297,29.74566892,O,0,0,O,0,0,O,0,1.821,1.663 \ No newline at end of file diff --git a/data/lce/train.csv b/data/lce/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..3ba4d26d7016d2390f934922abf3cd650f734da9 --- /dev/null +++ b/data/lce/train.csv @@ -0,0 +1,121 @@ +smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE +C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,1.155 +C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.046 +O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0.0,O,0.0,O,0.0,O,0.0,1.569 +C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0.0,O,0.0,O,0.0,0.886 +COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0.0,O,0.0,O,0.0,O,0.0,1.367 +COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,2.301 +C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0.0,O,0.0,O,0.0,O,0.0,1.489 +COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0.0,O,0.0,O,0.0,1.244 +C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0.0,O,0.0,0.745 +COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0.0,O,0.0,O,0.0,1.292 +CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0.0,O,0.0,O,0.0,2.301 +O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,O,0.0,1.745 +COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,1.745 +C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0.0,O,0.0,1.076 +C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.854 +C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0.0,1.678 +FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155 +CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222 +COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,O,0.0,1.638 +CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0.0,O,0.0,O,0.0,O,0.0,2.0 +COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,1.854 +O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0.0,O,0.0,O,0.0,1.959 +C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0.0,O,0.0,O,0.0,O,0.0,1.587 +CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0.0,O,0.0,O,0.0,0.699 +C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0.0,O,0.0,2.097 +C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0.0,O,0.0,O,0.0,1.59 +C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252 +COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,1.337 +C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0.0,O,0.0,O,0.0,O,0.0,1.377 +C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0.0,O,0.0,O,0.0,O,0.0,1.544 +CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0.0,O,0.0,O,0.0,O,0.0,2.097 +COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0.0,O,0.0,O,0.0,1.215 +COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0.0,O,0.0,O,0.0,O,0.0,1.222 +C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.194 +O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0.0,O,0.0,1.824 +C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.333 +O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0.0,O,0.0,1.824 +COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0.0,O,0.0,O,0.0,2.051 +COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0.0,O,0.0,O,0.0,1.444 +O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0.0,O,0.0,1.854 +CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0.0,O,0.0,O,0.0,O,0.0,2.046 +C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0.0,O,0.0,1.301 +C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222 +O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,1.903 +COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.561 +C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0.0,O,0.0,1.735 +FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0.0,O,0.0,O,0.0,O,0.0,2.301 +C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.498 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0.0,O,0.0,0.745 +O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0.0,O,0.0,O,0.0,1.824 +CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,O,0.0,2.0 +O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0.0,O,0.0,O,0.0,1.456 +COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0.0,O,0.0,O,0.0,O,0.0,1.301 +COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0.0,O,0.0,O,0.0,1.678 +C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0.0,O,0.0,O,0.0,O,0.0,1.646 +O1CCOC1,0.397,COCCOC,0.589,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.012,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.002,O,0.0,1.301 +C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0.0,O,0.0,O,0.0,2.046 +C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0.0,O,0.0,O,0.0,0.788 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.373 +O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0.0,O,0.0,O,0.0,O,0.0,1.602 +CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,1.854 +COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0.0,O,0.0,O,0.0,O,0.0,2.097 +O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0.0,O,0.0,O,0.0,O,0.0,1.699 +FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155 +CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0.0,O,0.0,O,0.0,O,0.0,2.208 +COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,1.77 +CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0.0,O,0.0,O,0.0,O,0.0,0.824 +C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0.0,O,0.0,0.924 +CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,2.097 +COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0.0,O,0.0,O,0.0,O,0.0,2.108 +CC1COC(=O)O1,0.922,[LI+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0.0,O,0.0,O,0.0,O,0.0,0.712 +C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0.0,O,0.0,1.081 +C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0.0,O,0.0,1.319 +COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0.0,O,0.0,O,0.0,1.62 +C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0.0,O,0.0,2.222 +C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.777 +CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0.0,O,0.0,O,0.0,O,0.0,2.018 +COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0.0,O,0.0,O,0.0,O,0.0,1.886 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0.0,O,0.0,0.699 +CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0.0,O,0.0,O,0.0,O,0.0,1.569 +C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0.0,O,0.0,O,0.0,1.523 +COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155 +C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.488 +O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0.0,2.046 +C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0.0,O,0.0,O,0.0,O,0.0,1.41 +COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0.0,O,0.0,O,0.0,2.222 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.559 +COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.301 +CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,O,0.0,1.672 +C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538 +CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0.0,O,0.0,O,0.0,O,0.0,1.796 +C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.355 +C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0.0,O,0.0,O,0.0,1.523 +COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0.0,O,0.0,1.78 +O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0.0,O,0.0,O,0.0,O,0.0,1.456 +O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0.0,O,0.0,O,0.0,O,0.0,1.745 +O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0.0,O,0.0,O,0.0,1.967 +COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0.0,O,0.0,O,0.0,O,0.0,2.097 +COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0.0,O,0.0,O,0.0,O,0.0,1.143 +O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0.0,O,0.0,1.523 +COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0.0,O,0.0,O,0.0,2.301 +CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,2.155 +C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,1.301 +COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.222 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0.0,O,0.0,0.699 +COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,1.495 +C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0.0,O,0.0,O,0.0,2.155 +C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0.0,O,0.0,1.921 +COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0.0,O,0.0,1.886 +CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,2.046 +COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,O,0.0,1.745 +C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0.0,O,0.0,O,0.0,O,0.0,1.633 +C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0.0,O,0.0,O,0.0,2.097 +FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.301 +C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0.0,O,0.0,1.108 +C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0.0,O,0.0,O,0.0,1.62 +CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.959 +C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,O,0.0,1.013 +C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0.0,1.824 +O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0.0,O,0.0,O,0.0,1.921 diff --git a/data/lce/train_data.csv b/data/lce/train_data.csv new file mode 100644 index 0000000000000000000000000000000000000000..26cdcb3434b884dde32d05a77c2f112c72214680 --- /dev/null +++ b/data/lce/train_data.csv @@ -0,0 +1,148 @@ +smiles1,conc1,smiles2,conc2,smiles3,conc3,smiles4,conc4,smiles5,conc5,smiles6,conc6,LCE +CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0,O,0,O,0,0.699 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0,O,0,0.699 +FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0,O,0,O,0,O,0,2.301 +FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.301 +CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0,O,0,O,0,2.155 +COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155 +CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,2.155 +O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0,O,0,O,0,1.967 +COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0,O,0,O,0,O,0,1.991 +C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0,O,0,O,0,O,0,1.646 +COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],0.035,O,0,O,0,1.658 +CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,O,0,1.672 +C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252 +C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0,O,0,O,0,2.155 +COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0,O,0,O,0,2.155 +COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0,O,0,O,0,2.155 +FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155 +FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155 +CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0,O,0,O,0,O,0,2.208 +C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0,O,0,2.222 +CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.054,O,0,O,0,O,0,2.222 +C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.029,O,0,O,0,O,0,2.222 +COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0,O,0,O,0,2.222 +COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.222 +COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0,O,0,O,0,2.268 +CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0,O,0,O,0,2.301 +COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,2.301 +COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0,O,0,O,0,2.301 +O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0,O,0,O,0,2.301 +COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,2.301 +COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.301 +O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0,O,0,2 +CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,O,0,2 +CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0,O,0,O,0,O,0,2.018 +CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0,O,0,O,0,O,0,2.046 +C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0,O,0,O,0,2.046 +O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0,2.046 +CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,2.046 +COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0,O,0,O,0,2.051 +COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0,O,0,O,0,2.056 +CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,2.097 +COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0,O,0,O,0,O,0,2.097 +C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0,O,0,O,0,2.097 +C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0,O,0,2.097 +CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0,O,0,O,0,O,0,2.097 +COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0,O,0,O,0,O,0,2.097 +COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0,O,0,O,0,O,0,2.108 +COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155 +CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0,O,0,O,0,O,0,2 +COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0,O,0,O,0,1.678 +C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0,1.678 +C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0,1.678 +O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0,O,0,O,0,O,0,1.699 +C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0,O,0,1.735 +O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,O,0,1.745 +COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,1.745 +COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,O,0,1.745 +O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0,O,0,O,0,O,0,1.745 +COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,1.77 +COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0,O,0,1.78 +CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0,O,0,O,0,O,0,1.796 +C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0,1.824 +O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0,O,0,O,0,1.824 +O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0,O,0,1.824 +O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0,O,0,1.824 +COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,1.854 +O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0,O,0,1.854 +CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,1.854 +COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0,O,0,1.886 +COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0,O,0,O,0,O,0,1.886 +O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,1.903 +O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0,O,0,O,0,1.921 +C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0,O,0,1.921 +CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0,O,0,O,0,O,0,1.921 +O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0,O,0,O,0,1.959 +CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.959 +C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0,O,0,O,0,O,0,1.377 +C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0,O,0,O,0,1.398 +C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0,O,0,O,0,O,0,1.41 +COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0,O,0,O,0,1.444 +O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0,O,0,O,0,O,0,1.456 +O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0,O,0,O,0,1.456 +C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.488 +C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0,O,0,O,0,O,0,1.489 +COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,1.495 +C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.498 +C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0,O,0,O,0,1.523 +O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0,O,0,O,0,1.523 +O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0,O,0,1.523 +C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0,O,0,O,0,1.523 +C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538 +C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0,O,0,O,0,O,0,1.544 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.559 +COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.561 +CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0,O,0,O,0,O,0,1.569 +O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0,O,0,O,0,O,0,1.569 +C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0,O,0,O,0,O,0,1.587 +C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0,O,0,O,0,1.59 +C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,1.602 +CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0,O,0,O,0,O,0,1.602 +O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0,O,0,O,0,O,0,1.602 +C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0,O,0,O,0,1.62 +COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0,O,0,O,0,1.62 +C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0,O,0,O,0,O,0,1.629 +C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0,O,0,O,0,O,0,1.633 +COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,O,0,1.638 +C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0,1.638 +C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.26 +COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0,O,0,O,0,1.268 +C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.276 +COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0,O,0,O,0,1.292 +C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0,O,0,1.301 +COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0,O,0,O,0,O,0,1.301 +C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0,O,0,O,0,1.301 +C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,1.301 +C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0,O,0,1.319 +C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.333 +COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,1.337 +C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.355 +COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0,O,0,O,0,O,0,1.367 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.373 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0,O,0,0.699 +CC1COC(=O)O1,0.922,[Li+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0,O,0,O,0,O,0,0.712 +C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0,O,0,0.745 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0,O,0,0.745 +C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.777 +C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0,O,0,O,0,0.788 +CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0,O,0,O,0,O,0,0.824 +C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0,O,0,0.854 +C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.854 +C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0,O,0,O,0,0.886 +C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0,O,0,O,0,0.921 +C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0,O,0,0.924 +C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,O,0,1.013 +C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.046 +C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0,O,0,1.076 +C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0,O,0,1.081 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,1.085 +C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0,O,0,1.108 +C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0,O,0,1.108 +COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0,O,0,O,0,O,0,1.143 +C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,1.155 +C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.194 +COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0,O,0,O,0,1.215 +COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0,O,0,O,0,O,0,1.222 +C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.225 +COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0,O,0,O,0,1.244 \ No newline at end of file diff --git a/models/.DS_Store b/models/.DS_Store index fd17398928f0c1b0fa649a92a2a26b9146b6c77c..3ac199b56b75f6545409adfd7db92d3027f7c5a1 100644 Binary files a/models/.DS_Store and b/models/.DS_Store differ diff --git a/models/.gitattributes b/models/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..7aa4044c51cb3d662ba09fbc6be3c5a681e8e99f --- /dev/null +++ b/models/.gitattributes @@ -0,0 +1,3 @@ +*.csv filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.pdf filter=lfs diff=lfs merge=lfs -text diff --git a/models/__pycache__/fm4m.cpython-310.pyc b/models/__pycache__/fm4m.cpython-310.pyc index fb711a24de31394268d64abfb60c2aee40d5d4e6..9253b88bbbcb2344a74db4a55fa11242a7c8cca3 100644 Binary files a/models/__pycache__/fm4m.cpython-310.pyc and b/models/__pycache__/fm4m.cpython-310.pyc differ diff --git a/models/fm4m.py b/models/fm4m.py index 31c9012fd76669b23e2f3ee8e169eb9e2306b16c..15d98be8fb2f261a2a68cd340ae81fd03f0982e5 100644 --- a/models/fm4m.py +++ b/models/fm4m.py @@ -25,9 +25,17 @@ from sklearn.preprocessing import MinMaxScaler import torch from transformers import AutoTokenizer, AutoModel -from .selfies_model.load import SELFIES as bart -from .mhg_model import load as mhg -from .smi_ted.smi_ted_light.load import load_smi_ted +import sys +sys.path.append("models/") + +from models.selfies_ted.load import SELFIES as bart +from models.mhg_model import load as mhg +from models.smi_ted.smi_ted_light.load import load_smi_ted + +import mordred +from mordred import Calculator, descriptors +from rdkit import Chem +from rdkit.Chem import AllChem datasets = {} models = {} @@ -48,7 +56,7 @@ def avail_models_data(): models = [{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality", "Timestamp": "2024-06-21 12:32:20"}, - {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"}, + {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"}, {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"}, {"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}] @@ -58,8 +66,10 @@ def avail_models(raw=False): models = [{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model"}, {"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"}, - {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality"}, + {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality"}, {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"}, + {"Name": "Mordred", "Model Name": "Mordred","Description": "Baseline: A descriptor-calculation software application that can calculate more than 1800 two- and three-dimensional descriptors"}, + {"Name": "MorganFingerprint", "Model Name": "MorganFingerprint","Description": "Baseline: Circular atom environments based descriptor"} ] @@ -70,12 +80,22 @@ def avail_models(raw=False): return models -def avail_downstream_models(): +def avail_downstream_models(raw=False): global downstream_models - with open("downstream_models.json", "r") as outfile: - downstream_models = json.load(outfile) - return downstream_models + downstream_models = [{"Name": "XGBClassifier", "Task Type": "Classfication"}, + {"Name": "DefaultClassifier", "Task Type": "Classfication"}, + {"Name": "SVR", "Task Type": "Regression"}, + {"Name": "Kernel Ridge", "Task Type": "Regression"}, + {"Name": "Linear Regression", "Task Type": "Regression"}, + {"Name": "DefaultRegressor", "Task Type": "Regression"}, + ] + + if raw: return downstream_models + else: + return pd.DataFrame(downstream_models) + + def avail_datasets(): global datasets @@ -178,13 +198,15 @@ def update_downstream_model_list(list_model): avail_models_data() + + def get_representation(train_data,test_data,model_type, return_tensor=True): alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"} if model_type in alias.keys(): model_type = alias[model_type] if model_type == "mhg": - model = mhg.load("models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") + model = mhg.load("../models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") with torch.no_grad(): train_emb = model.encode(train_data) x_batch = torch.stack(train_emb) @@ -196,7 +218,6 @@ def get_representation(train_data,test_data,model_type, return_tensor=True): x_batch_test = pd.DataFrame(x_batch_test) - elif model_type == "bart": model = bart() model.load() @@ -204,7 +225,7 @@ def get_representation(train_data,test_data,model_type, return_tensor=True): x_batch_test = model.encode(test_data, return_tensor=return_tensor) elif model_type == "smi-ted": - model = load_smi_ted(folder='./models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt') + model = load_smi_ted(folder='../models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt') with torch.no_grad(): x_batch = model.encode(train_data, return_torch=return_tensor) x_batch_test = model.encode(test_data, return_torch=return_tensor) @@ -237,35 +258,78 @@ def get_representation(train_data,test_data,model_type, return_tensor=True): if not return_tensor: x_batch = pd.DataFrame(x_batch) x_batch_test = pd.DataFrame(x_batch_test) - + + elif model_type == 'Mordred': + all_data = train_data + test_data + calc = Calculator(descriptors, ignore_3D=True) + mol_list = [Chem.MolFromSmiles(sm) for sm in all_data] + x_all = calc.pandas(mol_list) + print (f'original mordred fv dim: {x_all.shape}') + + for j in x_all.columns: + for k in range(len(x_all[j])): + i = x_all.loc[k, j] + if type(i) is mordred.error.Missing or type(i) is mordred.error.Error: + x_all.loc[k, j] = np.nan + + x_all.dropna(how="any", axis = 1, inplace=True) + print (f'Nan excluded mordred fv dim: {x_all.shape}') + + x_batch = x_all.iloc[:len(train_data)] + x_batch_test = x_all.iloc[len(train_data):] + # print(f'x_batch: {len(x_batch)}, x_batch_test: {len(x_batch_test)}') + + elif model_type == 'MorganFingerprint': + params = {'radius':2, 'nBits':1024} + + mol_train = [Chem.MolFromSmiles(sm) for sm in train_data] + mol_test = [Chem.MolFromSmiles(sm) for sm in test_data] + + x_batch = [] + for mol in mol_train: + info = {} + fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info) + vector = list(fp) + x_batch.append(vector) + x_batch = pd.DataFrame(x_batch) + + x_batch_test = [] + for mol in mol_test: + info = {} + fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info) + vector = list(fp) + x_batch_test.append(vector) + x_batch_test = pd.DataFrame(x_batch_test) return x_batch, x_batch_test -def single_modal(model,dataset, downstream_model,params): +def single_modal(model,dataset=None, downstream_model=None, params=None, x_train=None, x_test=None, y_train=None, y_test=None): print(model) - alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED": "smi-ted"} + alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"} data = avail_models(raw=True) df = pd.DataFrame(data) - print(list(df["Name"].values)) - if alias[model] in list(df["Name"].values): - if model in alias.keys(): + #print(list(df["Name"].values)) + + if model in list(df["Name"].values): + model_type = model + elif alias[model] in list(df["Name"].values): model_type = alias[model] - else: - model_type = model else: print("Model not available") return + data = avail_datasets() df = pd.DataFrame(data) - print(list(df["Dataset"].values)) + #print(list(df["Dataset"].values)) if dataset in list(df["Dataset"].values): task = dataset - with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1: + with open(f"representation/{task}_{model_type}.pkl", "rb") as f1: x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1) print(f" Representation loaded successfully") - else: + + elif x_train==None: print("Custom Dataset") #return @@ -283,14 +347,40 @@ def single_modal(model,dataset, downstream_model,params): print(f" Representation loaded successfully") + else: - - + y_batch = y_train + y_batch_test = y_test + x_batch, x_batch_test = get_representation(x_train, x_test, model_type) + + # exclude row containing Nan value + if isinstance(x_batch, torch.Tensor): + x_batch = pd.DataFrame(x_batch) + nan_indices = x_batch.index[x_batch.isna().any(axis=1)] + if len(nan_indices) > 0: + x_batch.dropna(inplace = True) + for index in sorted(nan_indices, reverse=True): + del y_batch[index] + print(f'x_batch Nan index: {nan_indices}') + print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}') + + if isinstance(x_batch_test, torch.Tensor): + x_batch_test = pd.DataFrame(x_batch_test) + nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)] + if len(nan_indices) > 0: + x_batch_test.dropna(inplace = True) + for index in sorted(nan_indices, reverse=True): + del y_batch_test[index] + print(f'x_batch_test Nan index: {nan_indices}') + print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}') print(f" Calculating ROC AUC Score ...") if downstream_model == "XGBClassifier": - xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10 + if params == None: + xgb_predict_concat = XGBClassifier() + else: + xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10 xgb_predict_concat.fit(x_batch, y_batch) y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1] @@ -300,21 +390,26 @@ def single_modal(model,dataset, downstream_model,params): print(f"ROC-AUC Score: {roc_auc:.4f}") try: - with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1: + with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1: class_0,class_1 = pickle.load(f1) except: print("Generating latent plots") reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1, verbose=False) n_samples = np.minimum(1000, len(x_batch)) - features_umap = reducer.fit_transform(x_batch[:n_samples]) + try:x = y_batch.values[:n_samples] - except:x = y_batch[:n_samples] + except: x = y_batch[:n_samples] index_0 = [index for index in range(len(x)) if x[index] == 0] index_1 = [index for index in range(len(x)) if x[index] == 1] - class_0 = features_umap[index_0] - class_1 = features_umap[index_1] + try: + features_umap = reducer.fit_transform(x_batch[:n_samples]) + class_0 = features_umap[index_0] + class_1 = features_umap[index_1] + except: + class_0 = [] + class_1 = [] print("Generating latent plots : Done") #vizualize(roc_auc,fpr, tpr, x_batch, y_batch ) @@ -334,20 +429,29 @@ def single_modal(model,dataset, downstream_model,params): print(f"ROC-AUC Score: {roc_auc:.4f}") try: - with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1: + with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1: class_0,class_1 = pickle.load(f1) except: print("Generating latent plots") reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False) n_samples = np.minimum(1000,len(x_batch)) - features_umap = reducer.fit_transform(x_batch[:n_samples]) - try:x = y_batch.values[:n_samples] - except:x = y_batch[:n_samples] - index_0 = [index for index in range(len(x)) if x[index] == 0] - index_1 = [index for index in range(len(x)) if x[index] == 1] - class_0 = features_umap[index_0] - class_1 = features_umap[index_1] + try: + x = y_batch.values[:n_samples] + except: + x = y_batch[:n_samples] + + try: + features_umap = reducer.fit_transform(x_batch[:n_samples]) + index_0 = [index for index in range(len(x)) if x[index] == 0] + index_1 = [index for index in range(len(x)) if x[index] == 1] + + class_0 = features_umap[index_0] + class_1 = features_umap[index_1] + except: + class_0 = [] + class_1 = [] + print("Generating latent plots : Done") #vizualize(roc_auc,fpr, tpr, x_batch, y_batch ) @@ -355,16 +459,19 @@ def single_modal(model,dataset, downstream_model,params): result = f"ROC-AUC Score: {roc_auc:.4f}" return result, roc_auc,fpr, tpr, class_0, class_1 - + elif downstream_model == "SVR": - regressor = SVR(**params) + if params == None: + regressor = SVR() + else: + regressor = SVR(**params) model = TransformedTargetRegressor(regressor= regressor, transformer = MinMaxScaler(feature_range=(-1, 1)) ).fit(x_batch,y_batch) - + y_prob = model.predict(x_batch_test) RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob)) - + print(f"RMSE Score: {RMSE_score:.4f}") result = f"RMSE Score: {RMSE_score:.4f}" @@ -372,20 +479,28 @@ def single_modal(model,dataset, downstream_model,params): reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1, verbose=False) n_samples = np.minimum(1000, len(x_batch)) - features_umap = reducer.fit_transform(x_batch[:n_samples]) - try:x = y_batch.values[:n_samples] - except:x = y_batch[:n_samples] + + try: x = y_batch.values[:n_samples] + except: x = y_batch[:n_samples] #index_0 = [index for index in range(len(x)) if x[index] == 0] #index_1 = [index for index in range(len(x)) if x[index] == 1] - class_0 = features_umap#[index_0] - class_1 = features_umap#[index_1] + try: + features_umap = reducer.fit_transform(x_batch[:n_samples]) + class_0 = features_umap#[index_0] + class_1 = features_umap#[index_1] + except: + class_0 = [] + class_1 = [] print("Generating latent plots : Done") - + return result, RMSE_score,y_batch_test, y_prob, class_0, class_1 elif downstream_model == "Kernel Ridge": - regressor = KernelRidge(**params) + if params == None: + regressor = KernelRidge() + else: + regressor = KernelRidge(**params) model = TransformedTargetRegressor(regressor=regressor, transformer=MinMaxScaler(feature_range=(-1, 1)) ).fit(x_batch, y_batch) @@ -401,8 +516,8 @@ def single_modal(model,dataset, downstream_model,params): verbose=False) n_samples = np.minimum(1000, len(x_batch)) features_umap = reducer.fit_transform(x_batch[:n_samples]) - try:x = y_batch.values[:n_samples] - except:x = y_batch[:n_samples] + try: x = y_batch.values[:n_samples] + except: x = y_batch[:n_samples] # index_0 = [index for index in range(len(x)) if x[index] == 0] # index_1 = [index for index in range(len(x)) if x[index] == 1] @@ -414,7 +529,10 @@ def single_modal(model,dataset, downstream_model,params): elif downstream_model == "Linear Regression": - regressor = LinearRegression(**params) + if params == None: + regressor = LinearRegression() + else: + regressor = LinearRegression(**params) model = TransformedTargetRegressor(regressor=regressor, transformer=MinMaxScaler(feature_range=(-1, 1)) ).fit(x_batch, y_batch) @@ -431,7 +549,7 @@ def single_modal(model,dataset, downstream_model,params): n_samples = np.minimum(1000, len(x_batch)) features_umap = reducer.fit_transform(x_batch[:n_samples]) try:x = y_batch.values[:n_samples] - except:x = y_batch[:n_samples] + except: x = y_batch[:n_samples] # index_0 = [index for index in range(len(x)) if x[index] == 0] # index_1 = [index for index in range(len(x)) if x[index] == 1] @@ -460,7 +578,7 @@ def single_modal(model,dataset, downstream_model,params): n_samples = np.minimum(1000, len(x_batch)) features_umap = reducer.fit_transform(x_batch[:n_samples]) try:x = y_batch.values[:n_samples] - except:x = y_batch[:n_samples] + except: x = y_batch[:n_samples] # index_0 = [index for index in range(len(x)) if x[index] == 0] # index_1 = [index for index in range(len(x)) if x[index] == 1] @@ -469,10 +587,10 @@ def single_modal(model,dataset, downstream_model,params): print("Generating latent plots : Done") return result, RMSE_score, y_batch_test, y_prob, class_0, class_1 + - -def multi_modal(model_list,dataset, downstream_model,params): - print(model_list) +def multi_modal(model_list,dataset=None, downstream_model=None,params=None, x_train=None, x_test=None, y_train=None, y_test=None): + #print(model_list) data = avail_datasets() df = pd.DataFrame(data) list(df["Dataset"].values) @@ -480,7 +598,7 @@ def multi_modal(model_list,dataset, downstream_model,params): if dataset in list(df["Dataset"].values): task = dataset predefined = True - else: + elif x_train==None: predefined = False components = dataset.split(",") train_data = pd.read_csv(components[0])[components[2]] @@ -490,13 +608,18 @@ def multi_modal(model_list,dataset, downstream_model,params): y_batch_test = pd.read_csv(components[1])[components[3]] print("Custom Dataset loaded") - + else: + predefined = False + y_batch = y_train + y_batch_test = y_test + train_data = x_train + test_data = x_test data = avail_models(raw=True) df = pd.DataFrame(data) list(df["Name"].values) - alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED":"smi-ted"} + alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl","SMI-TED":"smi-ted", "Mordred": "Mordred", "MorganFingerprint": "MorganFingerprint"} #if set(model_list).issubset(list(df["Name"].values)): if set(model_list).issubset(list(alias.keys())): for i, model in enumerate(model_list): @@ -507,7 +630,7 @@ def multi_modal(model_list,dataset, downstream_model,params): if i == 0: if predefined: - with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1: + with open(f"representation/{task}_{model_type}.pkl", "rb") as f1: x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1) print(f" Loaded representation/{task}_{model_type}.pkl") else: @@ -517,7 +640,7 @@ def multi_modal(model_list,dataset, downstream_model,params): else: if predefined: - with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1: + with open(f"representation/{task}_{model_type}.pkl", "rb") as f1: x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1) print(f" Loaded representation/{task}_{model_type}.pkl") else: @@ -528,7 +651,6 @@ def multi_modal(model_list,dataset, downstream_model,params): x_batch = pd.concat([x_batch, x_batch_1], axis=1) x_batch_test = pd.concat([x_batch_test, x_batch_test_1], axis=1) - else: print("Model not available") return @@ -538,11 +660,31 @@ def multi_modal(model_list,dataset, downstream_model,params): num_columns = x_batch.shape[1] x_batch.columns = [f'{i + 1}' for i in range(num_columns)] - + + # exclude row containing Nan value + if isinstance(x_batch, torch.Tensor): + x_batch = pd.DataFrame(x_batch) + nan_indices = x_batch.index[x_batch.isna().any(axis=1)] + if len(nan_indices) > 0: + x_batch.dropna(inplace = True) + for index in sorted(nan_indices, reverse=True): + del y_batch[index] + print(f'x_batch Nan index: {nan_indices}') + print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}') + + if isinstance(x_batch_test, torch.Tensor): + x_batch_test = pd.DataFrame(x_batch_test) + nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)] + if len(nan_indices) > 0: + x_batch_test.dropna(inplace = True) + for index in sorted(nan_indices, reverse=True): + del y_batch_test[index] + print(f'x_batch_test Nan index: {nan_indices}') + print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}') print(f"Representations loaded successfully") try: - with open(f"./plot_emb/{task}_multi.pkl", "rb") as f1: + with open(f"plot_emb/{task}_multi.pkl", "rb") as f1: class_0, class_1 = pickle.load(f1) except: print("Generating latent plots") @@ -552,7 +694,7 @@ def multi_modal(model_list,dataset, downstream_model,params): features_umap = reducer.fit_transform(x_batch[:n_samples]) if "Classifier" in downstream_model: - try:x = y_batch.values[:n_samples] + try: x = y_batch.values[:n_samples] except: x = y_batch[:n_samples] index_0 = [index for index in range(len(x)) if x[index] == 0] index_1 = [index for index in range(len(x)) if x[index] == 1] @@ -570,7 +712,10 @@ def multi_modal(model_list,dataset, downstream_model,params): if downstream_model == "XGBClassifier": - xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10) + if params == None: + xgb_predict_concat = XGBClassifier() + else: + xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10) xgb_predict_concat.fit(x_batch, y_batch) y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1] @@ -608,21 +753,27 @@ def multi_modal(model_list,dataset, downstream_model,params): return result, roc_auc,fpr, tpr, class_0, class_1 elif downstream_model == "SVR": - regressor = SVR(**params) + if params == None: + regressor = SVR() + else: + regressor = SVR(**params) model = TransformedTargetRegressor(regressor= regressor, transformer = MinMaxScaler(feature_range=(-1, 1)) ).fit(x_batch,y_batch) - + y_prob = model.predict(x_batch_test) RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob)) - + print(f"RMSE Score: {RMSE_score:.4f}") result = f"RMSE Score: {RMSE_score:.4f}" - + return result, RMSE_score,y_batch_test, y_prob, class_0, class_1 elif downstream_model == "Linear Regression": - regressor = LinearRegression(**params) + if params == None: + regressor = LinearRegression() + else: + regressor = LinearRegression(**params) model = TransformedTargetRegressor(regressor=regressor, transformer=MinMaxScaler(feature_range=(-1, 1)) ).fit(x_batch, y_batch) @@ -636,7 +787,10 @@ def multi_modal(model_list,dataset, downstream_model,params): return result, RMSE_score, y_batch_test, y_prob, class_0, class_1 elif downstream_model == "Kernel Ridge": - regressor = KernelRidge(**params) + if params == None: + regressor = KernelRidge() + else: + regressor = KernelRidge(**params) model = TransformedTargetRegressor(regressor=regressor, transformer=MinMaxScaler(feature_range=(-1, 1)) ).fit(x_batch, y_batch) @@ -665,6 +819,144 @@ def multi_modal(model_list,dataset, downstream_model,params): +def finetune_optuna(x_batch,y_batch, x_batch_test, y_test ): + print(f" Finetuning with Optuna and calculating ROC AUC Score ...") + X_train = x_batch.values + y_train = y_batch.values + X_test = x_batch_test.values + y_test = y_test.values + def objective(trial): + # Define parameters to be optimized + params = { + # 'objective': 'binary:logistic', + 'eval_metric': 'auc', + 'verbosity': 0, + 'n_estimators': trial.suggest_int('n_estimators', 1000, 10000), + # 'booster': trial.suggest_categorical('booster', ['gbtree', 'gblinear', 'dart']), + # 'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0), + 'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0), + 'max_depth': trial.suggest_int('max_depth', 1, 12), + # 'eta': trial.suggest_loguniform('eta', 1e-8, 1.0), + # 'gamma': trial.suggest_loguniform('gamma', 1e-8, 1.0), + # 'grow_policy': trial.suggest_categorical('grow_policy', ['depthwise', 'lossguide']), + # "subsample": trial.suggest_float("subsample", 0.05, 1.0), + # "colsample_bytree": trial.suggest_float("colsample_bytree", 0.05, 1.0), + } + + # Train XGBoost model + dtrain = xgb.DMatrix(X_train, label=y_train) + dtest = xgb.DMatrix(X_test, label=y_test) + + model = xgb.train(params, dtrain) + + # Predict probabilities + y_pred = model.predict(dtest) + + # Calculate ROC AUC score + roc_auc = roc_auc_score(y_test, y_pred) + print("ROC_AUC : ", roc_auc) + + return roc_auc + +def add_new_model(): + models = avail_models(raw=True) + + # Function to display models + def display_models(): + for model in models: + model_display = f"Name: {model['Name']}, Description: {model['Description']}, Timestamp: {model['Timestamp']}" + print(model_display) + + # Function to update models + def update_models(new_name, new_description, new_path): + new_model = { + "Name": new_name, + "Description": new_description, + "Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + #"path": new_path + } + models.append(new_model) + with open("models.json", "w") as outfile: + json.dump(models, outfile) + + print("Model uploaded and updated successfully!") + list_models() + #display_models() + + # Widgets + name_text = widgets.Text(description="Name:", layout=Layout(width='50%')) + description_text = widgets.Text(description="Description:", layout=Layout(width='50%')) + path_text = widgets.Text(description="Path:", layout=Layout(width='50%')) + + def browse_callback(b): + root = tk.Tk() + root.withdraw() # Hide the main window + file_path = filedialog.askopenfilename(title="Select a Model File") + if file_path: + path_text.value = file_path + + browse_button = widgets.Button(description="Browse") + browse_button.on_click(browse_callback) + + def submit_callback(b): + update_models(name_text.value, description_text.value, path_text.value) + + submit_button = widgets.Button(description="Submit") + submit_button.on_click(submit_callback) + + # Display widgets + display(VBox([name_text, description_text, path_text, browse_button, submit_button])) + + +def add_new_dataset(): + # Sample data + datasets = avail_datasets() + + # Function to display models + def display_datasets(): + for dataset in datasets: + dataset_display = f"Name: {dataset['Dataset']}, Input: {dataset['Input']},Output: {dataset['Output']},Path: {dataset['Path']}, Timestamp: {dataset['Timestamp']}" + + # Function to update models + def update_datasets(new_dataset, new_input, new_output, new_path): + new_model = { + "Dataset": new_dataset, + "Input": new_input, + "Output": new_output, + "Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "Path": os.path.basename(new_path) + } + datasets.append(new_model) + with open("datasets.json", "w") as outfile: + json.dump(datasets, outfile) + + print("Dataset uploaded and updated successfully!") + list_data() + + + # Widgets + dataset_text = widgets.Text(description="Dataset:", layout=Layout(width='50%')) + input_text = widgets.Text(description="Input:", layout=Layout(width='50%')) + output_text = widgets.Text(description="Output:", layout=Layout(width='50%')) + path_text = widgets.Text(description="Path:", layout=Layout(width='50%')) + + def browse_callback(b): + root = tk.Tk() + root.withdraw() # Hide the main window + file_path = filedialog.askopenfilename(title="Select a Dataset File") + if file_path: + path_text.value = file_path + + browse_button = widgets.Button(description="Browse") + browse_button.on_click(browse_callback) + + def submit_callback(b): + update_datasets(dataset_text.value, input_text.value, output_text.value, path_text.value) + + submit_button = widgets.Button(description="Submit") + submit_button.on_click(submit_callback) + + display(VBox([dataset_text, input_text, output_text, path_text, browse_button, submit_button])) diff --git a/models/mhg_model/README.md b/models/mhg_model/README.md index b855ff28edd655aedc5097cae88fbb812dd06f76..339698f2033bd48e9e66a67c7c8ba6ce5cb9a626 100644 --- a/models/mhg_model/README.md +++ b/models/mhg_model/README.md @@ -27,7 +27,7 @@ In addition, the decoder inherits the theoretical guarantee of MHG on always gen ### Pretrained Models and Training Logs -We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]() +We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.mhg-ged/blob/main/mhggnn_pretrained_model_0724_2023.pickle) Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs. diff --git a/models/mhg_model/images/mhg_example.png b/models/mhg_model/images/mhg_example.png index 3a7dd8ce73476fba75ed242e67147946d99740eb..816da8f712d997ef5f45ae365bd32b2da0ca62f1 100644 Binary files a/models/mhg_model/images/mhg_example.png and b/models/mhg_model/images/mhg_example.png differ diff --git a/models/mhg_model/images/mhg_example1.png b/models/mhg_model/images/mhg_example1.png index 150b71f10580655433a6f59a60cbc2afc07d8dc8..089cdde868fc15c8c9dfce84f3bcbbb650901da1 100644 Binary files a/models/mhg_model/images/mhg_example1.png and b/models/mhg_model/images/mhg_example1.png differ diff --git a/models/mhg_model/images/mhg_example2.png b/models/mhg_model/images/mhg_example2.png index b00f97a7fb3bec25c0e6e42990d18aaa216eff2d..87c8ebad807ef7dff641d217a0997ae47ca24ed5 100644 Binary files a/models/mhg_model/images/mhg_example2.png and b/models/mhg_model/images/mhg_example2.png differ diff --git a/models/mhg_model/load.py b/models/mhg_model/load.py index 20f21ea1a5584d002739ef9b212f3f120f37d240..322b43ce2e683a8d977dba6412c020553b469838 100644 --- a/models/mhg_model/load.py +++ b/models/mhg_model/load.py @@ -17,6 +17,7 @@ from typing_extensions import Self from .graph_grammar.io.smi import hg_to_mol from .models.mhgvae import GrammarGINVAE + from huggingface_hub import hf_hub_download @@ -73,12 +74,30 @@ class PretrainedModelWrapper: return output -def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[ +def load(model_name: str = "mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[ PretrainedModelWrapper]: + repo_id = "ibm/materials.mhg-ged" filename = "pytorch_model.bin" #"mhggnn_pretrained_model_0724_2023.pickle" file_path = hf_hub_download(repo_id=repo_id, filename=filename) with open(file_path, "rb") as f: model_dict = torch.load(f) return PretrainedModelWrapper(model_dict) + + + """try: + if os.path.isfile(model_name): + with open(model_name, "rb") as f: + model_dict = pickle.load(f) + print("MHG Model Loaded") + return PretrainedModelWrapper(model_dict) + + except: + + for p in sys.path: + file = p + "/" + model_name + if os.path.isfile(file): + with open(file, "rb") as f: + model_dict = pickle.load(f) + return PretrainedModelWrapper(model_dict)""" return None diff --git a/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf b/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf index a7dcc1270d1f444f77366013ad2d3d93ebb426ab..4bf1999e79d46e23da49a337a02dd6f189f4086a 100644 Binary files a/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf and b/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf differ diff --git a/models/selfies_model/selfies-ted.png b/models/selfies_model/selfies-ted.png index a71127e0a6baf0110d8074e63a08ba060c931121..d1c3561c3751785b0507d966a01ea1fe3b859fa3 100644 Binary files a/models/selfies_model/selfies-ted.png and b/models/selfies_model/selfies-ted.png differ diff --git a/models/selfies_ted/README.md b/models/selfies_ted/README.md new file mode 100644 index 0000000000000000000000000000000000000000..01f70f739727bf443957cdef04353175e0c4f47f --- /dev/null +++ b/models/selfies_ted/README.md @@ -0,0 +1,87 @@ +--- +license: apache-2.0 +library_name: transformers +pipeline_tag: feature-extraction +tags: +- chemistry +--- + +# selfies-ted + +selfies-ted is a project for encoding SMILES (Simplified Molecular Input Line Entry System) into SELFIES (SELF-referencing Embedded Strings) and generating embeddings for molecular representations. + + +## Model Architecture + +Configuration details + +Encoder and Decoder FFN dimensions: 256 +Number of attention heads: 4 +Number of encoder and decoder layers: 2 +Total number of hidden layers: 6 +Maximum position embeddings: 128 +Model dimension (d_model): 256 + +## Pretrained Models and Training Logs +We provide checkpoints of the selfies-ted model pre-trained on a dataset of molecules curated from PubChem. The pre-trained model shows competitive performance on molecular representation tasks. For model weights: "HuggingFace link". + +To install and use the pre-trained model: + +Download the selfies_ted_model.pkl file from the "HuggingFace link". +Add the selfies-ted selfies_ted_model.pkl to the models/ directory. The directory structure should look like the following: + +``` +models/ +└── selfies_ted_model.pkl +``` + +## Installation + +To use this project, you'll need to install the required dependencies. We recommend using a virtual environment: + +```bash +python -m venv venv +source venv/bin/activate # On Windows use `venv\Scripts\activate` +``` + +Install the required dependencies + +``` +pip install -r requirements.txt +``` + + +## Usage + +### Import + +``` +import load +``` +### Training the Model + +To train the model, use the train.py script: + +``` +python train.py -f <path_to_your_data_file> +``` + + +Note: The actual usage may depend on the specific implementation in load.py. Please refer to the source code for detailed functionality. + +### Load the model and tokenizer +``` +load.load("path/to/checkpoint.pkl") +``` +### Encode SMILES strings +``` +smiles_list = ["COC", "CCO"] +``` +``` +embeddings = load.encode(smiles_list) +``` + + +## Example Notebook + +Example notebook of this project is `selfies-ted-example.ipynb`. diff --git a/models/selfies_ted/load.py b/models/selfies_ted/load.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec37d3df32957e79b29e9cf71ee49d3690f9a23 --- /dev/null +++ b/models/selfies_ted/load.py @@ -0,0 +1,92 @@ +import os +import sys +import torch +import selfies as sf # selfies>=2.1.1 +import pickle +import pandas as pd +import numpy as np +from datasets import Dataset +from rdkit import Chem +from transformers import AutoTokenizer, AutoModel + + +class SELFIES(torch.nn.Module): + + def __init__(self): + super().__init__() + self.model = None + self.tokenizer = None + self.invalid = [] + + def get_selfies(self, smiles_list): + self.invalid = [] + spaced_selfies_batch = [] + for i, smiles in enumerate(smiles_list): + try: + selfies = sf.encoder(smiles.rstrip()) + except: + try: + smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.rstrip())) + selfies = sf.encoder(smiles) + except: + selfies = "[]" + self.invalid.append(i) + + spaced_selfies_batch.append(selfies.replace('][', '] [')) + + return spaced_selfies_batch + + + def get_embedding(self, selfies): + encoding = self.tokenizer(selfies["selfies"], return_tensors='pt', max_length=128, truncation=True, padding='max_length') + input_ids = encoding['input_ids'] + attention_mask = encoding['attention_mask'] + outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask) + model_output = outputs.last_hidden_state + + input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float() + sum_embeddings = torch.sum(model_output * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + model_output = sum_embeddings / sum_mask + + del encoding['input_ids'] + del encoding['attention_mask'] + + encoding["embedding"] = model_output + + return encoding + + + def load(self, checkpoint="bart-2908.pickle"): + """ + inputs : + checkpoint (pickle object) + """ + + self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") + self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted") + + + + + + # TODO: remove `use_gpu` argument in validation pipeline + def encode(self, smiles_list=[], use_gpu=False, return_tensor=False): + """ + inputs : + checkpoint (pickle object) + :return: embedding + """ + selfies = self.get_selfies(smiles_list) + selfies_df = pd.DataFrame(selfies,columns=["selfies"]) + data = Dataset.from_pandas(selfies_df) + embedding = data.map(self.get_embedding, batched=True, num_proc=1, batch_size=128) + emb = np.asarray(embedding["embedding"].copy()) + + for idx in self.invalid: + emb[idx] = np.nan + print("Cannot encode {0} to selfies and embedding replaced by NaN".format(smiles_list[idx])) + + if return_tensor: + return torch.tensor(emb) + return pd.DataFrame(emb) diff --git a/models/selfies_ted/requirements.txt b/models/selfies_ted/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9183360cca79111e2e64fe4b65849e4df75c195f --- /dev/null +++ b/models/selfies_ted/requirements.txt @@ -0,0 +1,12 @@ +torch>=2.1.0 +transformers>=4.38 +numpy>=1.26.1 +datasets>=2.13.1 +evaluate>=0.4.0 +selfies>=2.1.0 +scikit-learn>=1.2.1 +pyarrow>=14.0.1 +requests>=2.31.0 +urllib3>=2.0.7 +aiohttp>=3.9.0 +zipp>=3.17.0 \ No newline at end of file diff --git a/models/selfies_ted/selfies-ted-example.ipynb b/models/selfies_ted/selfies-ted-example.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..856f98cde351f6fa5b3fcfdebd2c5ad6726fc380 --- /dev/null +++ b/models/selfies_ted/selfies-ted-example.ipynb @@ -0,0 +1,136 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9d9b6eb8-9edb-44bd-9e5a-3a6ea67f5117", + "metadata": {}, + "source": [ + "### Import library" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c3ac4418", + "metadata": {}, + "outputs": [], + "source": [ + "from load import SELFIES" + ] + }, + { + "cell_type": "markdown", + "id": "790061cf-5470-4564-987e-aa2e492337db", + "metadata": {}, + "source": [ + "### Initialize and load" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "85847f26-e2f4-475a-a88e-41fd9cccfc0f", + "metadata": {}, + "outputs": [], + "source": [ + "model = SELFIES()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "095e864c", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "model.load(checkpoint=\"bart-2908.pickle\")" + ] + }, + { + "cell_type": "markdown", + "id": "55f1a68c-c462-4dee-9139-9befb469f176", + "metadata": {}, + "source": [ + "### Example to get embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2357ef0a", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b494cbf9878a4f5c8f4093e38fb82fd5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/3 [00:00<?, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "smiles_list = [\"CCO\", \"O=C=O\", \"OC(=O)c1ccccc1C(=O)O\"]\n", + "embeddings = model.encode(smiles_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3871c513-d0a9-4e70-9c18-3f0b491e07b2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 1024)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "289a8795-d6d8-4828-b2b2-b4d4a97a4604", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/models/selfies_ted/selfies-ted.png b/models/selfies_ted/selfies-ted.png new file mode 100644 index 0000000000000000000000000000000000000000..d1c3561c3751785b0507d966a01ea1fe3b859fa3 --- /dev/null +++ b/models/selfies_ted/selfies-ted.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1229d74cd9473344d9907f5b8b2ae22694bdd77e94d3ae8f1f8dadacf538ee9e +size 47631 diff --git a/models/smi_ted/.gitignore b/models/smi_ted/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..594f1b8f3000cc613695e1704763213e5697f4f8 --- /dev/null +++ b/models/smi_ted/.gitignore @@ -0,0 +1,18 @@ +# Model weights +inference/smi_ted_light/smi-ted-Light_40.pt + +# pyenv +.python-version + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# editor files +.vscode/ +.DS_Store diff --git a/models/smi_ted/README.md b/models/smi_ted/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a4bd49c6aa2586e33c3baa2cd5caf438d0253756 --- /dev/null +++ b/models/smi_ted/README.md @@ -0,0 +1,138 @@ +# SMILES-based Transformer Encoder-Decoder (SMI-TED) + +This repository provides PyTorch source code associated with our publication, "A Large Encoder-Decoder Family of Foundation Models for Chemical Language". + +**Paper:** [Arxiv Link](https://arxiv.org/abs/2407.20267) + +**HuggingFace:** [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted) + +For more information contact: eduardo.soares@ibm.com or evital@br.ibm.com. + + + +## Introduction + +We present a large encoder-decoder chemical foundation model, SMILES-based Transformer Encoder-Decoder (SMI-TED), pre-trained on a curated dataset of 91 million SMILES samples sourced from PubChem, equivalent to 4 billion molecular tokens. SMI-TED supports various complex tasks, including quantum property prediction, with two main variants ($289M$ and $8 \times 289M$). Our experiments across multiple benchmark datasets demonstrate state-of-the-art performance for various tasks. Model weights are available at: [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted). + +## Table of Contents + +1. [Getting Started](#getting-started) + 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs) + 2. [Replicating Conda Environment](#replicating-conda-environment) +2. [Pretraining](#pretraining) +3. [Finetuning](#finetuning) +4. [Feature Extraction](#feature-extraction) +5. [Citations](#citations) + +## Getting Started + +**This code and environment have been tested on Nvidia V100s and Nvidia A100s** + +### Pretrained Models and Training Logs + +We provide checkpoints of the SMI-TED model pre-trained on a dataset of ~91M molecules curated from PubChem. The pre-trained model shows competitive performance on classification and regression benchmarks from MoleculeNet. For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted) + +Add the SMI-TED `pre-trained weights.pt` to the `inference/` or `finetune/` directory according to your needs. The directory structure should look like the following: + +``` +inference/ +├── smi_ted_light +│ ├── smi_ted_light.pt +│ ├── bert_vocab_curated.txt +│ └── load.py +``` +and/or: + +``` +finetune/ +├── smi_ted_light +│ ├── smi_ted_light.pt +│ ├── bert_vocab_curated.txt +│ └── load.py +``` + +### Replicating Conda Environment + +Follow these steps to replicate our Conda environment and install the necessary libraries: + +#### Create and Activate Conda Environment + +``` +conda create --name smi-ted-env python=3.10 +conda activate smi-ted-env +``` + +#### Install Packages with Conda + +``` +conda install pytorch=2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia +``` + +#### Install Packages with Pip + +``` +pip install -r requirements.txt +pip install pytorch-fast-transformers +``` + +## Pretraining + +For pretraining, we use two strategies: the masked language model method to train the encoder part and an encoder-decoder strategy to refine SMILES reconstruction and improve the generated latent space. + +SMI-TED is pre-trained on canonicalized and curated 91M SMILES from PubChem with the following constraints: + +- Compounds are filtered to a maximum length of 202 tokens during preprocessing. +- A 95/5/0 split is used for encoder training, with 5% of the data for decoder pretraining. +- A 100/0/0 split is also used to train the encoder and decoder directly, enhancing model performance. + +The pretraining code provides examples of data processing and model training on a smaller dataset, requiring 8 A100 GPUs. + +To pre-train the two variants of the SMI-TED model, run: + +``` +bash training/run_model_light_training.sh +``` +or +``` +bash training/run_model_large_training.sh +``` + +Use `train_model_D.py` to train only the decoder or `train_model_ED.py` to train both the encoder and decoder. + +## Finetuning + +The finetuning datasets and environment can be found in the [finetune](finetune/) directory. After setting up the environment, you can run a finetuning task with: + +``` +bash finetune/smi_ted_light/esol/run_finetune_esol.sh +``` + +Finetuning training/checkpointing resources will be available in directories named `checkpoint_<measure_name>`. + +## Feature Extraction + +The example notebook [smi_ted_encoder_decoder_example.ipynb](notebooks/smi_ted_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks. It also includes examples of classification and regression tasks. For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted) + +To load smi-ted, you can simply use: + +```python +model = load_smi_ted( + folder='../inference/smi_ted_light', + ckpt_filename='smi_ted_light.pt' +) +``` + +To encode SMILES into embeddings, you can use: + +```python +with torch.no_grad(): + encoded_embeddings = model.encode(df['SMILES'], return_torch=True) +``` +For decoder, you can use the function, so you can return from embeddings to SMILES strings: + +```python +with torch.no_grad(): + decoded_smiles = model.decode(encoded_embeddings) +``` + + diff --git a/models/smi_ted/finetune/args.py b/models/smi_ted/finetune/args.py new file mode 100644 index 0000000000000000000000000000000000000000..5a698274dafef1671059da9456dbdce404caaafc --- /dev/null +++ b/models/smi_ted/finetune/args.py @@ -0,0 +1,337 @@ +import argparse + + +def get_parser(parser=None): + if parser is None: + parser = argparse.ArgumentParser() + + # Model + # model_arg = parser.add_argument_group('Model') + parser.add_argument("--n_head", type=int, default=8, help="GPT number of heads") + parser.add_argument("--n_layer", type=int, default=12, help="GPT number of layers") + parser.add_argument( + "--q_dropout", type=float, default=0.5, help="Encoder layers dropout" + ) + parser.add_argument( + "--d_dropout", type=float, default=0.1, help="Decoder layers dropout" + ) + parser.add_argument( + "--n_embd", type=int, default=768, help="Latent vector dimensionality" + ) + parser.add_argument( + "--fc_h", type=int, default=512, help="Fully connected hidden dimensionality" + ) + parser.add_argument("--n_output", type=int, default=1) + + # Train + # train_arg = parser.add_argument_group('Train') + parser.add_argument("--n_batch", type=int, default=512, help="Batch size") + parser.add_argument( + "--unlike_alpha", type=float, default=1.0, help="unlikelihood loss alpha weight" + ) + parser.add_argument( + "--from_scratch", + action="store_true", + default=False, + help="train on qm9 from scratch", + ) + parser.add_argument( + "--unlikelihood", + action="store_true", + default=False, + help="use unlikelihood loss with gpt pretrain", + ) + parser.add_argument( + "--grad_acc", + type=int, + default=1, + help="number of batches to accumulate gradients", + ) + parser.add_argument( + "--checkpoint_every", + type=int, + default=1000, + help="save checkpoint every x iterations", + ) + parser.add_argument( + "--clip_grad", type=int, default=50, help="Clip gradients to this value" + ) + parser.add_argument( + "--lr_start", type=float, default=3 * 1e-4, help="Initial lr value" + ) + parser.add_argument( + "--lr_end", type=float, default=3 * 1e-4, help="Maximum lr weight value" + ) + parser.add_argument( + "--lr_multiplier", type=int, default=1, help="lr weight multiplier" + ) + parser.add_argument( + "--n_last", type=int, default=1000, help="Number of iters to smooth loss calc" + ) + parser.add_argument("--n_jobs", type=int, default=1, help="Number of threads") + parser.add_argument( + "--accelerator", + type=str, + default="ddp", + help="The accelerator backend to use (previously known as distributed_backend)", + ) + parser.add_argument( + "--num_nodes", + type=int, + default=1, + help="number of GPU nodes for distributed training", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help='Device to run: "cpu" or "cuda:<device number>"', + ) + parser.add_argument("--seed", type=int, default=12345, help="Seed") + parser.add_argument( + "--init_params_from", + type=str, + default="", + help="Path to a ckpt used to initialize the parameters if no restart_path is provided", + ) + parser.add_argument( + "--train_decoder_every", + type=int, + default=10, + help="Optimize decoder params every n batches", + ) + parser.add_argument( + "--lr_decoder", type=float, default=1e-4, help="Learning rate for decoder part" + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="local_rank for distributed training on gpus", + ) + parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--tensorboard_path", default="./runs/deepspeed", help="tensorboard log dir" + ) + + # common_arg = parser.add_argument_group('Common') + parser.add_argument( + "--vocab_load", type=str, required=False, help="Where to load the vocab" + ) + parser.add_argument( + "--n_samples", type=int, required=False, help="Number of samples to sample" + ) + parser.add_argument( + "--gen_save", type=str, required=False, help="Where to save the gen molecules" + ) + parser.add_argument( + "--max_len", type=int, default=100, help="Max of length of SMILES" + ) + parser.add_argument( + "--train_load", type=str, required=False, help="Where to load the model" + ) + parser.add_argument( + "--val_load", type=str, required=False, help="Where to load the model" + ) + parser.add_argument( + "--n_workers", + type=int, + required=False, + default=1, + help="Where to load the model", + ) + # beam search hyper parameters + parser.add_argument( + "--beam_size", type=int, default=0, help="Number of beams to generate" + ) + parser.add_argument( + "--num_seq_returned", + type=int, + default=0, + help="number of beams to be returned (must be <= beam_size", + ) + parser.add_argument( + "--min_len", type=int, default=1, help="minimum length to be generated" + ) + parser.add_argument( + "--nucleus_thresh", type=float, default=0.9, help="nucleus sampling threshold" + ) + parser.add_argument( + "--finetune_path", + type=str, + default="", + help="path to trainer file to continue training", + ) + parser.add_argument( + "--restart_path", + type=str, + default="", + help="path to trainer file to continue training", + ) + parser.add_argument( + "--data_path", type=str, default="", help="path to pubchem file" + ) + parser.add_argument( + "--pretext_size", type=int, default=0, help="number of k-mers to pretext" + ) + parser.add_argument( + "--model_save_dir", + type=str, + required=False, + default="./models_dump/", + help="Where to save the models/log/config/vocab", + ) + parser.add_argument( + "--model_save", + type=str, + required=False, + default="model.pt", + help="Where to save the model", + ) + # parser.add_argument('--save_frequency', + # type=int, default=20, + # help='How often to save the model') + parser.add_argument( + "--num_epoch", type=int, default=1, help="number of epochs to train" + ) + # parser.add_argument('--num_iter', + # type=int, default=-1, + # help='how many itersations per epoch (for unlikelihood tuning)') + parser.add_argument( + "--log_file", type=str, required=False, help="Where to save the log" + ) + parser.add_argument( + "--tb_loc", + type=str, + required=False, + help="Where to save the tensorflow location", + ) + parser.add_argument( + "--config_save", type=str, required=False, help="Where to save the config" + ) + parser.add_argument("--vocab_save", type=str, help="Where to save the vocab") + + # resume_arg = parser.add_argument_group('Resume') + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="do not erase cache at end of program", + ) + parser.add_argument( + "--fast_dev_run", + default=False, + help="This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).", + ) + parser.add_argument( + "--freeze_model", + default=False, + action="store_true", + help="freeze weights of bert model during fine tuning", + ) + parser.add_argument( + "--resume", default=False, action="store_true", help="Resume from a saved model" + ) + parser.add_argument( + "--rotate", + default=False, + action="store_true", + help="use rotational relative embedding", + ) + parser.add_argument( + "--model_load", type=str, required=False, help="Where to load the model" + ) + parser.add_argument( + "--root_dir", type=str, required=False, default=".", help="location of root dir" + ) + parser.add_argument( + "--config_load", type=str, required=False, help="Where to load the config" + ) + parser.add_argument( + "--gpus", type=int, required=False, default=1, help="number of gpus to use" + ) + # parser.add_argument('--start_epoch', + # type=int, required=False, default=0, + # help='Where to load the config') + + parser.add_argument( + "--model_arch", + type=str, + required=False, + help="used to teack model arch in params", + ) + parser.add_argument( + "--eval_every", + type=int, + default=50000, + help="run evaluation every x iterations", + ) + parser.add_argument( + "--num_feats", + type=int, + required=False, + default=32, + help="number of random reatures for FAVOR+", + ) + parser.add_argument( + "--max_epochs", type=int, required=False, default=1, help="max number of epochs" + ) + + # debug() FINE TUNEING + # parser.add_argument('--save_dir', type=str, required=True) + parser.add_argument( + "--mode", type=str, default="cls", help="type of pooling to use" + ) + parser.add_argument("--dataset_length", type=int, default=None, required=False) + parser.add_argument("--num_workers", type=int, default=0, required=False) + parser.add_argument("--dropout", type=float, default=0.1, required=False) + # parser.add_argument("--dims", type=int, nargs="*", default="", required=False) + parser.add_argument( + "--smiles_embedding", + type=str, + default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt", + ) + # parser.add_argument("--train_pct", type=str, required=False, default="95") + # parser.add_argument("--aug", type=int, required=True) + parser.add_argument("--dataset_name", type=str, required=False, default="sol") + parser.add_argument("--measure_name", type=str, required=False, default="measure") + # parser.add_argument("--emb_type", type=str, required=True) + parser.add_argument("--checkpoints_folder", type=str, required=True) + # parser.add_argument("--results_dir", type=str, required=True) + # parser.add_argument("--patience_epochs", type=int, required=True) + parser.add_argument("--model_path", type=str, default="./smi_ted/") + parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt") + parser.add_argument("--restart_filename", type=str, default="") + # parser.add_argument('--n_output', type=int, default=1) + parser.add_argument("--save_every_epoch", type=int, default=0) + parser.add_argument("--save_ckpt", type=int, default=1) + parser.add_argument("--start_seed", type=int, default=0) + parser.add_argument("--smi_ted_version", type=str, default="v1") + parser.add_argument("--train_decoder", type=int, default=1) + parser.add_argument("--target_metric", type=str, default="rmse") + parser.add_argument("--loss_fn", type=str, default="mae") + + parser.add_argument( + "--data_root", + type=str, + required=False, + default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity", + ) + # parser.add_argument("--use_bn", type=int, default=0) + parser.add_argument("--use_linear", type=int, default=0) + + parser.add_argument("--lr", type=float, default=0.001) + # parser.add_argument("--weight_decay", type=float, default=5e-4) + # parser.add_argument("--val_check_interval", type=float, default=1.0) + parser.add_argument("--batch_size", type=int, default=64) + + return parser + + +def parse_args(): + parser = get_parser() + args = parser.parse_args() + return args diff --git a/models/smi_ted/finetune/finetune_classification.py b/models/smi_ted/finetune/finetune_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..80e4636f4b45093947838450c3e1a599fdc3d273 --- /dev/null +++ b/models/smi_ted/finetune/finetune_classification.py @@ -0,0 +1,68 @@ +# Deep learning +import torch +import torch.nn as nn +from torch import optim +from trainers import TrainerClassifier +from utils import get_optim_groups + +# Data +import pandas as pd +import numpy as np + +# Standard library +import args +import os + + +def main(config): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # load dataset + df_train = pd.read_csv(f"{config.data_root}/train.csv") + df_valid = pd.read_csv(f"{config.data_root}/valid.csv") + df_test = pd.read_csv(f"{config.data_root}/test.csv") + + # load model + if config.smi_ted_version == 'v1': + from smi_ted_light.load import load_smi_ted + elif config.smi_ted_version == 'v2': + from smi_ted_large.load import load_smi_ted + + model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False) + model.net.apply(model._init_weights) + print(model.net) + + lr = config.lr_start*config.lr_multiplier + optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder)) + if config.loss_fn == 'crossentropy': + loss_function = nn.CrossEntropyLoss() + + # init trainer + trainer = TrainerClassifier( + raw_data=(df_train, df_valid, df_test), + dataset_name=config.dataset_name, + target=config.measure_name, + batch_size=config.n_batch, + hparams=config, + target_metric=config.target_metric, + seed=config.start_seed, + smi_ted_version=config.smi_ted_version, + checkpoints_folder=config.checkpoints_folder, + restart_filename=config.restart_filename, + device=device, + save_every_epoch=bool(config.save_every_epoch), + save_ckpt=bool(config.save_ckpt) + ) + trainer.compile( + model=model, + optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)), + loss_fn=loss_function + ) + trainer.fit(max_epochs=config.max_epochs) + trainer.evaluate() + + +if __name__ == '__main__': + parser = args.get_parser() + config = parser.parse_args() + main(config) \ No newline at end of file diff --git a/models/smi_ted/finetune/finetune_classification_multitask.py b/models/smi_ted/finetune/finetune_classification_multitask.py new file mode 100644 index 0000000000000000000000000000000000000000..d244f3650db1e76a51e8050be0abdb4a92b168cb --- /dev/null +++ b/models/smi_ted/finetune/finetune_classification_multitask.py @@ -0,0 +1,101 @@ +# Deep learning +import torch +import torch.nn as nn +from torch import optim +from trainers import TrainerClassifierMultitask +from utils import get_optim_groups + +# Data +import pandas as pd +import numpy as np + +# Standard library +import args +import os + + +def main(config): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Define Target and Causal Features + if config.dataset_name == 'tox21': + targets = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', + 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'] + elif config.dataset_name == 'clintox': + targets = ['FDA_APPROVED', 'CT_TOX'] + elif config.dataset_name == 'sider': + targets = [ + 'Hepatobiliary disorders', 'Metabolism and nutrition disorders', + 'Product issues', 'Eye disorders', 'Investigations', + 'Musculoskeletal and connective tissue disorders', + 'Gastrointestinal disorders', 'Social circumstances', + 'Immune system disorders', 'Reproductive system and breast disorders', + 'Neoplasms benign, malignant and unspecified (incl cysts and polyps)', + 'General disorders and administration site conditions', + 'Endocrine disorders', 'Surgical and medical procedures', + 'Vascular disorders', 'Blood and lymphatic system disorders', + 'Skin and subcutaneous tissue disorders', + 'Congenital, familial and genetic disorders', 'Infections and infestations', + 'Respiratory, thoracic and mediastinal disorders', 'Psychiatric disorders', + 'Renal and urinary disorders', + 'Pregnancy, puerperium and perinatal conditions', + 'Ear and labyrinth disorders', 'Cardiac disorders', + 'Nervous system disorders', 'Injury, poisoning and procedural complications' + ] + elif config.dataset_name == 'muv': + targets = [ + 'MUV-466', 'MUV-548', 'MUV-600', 'MUV-644', 'MUV-652', 'MUV-689', + 'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810', + 'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859' + ] + config.n_output = len(targets) + + # load dataset + df_train = pd.read_csv(f"{config.data_root}/train.csv") + df_valid = pd.read_csv(f"{config.data_root}/valid.csv") + df_test = pd.read_csv(f"{config.data_root}/test.csv") + + # load model + if config.smi_ted_version == 'v1': + from smi_ted_light.load import load_smi_ted + elif config.smi_ted_version == 'v2': + from smi_ted_large.load import load_smi_ted + + model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=len(targets), eval=False) + model.net.apply(model._init_weights) + print(model.net) + + lr = config.lr_start*config.lr_multiplier + optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder)) + if config.loss_fn == 'bceloss': + loss_function = nn.BCELoss() + + # init trainer + trainer = TrainerClassifierMultitask( + raw_data=(df_train, df_valid, df_test), + dataset_name=config.dataset_name, + target=targets, + batch_size=config.n_batch, + hparams=config, + target_metric=config.target_metric, + seed=config.start_seed, + smi_ted_version=config.smi_ted_version, + checkpoints_folder=config.checkpoints_folder, + restart_filename=config.restart_filename, + device=device, + save_every_epoch=bool(config.save_every_epoch), + save_ckpt=bool(config.save_ckpt) + ) + trainer.compile( + model=model, + optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)), + loss_fn=loss_function + ) + trainer.fit(max_epochs=config.max_epochs) + trainer.evaluate() + + +if __name__ == '__main__': + parser = args.get_parser() + config = parser.parse_args() + main(config) \ No newline at end of file diff --git a/models/smi_ted/finetune/finetune_regression.py b/models/smi_ted/finetune/finetune_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..2a05d32baa43fd9eb127733f1da00b146d7a5172 --- /dev/null +++ b/models/smi_ted/finetune/finetune_regression.py @@ -0,0 +1,70 @@ +# Deep learning +import torch +import torch.nn as nn +from torch import optim +from trainers import TrainerRegressor +from utils import RMSELoss, get_optim_groups + +# Data +import pandas as pd +import numpy as np + +# Standard library +import args +import os + + +def main(config): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # load dataset + df_train = pd.read_csv(f"{config.data_root}/train.csv") + df_valid = pd.read_csv(f"{config.data_root}/valid.csv") + df_test = pd.read_csv(f"{config.data_root}/test.csv") + + # load model + if config.smi_ted_version == 'v1': + from smi_ted_light.load import load_smi_ted + elif config.smi_ted_version == 'v2': + from smi_ted_large.load import load_smi_ted + + model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False) + model.net.apply(model._init_weights) + print(model.net) + + lr = config.lr_start*config.lr_multiplier + optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder)) + if config.loss_fn == 'rmse': + loss_function = RMSELoss() + elif config.loss_fn == 'mae': + loss_function = nn.L1Loss() + + # init trainer + trainer = TrainerRegressor( + raw_data=(df_train, df_valid, df_test), + dataset_name=config.dataset_name, + target=config.measure_name, + batch_size=config.n_batch, + hparams=config, + target_metric=config.target_metric, + seed=config.start_seed, + smi_ted_version=config.smi_ted_version, + checkpoints_folder=config.checkpoints_folder, + restart_filename=config.restart_filename, + device=device, + save_every_epoch=bool(config.save_every_epoch), + save_ckpt=bool(config.save_ckpt) + ) + trainer.compile( + model=model, + optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)), + loss_fn=loss_function + ) + trainer.fit(max_epochs=config.max_epochs) + trainer.evaluate() + + +if __name__ == '__main__': + parser = args.get_parser() + config = parser.parse_args() + main(config) \ No newline at end of file diff --git a/models/smi_ted/finetune/moleculenet/bace/test.csv b/models/smi_ted/finetune/moleculenet/bace/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..adc1ccdfbfadfc6b9d6c35079ac028d1a748499d --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/bace/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3af97c680375dd09349c63b4779b35166212302e79e4fc7a1752ef5d71cf35b +size 400436 diff --git a/models/smi_ted/finetune/moleculenet/bace/train.csv b/models/smi_ted/finetune/moleculenet/bace/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..017b9a6079d76ea84dd61b119ffbc374d765cc09 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/bace/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5b3426e84dc7e2f40f2cf9d15d4d38328126c07f49c215cfb4fb657f69200de +size 3109699 diff --git a/models/smi_ted/finetune/moleculenet/bace/valid.csv b/models/smi_ted/finetune/moleculenet/bace/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..11c8f7fbcbe27d30244a8f8d31dd84f35a270e88 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/bace/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:813c8f2af5a1058568cf60b7021b8b2cd818a17944afd0b09f9d838e36ee985d +size 397085 diff --git a/models/smi_ted/finetune/moleculenet/bbbp/test.csv b/models/smi_ted/finetune/moleculenet/bbbp/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..21089037ffa94aca5db3083f16b887b79bd74212 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/bbbp/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cca4161c44535fd0f8ff917cc68d26703da7fbce19ddecb7dc5f7ae4b4d241a6 +size 14874 diff --git a/models/smi_ted/finetune/moleculenet/bbbp/train.csv b/models/smi_ted/finetune/moleculenet/bbbp/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..314cc5ea086cecbd3d7c0ab9fb96371619aca018 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/bbbp/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7300807bf21ea1177efd81c218e43275ed00b6c3006b5dae7625f774edb6b1a6 +size 115549 diff --git a/models/smi_ted/finetune/moleculenet/bbbp/valid.csv b/models/smi_ted/finetune/moleculenet/bbbp/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..0255eb26d4d514fd9446bf356938161e3e5d7378 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/bbbp/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af39cc3735a356010a072e1e196a64eca6e0d88f0b2a023d4dc1adba7030ce40 +size 15655 diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/biodeg_example.csv b/models/smi_ted/finetune/moleculenet/biodegradability/biodeg_example.csv new file mode 100644 index 0000000000000000000000000000000000000000..af1df8f88f1194796428d43b11b8c8442feeac15 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/biodegradability/biodeg_example.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c98992c1c22ae7468a41fb7bc86c775ccc30fa29e50053bb148ffc2f2d95551e +size 6352 diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/biodegradability.csv b/models/smi_ted/finetune/moleculenet/biodegradability/biodegradability.csv new file mode 100644 index 0000000000000000000000000000000000000000..667fbf1e87eb9753bffe53851749bf0c0accf8e6 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/biodegradability/biodegradability.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ec61887444a0e8925b16cca48433c3b3bff1ac5cf08f448d6b64bbdbc14a318 +size 416181 diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/test.csv b/models/smi_ted/finetune/moleculenet/biodegradability/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..3f89d1cf7d041cc4048d95328f4135f03a98d4e1 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/biodegradability/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86c2f7f39add0fff77358454c0f1b289a233e4a78d50b7f005ec2dc1c632d473 +size 84488 diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/train.csv b/models/smi_ted/finetune/moleculenet/biodegradability/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..92b5d108ee4ef6abe93e0deec05f9b6bac50bbbd --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/biodegradability/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a4a94ae0f8c134ce10f2d853eced84d031a4e7b394662344a9141e7567b3eb2 +size 252230 diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/valid.csv b/models/smi_ted/finetune/moleculenet/biodegradability/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..301cf2278dee811b8afdaf79f771c650af2b4dba --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/biodegradability/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09e827ee7e55544f5b327d5e2ef2d9fe09e3f62024e1316b6e71d1fc9be275a1 +size 85290 diff --git a/models/smi_ted/finetune/moleculenet/clintox/test.csv b/models/smi_ted/finetune/moleculenet/clintox/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..58483d5b0478e7cbabb603c70177ab8d1ac0157a --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/clintox/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:963a05e8eeaaa38fd3688f448dfc28cd0917ea280b1b9cb5b4297244f7f68fe2 +size 10219 diff --git a/models/smi_ted/finetune/moleculenet/clintox/train.csv b/models/smi_ted/finetune/moleculenet/clintox/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..ff58b106a4761c714aae3c31c11ea210e1534d5b --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/clintox/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04bbee4a0d7fb4942292c9581f318909d06508d529a4a3a76590e6749417c1a7 +size 74357 diff --git a/models/smi_ted/finetune/moleculenet/clintox/valid.csv b/models/smi_ted/finetune/moleculenet/clintox/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..efface840e8b99c44772e26ca67fac655d0e5a8d --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/clintox/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3e2b9ab566ffc184c0590002bfbd6a42e6522209e6d6271968262844dde2905 +size 10255 diff --git a/models/smi_ted/finetune/moleculenet/esol/test.csv b/models/smi_ted/finetune/moleculenet/esol/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..835d35a8d39a5c355db39ae890b81598e5b3bc7b --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/esol/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7da41a7eab447fdfd163292b4a5eb8ef09a747fc82b0f1cc5c468e46b1b2ef5a +size 9999 diff --git a/models/smi_ted/finetune/moleculenet/esol/train.csv b/models/smi_ted/finetune/moleculenet/esol/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..4c3e49f99bd860f679a8d5006776f44051c2528d --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/esol/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:784ba31de05a43ecab98260c94a47e2c807f4d65c0f93d9a88fbd962515976c5 +size 77154 diff --git a/models/smi_ted/finetune/moleculenet/esol/valid.csv b/models/smi_ted/finetune/moleculenet/esol/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..6aa8495439476ef54e8bb536e7943c697b08f907 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/esol/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc30e7fa1f774e27ed56de7cfd77e21f07a5a2c38fcc6d928c0084a9a99181e5 +size 9892 diff --git a/models/smi_ted/finetune/moleculenet/freesolv/test.csv b/models/smi_ted/finetune/moleculenet/freesolv/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..08c43b5d60425f5e337889df1a07a197052301b4 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/freesolv/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8212c391ccbff3722a11d1bd3752b3a9dd187f2a7b33f8b9d2d594950b188d7 +size 3223 diff --git a/models/smi_ted/finetune/moleculenet/freesolv/train.csv b/models/smi_ted/finetune/moleculenet/freesolv/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..0baf09f23bc16c90fac16b6e45714122c2af568f --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/freesolv/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3b781e5d03dbd7d272347288161f92e8e66c628da50e3e2bc06de12225de22d +size 25053 diff --git a/models/smi_ted/finetune/moleculenet/freesolv/valid.csv b/models/smi_ted/finetune/moleculenet/freesolv/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..6b384a7841a93e57505088bf2dd643aaba76b091 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/freesolv/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b35d9c13a02291eefe85bd4b048ccc28f5326a3b018beb937aba12067b072d2 +size 3151 diff --git a/models/smi_ted/finetune/moleculenet/hiv/test.csv b/models/smi_ted/finetune/moleculenet/hiv/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..26c91f8f64c0df5fd65a6dc8a3e19990cf0feae6 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/hiv/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e86ca708a331966f6e7b06621a2e221a9f6ce45f0141e6cbe919fd64ec50fc7 +size 213176 diff --git a/models/smi_ted/finetune/moleculenet/hiv/train.csv b/models/smi_ted/finetune/moleculenet/hiv/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..8a61257627b0d1088bb89c2d8c7c75d5c7cd27da --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/hiv/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c289700d093d7ccbe55a583ad5cb3a670df931a19283ea66880413ed398358ff +size 1685863 diff --git a/models/smi_ted/finetune/moleculenet/hiv/valid.csv b/models/smi_ted/finetune/moleculenet/hiv/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..ff00c5124f70c15ba58ef905521c02e3bfbc8295 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/hiv/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33dd9321f709fb4fbc4545b1bfdc641eaebc410f6f698b9ed331678c5b3c3514 +size 212529 diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/test.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..222b1a5641c3d943aeefa1978086979ce88b5e25 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0947b182a1ba6b783fdca9fd01146cbe1e7bdf28d535e75765fda11a6b9a7458 +size 1541270 diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/toxicity-prediction.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/toxicity-prediction.csv new file mode 100644 index 0000000000000000000000000000000000000000..db36447e717b0c1889742e89a7459f149532ea1c --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/toxicity-prediction.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afeaf75aebdb67f18aeab58646ef0a31ae3b2c73f3d621afe3b648ba85990210 +size 7843582 diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/train.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..862b852596832a991b593753ab21a19684df13ca --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ee6370ec81777620a59316995e15f259c93bb52511d43756db1bb744d453485 +size 4587490 diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/valid.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..10e7d231ddb74b3c42ef3ef604282eb644bc33b7 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bde2141626dcc6e4d6e5cf78242eca4c1335724b15a297c86ce2ad36fbaf4c4c +size 1525896 diff --git a/models/smi_ted/finetune/moleculenet/lipophilicity/test.csv b/models/smi_ted/finetune/moleculenet/lipophilicity/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..7313b2bd6d643e0e71c3eae315198e8613f26a01 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/lipophilicity/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82a4f29bc409667a655ea3a7cddcdf74d8066150b15ae5074319ad6747bccfff +size 28696 diff --git a/models/smi_ted/finetune/moleculenet/lipophilicity/train.csv b/models/smi_ted/finetune/moleculenet/lipophilicity/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..b749481478b42a22eae4fcc2cdc0e09bfa4ddbc5 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/lipophilicity/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebd15220de24d82242b6a0b4ddbd985c9728b8e4797dcf20500655cb17338f36 +size 228704 diff --git a/models/smi_ted/finetune/moleculenet/lipophilicity/valid.csv b/models/smi_ted/finetune/moleculenet/lipophilicity/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..281520d34cf818d3507078f9004e866fe8f7cbf5 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/lipophilicity/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79a7682a54e0c37072dc43a5787cd77a40047011e04be204f1b961501be41613 +size 28318 diff --git a/models/smi_ted/finetune/moleculenet/muv/test.csv b/models/smi_ted/finetune/moleculenet/muv/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..c09b321e7bb7809aea99c1db1a1611dc7487ce55 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/muv/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5167910e2919d94164b3250266f846ef468aea7be1dea43698d08fa91da4933a +size 721037 diff --git a/models/smi_ted/finetune/moleculenet/muv/train.csv b/models/smi_ted/finetune/moleculenet/muv/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..55db4ad896a618d67ba1044d3d97a42a17570496 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/muv/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdc1c542f08aef281fb6c4b8727a92d1f8bfe94e4a370b9240dde03cc866cead +size 5781901 diff --git a/models/smi_ted/finetune/moleculenet/muv/valid.csv b/models/smi_ted/finetune/moleculenet/muv/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..a605a98f25564f1d7dbf26936b21d57892e68113 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/muv/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0743918c3610c4bceada0f529f7ffac7a196c50b355d92da3e06bbb9dac56ffe +size 723580 diff --git a/models/smi_ted/finetune/moleculenet/qm8/qm8.csv b/models/smi_ted/finetune/moleculenet/qm8/qm8.csv new file mode 100644 index 0000000000000000000000000000000000000000..1b81b402728c3e075c1ba6bdee6734df4085a7ae --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm8/qm8.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e70e56805edb2f0796b64f96af9b53dd8cca408d775612d47123f7d2da7d61d +size 4719270 diff --git a/models/smi_ted/finetune/moleculenet/qm8/test.csv b/models/smi_ted/finetune/moleculenet/qm8/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..8c6ad38a79f1716a7171c0822d6fca8d5bf6e2c8 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm8/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:418fd6afa3db6a219050bd36df002620bce787a2890debab5e63b0829c879914 +size 471657 diff --git a/models/smi_ted/finetune/moleculenet/qm8/train.csv b/models/smi_ted/finetune/moleculenet/qm8/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..5c21bbb766e815373251d74888daaa61652ce414 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm8/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:079af359ffdcba408646a556ea862ded8f483381af320e36d4981dcbe28b849b +size 3770636 diff --git a/models/smi_ted/finetune/moleculenet/qm8/valid.csv b/models/smi_ted/finetune/moleculenet/qm8/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..2c061eeadadccf63bc74614ab5bd72917afecb2b --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm8/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e032fee62d0117743a14095fffe07223b6f8c514a1961298388a6c6bd272fd5 +size 470821 diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9.csv new file mode 100644 index 0000000000000000000000000000000000000000..0c1cf60535385093f5edd4c17a8605429cce67d6 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm9/qm9.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e668f8c34e4bc392a90d417a50a5eed3b64b842a817a633024bdc054c68ccb4 +size 29856825 diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9_small_test.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_test.csv new file mode 100644 index 0000000000000000000000000000000000000000..d835a0e30540d12c3be00088b45df27ac382cf99 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56872786be70d65eb43e943417a293b5c64efb015da6e1cfa3cdd6bc06d8a057 +size 7255 diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9_small_train.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_train.csv new file mode 100644 index 0000000000000000000000000000000000000000..d835a0e30540d12c3be00088b45df27ac382cf99 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56872786be70d65eb43e943417a293b5c64efb015da6e1cfa3cdd6bc06d8a057 +size 7255 diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9_small_valid.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..d835a0e30540d12c3be00088b45df27ac382cf99 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56872786be70d65eb43e943417a293b5c64efb015da6e1cfa3cdd6bc06d8a057 +size 7255 diff --git a/models/smi_ted/finetune/moleculenet/qm9/test.csv b/models/smi_ted/finetune/moleculenet/qm9/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..21f1f45d52b056bbc1f18dab228f8528b2e324cc --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm9/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:352e17f5061840e0cdcffdc2e86d5c483ac5aa31a8e8feb1916825247e0ad323 +size 2986085 diff --git a/models/smi_ted/finetune/moleculenet/qm9/train.csv b/models/smi_ted/finetune/moleculenet/qm9/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..a04e7e9390a0040d697ecedc1b6cd54f58b166f1 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm9/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f0d2d1faa91c040ba845dbf8375ab1351d14d292b7840f13675afe50658a2ed +size 24186523 diff --git a/models/smi_ted/finetune/moleculenet/qm9/valid.csv b/models/smi_ted/finetune/moleculenet/qm9/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..0b68edafa00a14e066d5dee89cee71f82d409a9d --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/qm9/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f701d74d3e19b1bf2daeca0043a8a4403e1ba794831682509d45fa54a54587d1 +size 2687631 diff --git a/models/smi_ted/finetune/moleculenet/sider/test.csv b/models/smi_ted/finetune/moleculenet/sider/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..0e586ffe3cb5e7ecaa62067328195eb33954c1c1 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/sider/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03f44a09ac140628293a36e4ac6d23a719058b9956cfb07f5db7923e527e187f +size 18568 diff --git a/models/smi_ted/finetune/moleculenet/sider/train.csv b/models/smi_ted/finetune/moleculenet/sider/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..9e89e340142dde4ca670f2e7c760b26081d4a0c9 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/sider/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c71e635cccf17c173fedfca7d91fa35c66c4d95f1c558d92e67c1652b831fb75 +size 147151 diff --git a/models/smi_ted/finetune/moleculenet/sider/valid.csv b/models/smi_ted/finetune/moleculenet/sider/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..bb6f83f3067d535eea3c6cbe1b5fab4818747e8b --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/sider/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf790d7e965f90710e45c4f637352d1349c4fa420c9c3cb8e3bab4c86b38755c +size 19691 diff --git a/models/smi_ted/finetune/moleculenet/tox21/test.csv b/models/smi_ted/finetune/moleculenet/tox21/test.csv new file mode 100644 index 0000000000000000000000000000000000000000..105b35dc005c562395f8a80803aa759f88b37d70 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/tox21/test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06e9af48940e5eba55ad20229ba3d0f2c1c5110007aa16790cc86df9b0e5de14 +size 53905 diff --git a/models/smi_ted/finetune/moleculenet/tox21/tox21.csv b/models/smi_ted/finetune/moleculenet/tox21/tox21.csv new file mode 100644 index 0000000000000000000000000000000000000000..a31a8e23869b39932128dba54c716abca10b47a9 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/tox21/tox21.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1689278aa402ef8da840be1126c547253338d390cc8526714910a3b2a39fa1c9 +size 536070 diff --git a/models/smi_ted/finetune/moleculenet/tox21/train.csv b/models/smi_ted/finetune/moleculenet/tox21/train.csv new file mode 100644 index 0000000000000000000000000000000000000000..4f9ae089aafddb9bea9701c03114ca64ca722d10 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/tox21/train.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f5c366c2d80a6982fd95bdb37f41631c43a158b7d165e201b74ce8fe68c0a03 +size 416358 diff --git a/models/smi_ted/finetune/moleculenet/tox21/valid.csv b/models/smi_ted/finetune/moleculenet/tox21/valid.csv new file mode 100644 index 0000000000000000000000000000000000000000..922ee60b8888e464173ac9ede8f4fc6a15e16083 --- /dev/null +++ b/models/smi_ted/finetune/moleculenet/tox21/valid.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d1e23f3b582e66fdc74c7b89757bd832208f7183c3c5fcbabf6d45e321ffed7 +size 55019 diff --git a/models/smi_ted/finetune/smi_ted_large/bace/run_finetune_bace.sh b/models/smi_ted/finetune/smi_ted_large/bace/run_finetune_bace.sh new file mode 100644 index 0000000000000000000000000000000000000000..8d7a5cfe4b6f9af9664e99b48bfa117d347ad385 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/bace/run_finetune_bace.sh @@ -0,0 +1,25 @@ +python ../../finetune_classification.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/bace' \ + --dataset_name bace \ + --measure_name 'Class' \ + --checkpoints_folder './checkpoints_bace' \ + --loss_fn 'crossentropy' \ + --target_metric 'roc-auc' \ + --n_output 2 \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/bbbp/run_finetune_bbbp.sh b/models/smi_ted/finetune/smi_ted_large/bbbp/run_finetune_bbbp.sh new file mode 100644 index 0000000000000000000000000000000000000000..bc1d0b7d09eb2ad005700dc2c9d55c49bd714e1a --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/bbbp/run_finetune_bbbp.sh @@ -0,0 +1,25 @@ +python ../../finetune_classification.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/bbbp' \ + --dataset_name bbbp \ + --measure_name 'p_np' \ + --checkpoints_folder './checkpoints_bbbp' \ + --loss_fn 'crossentropy' \ + --target_metric 'roc-auc' \ + --n_output 2 \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/bert_vocab_curated.txt b/models/smi_ted/finetune/smi_ted_large/bert_vocab_curated.txt new file mode 100644 index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/bert_vocab_curated.txt @@ -0,0 +1,2393 @@ +<bos> +<eos> +<pad> +<mask> +C +c +( +) +1 +O +N +2 += +n +3 +[C@H] +[C@@H] +F +S +4 +Cl +- +o +s +[nH] +# +/ +Br +[C@] +[C@@] +[N+] +[O-] +5 +\ +. +I +6 +[S@] +[S@@] +P +[N-] +[Si] +7 +[n+] +[2H] +8 +[NH+] +B +9 +[C-] +[Na+] +[Cl-] +[c-] +[CH] +%10 +[NH2+] +[P+] +[B] +[I-] +%11 +[CH2-] +[O+] +[NH3+] +[C] +[Br-] +[IH2] +[S-] +[cH-] +%12 +[nH+] +[B-] +[K+] +[Sn] +[Se] +[CH-] +[HH] +[Y] +[n-] +[CH3-] +[SiH] +[S+] +%13 +[SiH2] +[Li+] +[NH-] +%14 +[Na] +[CH2] +[O-2] +[U+2] +[W] +[Al] +[P@] +[Fe+2] +[PH+] +%15 +[Cl+3] +[Zn+2] +[Ir] +[Mg+2] +[Pt+2] +[OH2+] +[As] +[Fe] +[OH+] +[Zr+2] +[3H] +[Ge] +[SiH3] +[OH-] +[NH4+] +[Cu+2] +[P@@] +p +[Pt] +%16 +[Ca+2] +[Zr] +[F-] +[C+] +[Ti] +[P-] +[V] +[se] +[U] +[O] +[Ni+2] +[Zn] +[Co] +[Ni] +[Pd+2] +[Cu] +%17 +[Cu+] +[Te] +[H+] +[CH+] +[Li] +[Pd] +[Mo] +[Ru+2] +[o+] +[Re] +[SH+] +%18 +[Ac] +[Cr] +[NH2-] +[K] +[13CH2] +[c] +[Zr+4] +[Tl] +[13C] +[Mn] +[N@+] +[Hg] +[Rh] +[Ti+4] +[Sb] +[Co+2] +[Ag+] +[Ru] +%19 +[N@@+] +[Ti+2] +[Al+3] +[Pb] +[I+] +[18F] +[s+] +[Rb+] +[Ba+2] +[H-] +[Fe+3] +[Ir+3] +[13cH] +%20 +[AlH2] +[Au+] +[13c] +[SH2+] +[Sn+2] +[Mn+2] +[Si-] +[Ag] +[N] +[Bi] +%21 +[In] +[CH2+] +[Y+3] +[Ga] +%22 +[Co+3] +[Au] +[13CH3] +[Mg] +[Cs+] +[W+2] +[Hf] +[Zn+] +[Se-] +[S-2] +[Ca] +[pH] +[ClH+] +[Ti+3] +%23 +[Ru+] +[SH-] +[13CH] +[IH+] +[Hf+4] +[Rf] +[OH3+] +%24 +[Pt+4] +[Zr+3] +[PH3+] +[Sr+2] +[Cd+2] +[Cd] +%25 +[Os] +[BH-] +[Sn+4] +[Cr+3] +[Ru+3] +[PH2+] +[Rh+2] +[V+2] +%26 +[Gd+3] +[Pb+2] +[PH] +[Hg+] +[Mo+2] +[AlH] +[Sn+] +%27 +[Pd+] +b +[Rh+3] +[Hg+2] +[15NH] +[14C] +%28 +[Mn+3] +[Si+] +[SeH] +[13C@H] +[NH] +[Ga+3] +[SiH-] +[13C@@H] +[Ce] +[Au+3] +[Bi+3] +[15N] +%29 +[BH3-] +[14cH] +[Ti+] +[Gd] +[cH+] +[Cr+2] +[Sb-] +%30 +[Be+2] +[Al+] +[te] +[11CH3] +[Sm] +[Pr] +[La] +%31 +[Al-] +[Ta] +[125I] +[BH2-] +[Nb] +[Si@] +%32 +[14c] +[Sb+3] +[Ba] +%33 +[Os+2] +[Si@@] +[La+3] +[15n] +[15NH2] +[Nd+3] +%34 +[14CH2] +[18O] +[Nd] +[GeH] +[Ni+3] +[Eu] +[Dy+3] +[Sc] +%36 +[Se-2] +[As+] +%35 +[AsH] +[Tb] +[Sb+5] +[Se+] +[Ce+3] +[c+] +[In+3] +[SnH] +[Mo+4] +%37 +[V+4] +[Eu+3] +[Hf+2] +%38 +[Pt+] +[p+] +[123I] +[Tl+] +[Sm+3] +%39 +[Yb+3] +%40 +[Yb] +[Os+] +%41 +[10B] +[Sc+3] +[Al+2] +%42 +[Sr] +[Tb+3] +[Po] +[Tc] +[PH-] +[AlH3] +[Ar] +[U+4] +[SnH2] +[Cl+2] +[si] +[Fe+] +[14CH3] +[U+3] +[Cl+] +%43 +[GeH2] +%44 +[Er+3] +[Mo+3] +[I+2] +[Fe+4] +[99Tc] +%45 +[11C] +%46 +[SnH3] +[S] +[Te+] +[Er] +[Lu+3] +[11B] +%47 +%48 +[P] +[Tm] +[Th] +[Dy] +[Pr+3] +[Ta+5] +[Nb+5] +[Rb] +[GeH3] +[Br+2] +%49 +[131I] +[Fm] +[Cs] +[BH4-] +[Lu] +[15nH] +%50 +[Ru+6] +[b-] +[Ho] +[Th+4] +[Ru+4] +%52 +[14CH] +%51 +[Cr+6] +[18OH] +[Ho+3] +[Ce+4] +[Bi+2] +[Co+] +%53 +[Yb+2] +[Fe+6] +[Be] +%54 +[SH3+] +[Np] +[As-] +%55 +[14C@@H] +[Ir+2] +[GaH3] +[p-] +[GeH4] +[Sn+3] +[Os+4] +%56 +[14C@H] +[sH+] +[19F] +[Eu+2] +[TlH] +%57 +[Cr+4] +%58 +[B@@-] +[SiH+] +[At] +[Am] +[Fe+5] +[AsH2] +[Si+4] +[B@-] +[Pu] +[SbH] +[P-2] +[Tm+3] +* +%59 +[se+] +[IH-] +%60 +[oH+] +[1H] +[15N+] +[124I] +[S@@+] +[P-3] +[H] +[IH2+] +[TeH] +[Xe] +[PH4+] +[Cr+] +[Cm] +[I+3] +%61 +[Nb+2] +[Ru+5] +%62 +[Ta+2] +[Tc+4] +[CH3+] +[Pm] +[Si@H] +[No] +%63 +[Cr+5] +[Th+2] +[Zn-2] +[13C@] +[Lr] +%64 +[99Tc+3] +%65 +[13C@@] +%66 +[Fe-] +[17O] +[siH] +[Sb+] +[OH] +[IH] +[11CH2] +[Cf] +[SiH2+] +[Gd+2] +[In+] +[Si@@H] +[Mn+] +[99Tc+4] +[Ga-] +%67 +[S@+] +[Ge+4] +[Tl+3] +[16OH] +%68 +[2H-] +[Ra] +[si-] +[NiH2] +[P@@H] +[Rh+] +[12C] +[35S] +[32P] +[SiH2-] +[AlH2+] +[16O] +%69 +[BiH] +[BiH2] +[Zn-] +[BH] +[Tc+3] +[Ir+] +[Ni+] +%70 +[InH2] +[InH] +[Nb+3] +[PbH] +[Bi+] +%71 +[As+3] +%72 +[18O-] +[68Ga+3] +%73 +[Pa] +[76Br] +[Tc+5] +[pH+] +[64Cu+2] +[Ru+8] +%74 +[PH2-] +[Si+2] +[17OH] +[RuH] +[111In+3] +[AlH+] +%75 +%76 +[W+] +[SbH2] +[PoH] +[Ru-] +[XeH] +[Tc+2] +[13C-] +[Br+] +[Pt-2] +[Es] +[Cu-] +[Mg+] +[3HH] +[P@H] +[ClH2+] +%77 +[SH] +[Au-] +[2HH] +%78 +[Sn-] +[11CH] +[PdH2] +0 +[Os+6] +%79 +[Mo+] +%80 +[al] +[PbH2] +[64Cu] +[Cl] +[12CH3] +%81 +[Tc+7] +[11c] +%82 +[Li-] +[99Tc+5] +[He] +[12c] +[Kr] +[RuH+2] +[35Cl] +[Pd-2] +[GaH2] +[4H] +[Sg] +[Cu-2] +[Br+3] +%83 +[37Cl] +[211At] +[IrH+2] +[Mt] +[Ir-2] +[In-] +[12cH] +[12CH2] +[RuH2] +[99Tc+7] +%84 +[15n+] +[ClH2+2] +[16N] +[111In] +[Tc+] +[Ru-2] +[12CH] +[si+] +[Tc+6] +%85 +%86 +[90Y] +[Pd-] +[188Re] +[RuH+] +[NiH] +[SiH3-] +[14n] +[CH3] +[14N] +[10BH2] +%88 +%89 +%90 +[34S] +[77Br] +[GaH] +[Br] +[Ge@] +[B@@H-] +[CuH] +[SiH4] +[3H-] +%87 +%91 +%92 +[67Cu] +[I] +[177Lu] +[ReH] +[67Ga+3] +[Db] +[177Lu+3] +[AlH2-] +[Si+3] +[Ti-2] +[RuH+3] +[al+] +[68Ga] +[2H+] +[B@H-] +[WH2] +[OsH] +[Ir-3] +[AlH-] +[Bk] +[75Se] +[14C@] +[Pt-] +[N@@H+] +[Nb-] +[13NH2] +%93 +[186Re] +[Tb+4] +[PtH] +[IrH2] +[Hg-2] +[AlH3-] +[PdH+] +[Md] +[RhH+2] +[11cH] +[Co-2] +[15N-] +[ZrH2] +%94 +[Hg-] +[127I] +[AsH2+] +[MoH2] +[Te+4] +[14C@@] +[As+5] +[SnH+3] +[Ge@@] +[6Li+] +[WH] +[Ne] +[14NH2] +[14NH] +[12C@@H] +[Os+7] +[RhH] +[Al-3] +[SnH+] +[15NH3+] +[Zr+] +[197Hg+] +%95 +%96 +[90Y+3] +[Os-2] +[98Tc+5] +[15NH3] +[bH-] +[33P] +[Zr-2] +[15O] +[Rh-] +[PbH3] +[PH2] +[Ni-] +[CuH+] +%97 +%98 +%99 +[Os+5] +[PtH+] +[ReH4] +[16NH] +[82Br] +[W-] +[18F-] +[15NH4+] +[Se+4] +[SeH-] +[SH4] +[67Cu+2] +[12C@H] +[AsH3] +[HgH] +[10B-] +[99Tc+6] +[117Sn+4] +[Te@] +[P@+] +[35SH] +[SeH+] +[Ni-2] +[Al-2] +[TeH2] +[Bh] +[99Tc+2] +[Os+8] +[PH-2] +[7Li+] +[14nH] +[AlH+2] +[18FH] +[SnH4] +[18O-2] +[IrH] +[13N] +[Te@@] +[Rh-3] +[15NH+] +[AsH3+] +[SeH2] +[AsH+] +[CoH2] +[16NH2] +[AsH-] +[203Hg+] +[P@@+] +[166Ho+3] +[60Co+3] +[13CH2-] +[SeH2+] +[75Br] +[TlH2] +[80Br] +[siH+] +[Ca+] +[153Sm+3] +[PdH] +[225Ac] +[13CH3-] +[AlH4-] +[FeH] +[13CH-] +[14C-] +[11C-] +[153Sm] +[Re-] +[te+] +[13CH4] +[ClH+2] +[8CH2] +[99Mo] +[ClH3+3] +[SbH3] +[25Mg+2] +[16N+] +[SnH2+] +[PH4] +[11C@H] +[122I] +[Re-2] +[RuH2+2] +[ZrH] +[Bi-] +[Pr+] +[Rn] +[Fr] +[36Cl] +[18o] +[YH] +[79Br] +[121I] +[113In+3] +[InH4-] +[TaH] +[RhH2] +[Ta-] +[67Ga] +[ZnH+] +[SnH2-] +[OsH2] +[16F] +[FeH2] +[14O] +[PbH2+2] +[BH2] +[6H] +[125Te] +[197Hg] +[TaH2] +[TaH3] +[76As] +[Nb-2] +[14N+] +[125I-] +[33S] +[IH2+2] +[NH2] +[PtH2] +[MnH] +[19C] +[17F] +[1H-] +[SnH4+2] +[Mn-2] +[15NH2+] +[TiH2] +[ReH7] +[Cd-2] +[Fe-3] +[SH2] +[17O-] +[siH-] +[CoH+] +[VH] +[10BH] +[Ru-3] +[13O] +[5H] +[CoH] +[PH5] +[15n-] +[153Gd] +[12C@] +[11CH3-] +[IrH3] +[RuH3] +[74Se] +[Se@] +[Hf+] +[77Se] +[166Ho] +[59Fe+2] +[203Hg] +[18OH-] +[8CH] +[12C@@] +[11CH4] +[15C] +[249Cf] +[PbH4] +[64Zn] +[PH3] +[99Tc+] +[14c-] +[149Pm] +[IrH4] +[Se@@] +[13OH] +[14CH3-] +[28Si] +[Rh-2] +[Fe-2] +[131I-] +[51Cr] +[62Cu+2] +[81Br] +[121Sb] +[7Li] +[89Zr+4] +[SbH3+] +[11C@@H] +[98Tc] +[59Fe+3] +[BiH2+] +[SbH+] +[TiH] +[14NH3] +[15OH] +[119Sn] +[201Hg] +[MnH+] +[201Tl] +[51Cr+3] +[123I-] +[MoH] +[AlH6-3] +[MnH2] +[WH3] +[213Bi+3] +[SnH2+2] +[123IH] +[13CH+] +[Zr-] +[74As] +[13C+] +[32P+] +[KrH] +[SiH+2] +[ClH3+2] +[13NH] +[9CH2] +[ZrH2+2] +[87Sr+2] +[35s] +[239Pu] +[198Au] +[241Am] +[203Hg+2] +[V+] +[YH2] +[SH5] +[195Pt] +[203Pb] +[RuH4] +[ThH2] +[AuH] +[66Ga+3] +[11B-] +[F] +[24Na+] +[85Sr+2] +[201Tl+] +[14CH4] +[32S] +[TeH2+] +[ClH2+3] +[AgH] +[Ge@H] +[44Ca+2] +[Os-] +[31P] +[15nH+] +[SbH4] +[TiH+] +[Ba+] +[57Co+2] +[Ta+] +[125IH] +[77As] +[129I] +[Fe-4] +[Ta-2] +[19O] +[12O] +[BiH3] +[237Np] +[252Cf] +[86Y] +[Cr-2] +[89Y] +[195Pt+2] +[si+2] +[58Fe+2] +[Hs] +[S@@H] +[OsH6] +[GdH2] +[IH3] +[8CH4] +[164Dy+3] +[47Ca+2] +[57Co] +[NbH2] +[ReH2] +[ZnH2] +[CrH2] +[17NH] +[ZrH3] +[RhH3] +[12C-] +[18O+] +[Bi-2] +[ClH4+3] +[Ni-3] +[Ag-] +[111In-] +[Mo-2] +[55Fe+3] +[204Hg+] +[35Cl-] +[211Pb] +[75Ge] +[8B] +[TeH3] +[SnH3+] +[Zr-3] +[28F] +[249Bk] +[169Yb] +[34SH] +[6Li] +[94Tc] +[197Au] +[195Pt+4] +[169Yb+3] +[32Cl] +[82Se] +[159Gd+3] +[213Bi] +[CoH+2] +[36S] +[35P] +[Ru-4] +[Cr-3] +[60Co] +[1H+] +[18CH2] +[Cd-] +[152Sm+3] +[106Ru] +[238Pu] +[220Rn] +[45Ca+2] +[89Sr+2] +[239Np] +[90Sr+2] +[137Cs+] +[165Dy] +[68GaH3] +[65Zn+2] +[89Zr] +[BiH2+2] +[62Cu] +[165Dy+3] +[238U] +[105Rh+3] +[70Zn] +[12B] +[12OH] +[18CH] +[17CH] +[OsH3] +[SbH-] +[SH6] +[AlH2-2] +[42K] +[76Br-] +[71As] +[NbH3] +[ReH3] +[OsH-] +[WH4] +[MoH3] +[OsH4] +[RuH6] +[PtH3] +[CuH2] +[CoH3] +[TiH4] +[64Zn+2] +[Si-2] +[79BrH] +[14CH2-] +[PtH2+2] +[Os-3] +[29Si] +[Ti-] +[Se+6] +[22Na+] +[42K+] +[131Cs+] +[86Rb+] +[134Cs+] +[209Po] +[208Po] +[81Rb+] +[203Tl+] +[Zr-4] +[148Sm] +[147Sm] +[37Cl-] +[12CH4] +[Ge@@H] +[63Cu] +[13CH2+] +[AsH2-] +[CeH] +[SnH-] +[UH] +[9c] +[21CH3] +[TeH+] +[57Co+3] +[8BH2] +[12BH2] +[19BH2] +[9BH2] +[YbH2] +[CrH+2] +[208Bi] +[152Gd] +[61Cu] +[115In] +[60Co+2] +[13NH2-] +[120I] +[18OH2] +[75SeH] +[SbH2+] +[144Ce] +[16n] +[113In] +[22nH] +[129I-] +[InH3] +[32PH3] +[234U] +[235U] +[59Fe] +[82Rb+] +[65Zn] +[244Cm] +[147Pm] +[91Y] +[237Pu] +[231Pa] +[253Cf] +[127Te] +[187Re] +[236Np] +[235Np] +[72Zn] +[253Es] +[159Dy] +[62Zn] +[101Tc] +[149Tb] +[124I-] +[SeH3+] +[210Pb] +[40K] +[210Po] +[214Pb] +[218Po] +[214Po] +[7Be] +[212Pb] +[205Pb] +[209Pb] +[123Te] +[202Pb] +[72As] +[201Pb] +[70As] +[73Ge] +[200Pb] +[198Pb] +[66Ga] +[73Se] +[195Pb] +[199Pb] +[144Ce+3] +[235U+2] +[90Tc] +[114In+3] +[128I] +[100Tc+] +[82Br-] +[191Pt+2] +[191Pt+4] +[193Pt+4] +[31PH3] +[125I+2] +[131I+2] +[125Te+4] +[82Sr+2] +[149Sm] +[81BrH] +[129Xe] +[193Pt+2] +[123I+2] +[Cr-] +[Co-] +[227Th+4] +[249Cf+3] +[252Cf+3] +[187Os] +[16O-] +[17O+] +[16OH-] +[98Tc+7] +[58Co+2] +[69Ga+3] +[57Fe+2] +[43K+] +[16C] +[52Fe+3] +[SeH5] +[194Pb] +[196Pb] +[197Pb] +[213Pb] +[9B] +[19B] +[11CH-] +[9CH] +[20OH] +[25OH] +[8cH] +[TiH+3] +[SnH6+3] +[N@H+] +[ZnH] +[VH3] +[52Mn+2] +[64Ga] +[13B] +[216Bi] +[117Sn+2] +[232Th] +[SnH+2] +[BiH5] +[77Kr] +[103Cd] +[62Ni] +[LaH3] +[SmH3] +[EuH3] +[MoH5] +[64Ni] +[66Zn] +[68Zn] +[186W] +[FeH4] +[MoH4] +[HgH2] +[15NH2-] +[UH2] +[204Hg] +[GaH4-] +[ThH4] +[WH6] +[PtH4] +[VH2] +[UH3] +[FeH3] +[RuH5] +[BiH4] +[80Br-] +[CeH3] +[37ClH] +[157Gd+3] +[205Tl] +[203Tl] +[62Cu+] +[64Cu+] +[61Cu+] +[37SH2] +[30Si] +[28Al] +[19OH2] +[8He] +[6He] +[153Pm] +[209Bi] +[66Zn+2] +[10CH4] +[191Ir] +[66Cu] +[16O+] +[25O] +[10c] +[Co-3] +[Sn@@] +[17OH-] +[206Po] +[204Po] +[202Po] +[201Po] +[200Po] +[199Po] +[198Po] +[197Po] +[196Po] +[195Po] +[194Po] +[193Po] +[192Po] +[191Po] +[190Po] +[217Po] +[BiH4-] +[TeH4] +[222Ra] +[62Ga] +[39Ar] +[144Sm] +[58Fe] +[153Eu] +[85Rb] +[171Yb] +[172Yb] +[114Cd] +[51Fe] +[142Ce] +[207Tl] +[92Mo] +[115Sn] +[140Ce] +[202Hg] +[180W] +[182W] +[183W] +[184W] +[96Mo] +[47Ti] +[111Cd] +[143Nd] +[145Nd] +[126Te] +[128Te] +[130Te] +[185Re] +[97Mo] +[98Mo] +[183Re] +[52V] +[80Se] +[87Kr] +[137Xe] +[196Au] +[146Ce] +[88Kr] +[51Ti] +[138Xe] +[112Cd] +[116Sn] +[120Sn] +[28SiH3] +[35S-] +[15NH-] +[13CH3+] +[34S+] +[34s] +[SiH4-] +[100Tc+5] +[NiH2+2] +[239Th] +[186Lu] +[AuH3] +[I@@-] +[XeH2] +[B+] +[16CH2] +[8C] +[TaH5] +[FeH4-] +[19C@H] +[10NH] +[FeH6-3] +[22CH] +[25N] +[25N+] +[25N-] +[21CH2] +[18cH] +[113I] +[ScH3] +[30PH3] +[43Ca+2] +[41Ca+2] +[106Cd] +[122Sn] +[18CH3] +[58Co+3] +[98Tc+4] +[70Ge] +[76Ge] +[108Cd] +[116Cd] +[130Xe] +[94Mo] +[124Sn] +[186Os] +[188Os] +[190Os] +[192Os] +[106Pd] +[110Pd] +[120Te] +[132Ba] +[134Ba] +[136Ba] +[136Ce] +[138Ce] +[156Dy] +[158Dy] +[160Dy] +[163Dy] +[162Er] +[164Er] +[167Er] +[176Hf] +[26Mg] +[144Nd] +[150Nd] +[41K] +[46Ti] +[48Ti] +[49Ti] +[50Ti] +[170Yb] +[173Yb] +[91Zr] +[92Zr] +[96Zr] +[34S-] +[CuH2-] +[38Cl] +[25Mg] +[51V] +[93Nb] +[95Mo] +[45Sc] +[123Sb] +[139La] +[9Be] +[99Y+3] +[99Y] +[156Ho] +[67Zn] +[144Ce+4] +[210Tl] +[42Ca] +[54Fe] +[193Ir] +[92Nb] +[141Cs] +[52Cr] +[35ClH] +[46Ca] +[139Cs] +[65Cu] +[71Ga] +[60Ni] +[16NH3] +[148Nd] +[72Ge] +[161Dy] +[49Ca] +[43Ca] +[8Be] +[48Ca] +[44Ca] +[120Xe] +[80Rb] +[215At] +[180Re] +[146Sm] +[19Ne] +[74Kr] +[134La] +[76Kr] +[219Fr] +[121Xe] +[220Fr] +[216At] +[223Ac] +[218At] +[37Ar] +[135I] +[110Cd] +[94Tc+7] +[86Y+3] +[135I-] +[15O-2] +[151Eu+3] +[161Tb+3] +[197Hg+2] +[109Cd+2] +[191Os+4] +[170Tm+3] +[205Bi+3] +[233U+4] +[126Sb+3] +[127Sb+3] +[132Cs+] +[136Eu+3] +[136Eu] +[125Sn+4] +[175Yb+3] +[100Mo] +[22Ne] +[13c-] +[13NH4+] +[17C] +[9C] +[31S] +[31SH] +[133I] +[126I] +[36SH] +[30S] +[32SH] +[19CH2] +[19c] +[18c] +[15F] +[10C] +[RuH-] +[62Zn+2] +[32ClH] +[33ClH] +[78BrH] +[12Li+] +[12Li] +[233Ra] +[68Ge+4] +[44Sc+3] +[91Y+3] +[106Ru+3] +[PoH2] +[AtH] +[55Fe] +[233U] +[210PoH2] +[230Th] +[228Th] +[222Rn] +[35SH2] +[227Th] +[192Ir] +[133Xe] +[81Kr] +[95Zr] +[240Pu] +[54Mn] +[103Ru] +[95Nb] +[109Cd] +[141Ce] +[85Kr] +[110Ag] +[58Co] +[241Pu] +[234Th] +[140La] +[63Ni] +[152Eu] +[132IH] +[226Rn] +[154Eu] +[36ClH] +[228Ac] +[155Eu] +[106Rh] +[243Am] +[227Ac] +[243Cm] +[236U] +[144Pr] +[232U] +[32SH2] +[88Y] +[82BrH] +[135IH] +[242Cm] +[115Cd] +[242Pu] +[46Sc] +[56Mn] +[234Pa] +[41Ar] +[147Nd] +[187W] +[151Sm] +[59Ni] +[233Pa] +[52Mn] +[94Nb] +[219Rn] +[236Pu] +[13NH3] +[93Zr] +[51Cr+6] +[TlH3] +[123Xe] +[160Tb] +[170Tm] +[182Ta] +[175Yb] +[93Mo] +[143Ce] +[191Os] +[126IH] +[48V] +[113Cd] +[47Sc] +[181Hf] +[185W] +[143Pr] +[191Pt] +[181W] +[33PH3] +[97Ru] +[97Tc] +[111Ag] +[169Er] +[107Pd] +[103Ru+2] +[34SH2] +[137Ce] +[242Am] +[117SnH2] +[57Ni] +[239U] +[60Cu] +[250Cf] +[193Au] +[69Zn] +[55Co] +[139Ce] +[127Xe] +[159Gd] +[56Co] +[177Hf] +[244Pu] +[38ClH] +[142Pr] +[199Hg] +[179Hf] +[178Hf] +[237U] +[156Eu] +[157Eu] +[105Ru] +[171Tm] +[199Au] +[155Sm] +[80BrH] +[108Ag] +[128IH] +[48Sc] +[45Ti] +[176Lu] +[121SnH2] +[148Pm] +[57Fe] +[10BH3] +[96Tc] +[133IH] +[143Pm] +[105Rh] +[130IH] +[134IH] +[131IH] +[71Zn] +[105Ag] +[97Zr] +[235Pu] +[231Th] +[109Pd] +[93Y] +[190Ir] +[135Xe] +[53Mn] +[134Ce] +[234Np] +[240Am] +[246Cf] +[240Cm] +[241Cm] +[226Th] +[39ClH] +[229Th] +[245Cm] +[240U] +[240Np] +[249Cm] +[243Pu] +[145Pm] +[199Pt] +[246Bk] +[193Pt] +[230U] +[250Cm] +[44Ti] +[175Hf] +[254Fm] +[255Fm] +[257Fm] +[92Y] +[188Ir] +[171Lu] +[257Md] +[247Bk] +[121IH] +[250Bk] +[179Lu] +[224Ac] +[195Hg] +[244Am] +[246Pu] +[194Au] +[252Fm] +[173Hf] +[246Cm] +[135Ce] +[49Cr] +[248Cf] +[247Cm] +[248Cm] +[174Ta] +[176Ta] +[154Tb] +[172Ta] +[177Ta] +[175Ta] +[180Ta] +[158Tb] +[115Ag] +[189Os] +[251Cf] +[145Pr] +[147Pr] +[76BrH] +[102Rh] +[238Np] +[185Os] +[246Am] +[233Np] +[166Dy] +[254Es] +[244Cf] +[193Os] +[245Am] +[245Bk] +[239Am] +[238Am] +[97Nb] +[245Pu] +[254Cf] +[188W] +[250Es] +[251Es] +[237Am] +[182Hf] +[258Md] +[232Np] +[238Cm] +[60Fe] +[109Pd+2] +[234Pu] +[141Ce+3] +[136Nd] +[136Pr] +[173Ta] +[110Ru] +[147Tb] +[253Fm] +[139Nd] +[178Re] +[177Re] +[200Au] +[182Re] +[156Tb] +[155Tb] +[157Tb] +[161Tb] +[161Ho] +[167Tm] +[173Lu] +[179Ta] +[171Er] +[44Sc] +[49Sc] +[49V] +[51Mn] +[90Nb] +[88Nb] +[88Zr] +[36SH2] +[174Yb] +[178Lu] +[179W] +[83BrH] +[107Cd] +[75BrH] +[62Co] +[48Cr] +[63Zn] +[102Ag] +[154Sm] +[168Er] +[65Ni] +[137La] +[187Ir] +[144Pm] +[146Pm] +[160Gd] +[166Yb] +[162Dy] +[47V] +[141Nd] +[141Sm] +[166Er] +[150Sm] +[146Eu] +[149Eu] +[174Lu] +[17NH3] +[102Ru] +[170Hf] +[188Pt] +[61Ni] +[56Ni] +[149Gd] +[151Gd] +[141Pm] +[147Gd] +[146Gd] +[161Er] +[103Ag] +[145Eu] +[153Tb] +[155Dy] +[184Re] +[180Os] +[182Os] +[186Pt] +[181Os] +[181Re] +[151Tb] +[178Ta] +[178W] +[189Pt] +[194Hg] +[145Sm] +[150Tb] +[132La] +[158Gd] +[104Ag] +[193Hg] +[94Ru] +[137Pr] +[155Ho] +[117Cd] +[99Ru] +[146Nd] +[218Rn] +[95Y] +[79Kr] +[120IH] +[138Pr] +[100Pd] +[166Tm] +[90Mo] +[151Nd] +[231U] +[138Nd] +[89Nb] +[98Nb] +[162Ho] +[142Sm] +[186Ta] +[104Tc] +[184Ta] +[185Ta] +[170Er] +[107Rh] +[131La] +[169Lu] +[74BrH] +[150Pm] +[172Tm] +[197Pt] +[230Pu] +[170Lu] +[86Zr] +[176W] +[177W] +[101Pd] +[105Pd] +[108Pd] +[149Nd] +[164Ho] +[159Ho] +[167Ho] +[176Yb] +[156Sm] +[77BrH] +[189Re] +[99Rh] +[100Rh] +[151Pm] +[232Pa] +[228Pa] +[230Pa] +[66Ni] +[194Os] +[135La] +[138La] +[141La] +[142La] +[195Ir] +[96Nb] +[157Ho] +[183Hf] +[162Tm] +[172Er] +[148Eu] +[150Eu] +[15CH4] +[89Kr] +[143La] +[58Ni] +[61Co] +[158Eu] +[165Er] +[167Yb] +[173Tm] +[175Tm] +[172Hf] +[172Lu] +[93Tc] +[177Yb] +[124IH] +[194Ir] +[147Eu] +[101Mo] +[180Hf] +[189Ir] +[87Y] +[43Sc] +[195Au] +[112Ag] +[84BrH] +[106Ag] +[109Ag] +[101Rh] +[162Yb] +[228Rn] +[139Pr] +[94Y] +[201Au] +[40PH3] +[110Ag+] +[104Cd] +[133Ba+2] +[226Ac] +[145Gd] +[186Ir] +[184Ir] +[224Rn] +[185Ir] +[182Ir] +[184Hf] +[200Pt] +[227Pa] +[178Yb] +[72Br-] +[72BrH] +[248Am] +[238Th] +[161Gd] +[35S-2] +[107Ag] +[FeH6-4] +[89Sr] +[SnH3-] +[SeH3] +[TeH3+] +[SbH4+] +[AsH4+] +[4He] +[AsH3-] +[1HH] +[3H+] +[82Rb] +[85Sr] +[90Sr] +[137Cs] +[133Ba] +[131Cs] +[SbH5] +[224Ra] +[22Na] +[210Bi] +[214Bi] +[228Ra] +[127Sb] +[136Cs] +[125Sb] +[134Cs] +[140Ba] +[45Ca] +[206Pb] +[207Pb] +[24Na] +[86Rb] +[212Bi] +[208Pb] +[124Sb] +[204Pb] +[44K] +[129Te] +[113Sn] +[204Tl] +[87Sr] +[208Tl] +[87Rb] +[47Ca] +[135Cs] +[216Po] +[137Ba] +[207Bi] +[212Po] +[79Se] +[223Ra] +[86Sr] +[122Sb] +[26Al] +[32Si] +[126Sn] +[225Ra] +[114In] +[72Ga] +[132Te] +[10Be] +[125Sn] +[73As] +[206Bi] +[117Sn] +[40Ca] +[41Ca] +[89Rb] +[116In] +[129Sb] +[91Sr] +[71Ge] +[139Ba] +[69Ga] +[120Sb] +[121Sn] +[123Sn] +[131Te] +[77Ge] +[135Ba] +[82Sr] +[43K] +[131Ba] +[92Sr] +[88Rb] +[129Cs] +[144Cs] +[127Cs] +[200Tl] +[202Tl] +[141Ba] +[117Sb] +[116Sb] +[78As] +[131Sb] +[126Sb] +[128Sb] +[130Sb] +[67Ge] +[68Ge] +[78Ge] +[66Ge] +[223Fr] +[132Cs] +[125Cs] +[138Cs] +[133Te] +[84Rb] +[83Rb] +[81Rb] +[142Ba] +[200Bi] +[115Sb] +[194Tl] +[70Se] +[112In] +[118Sb] +[70Ga] +[27Mg] +[202Bi] +[83Se] +[9Li] +[69As] +[79Rb] +[81Sr] +[83Sr] +[78Se] +[109In] +[29Al] +[118Sn] +[117In] +[119Sb] +[114Sn] +[138Ba] +[69Ge] +[73Ga] +[74Ge] +[206Tl] +[199Tl] +[130Cs] +[28Mg] +[116Te] +[112Sn] +[126Ba] +[211Bi] +[81Se] +[127Sn] +[143Cs] +[134Te] +[80Sr] +[45K] +[215Po] +[207Po] +[111Sn] +[211Po] +[128Ba] +[198Tl] +[227Ra] +[213Po] +[220Ra] +[128Sn] +[203Po] +[205Po] +[65Ga] +[197Tl] +[88Sr] +[110In] +[31Si] +[201Bi] +[121Te] +[205Bi] +[203Bi] +[195Tl] +[209Tl] +[110Sn] +[222Fr] +[207At] +[119In] +[As@] +[129IH] +[157Dy] +[111IH] +[230Ra] +[144Pr+3] +[SiH3+] +[3He] +[AsH5] +[72Se] +[95Tc] +[103Pd] +[121Sn+2] +[211Rn] +[38SH2] +[127IH] +[74Br-] +[133I-] +[100Tc+4] +[100Tc] +[36Cl-] +[89Y+3] +[104Rh] +[152Sm] +[226Ra] +[19FH] +[104Pd] +[148Gd] +[157Lu] +[33SH2] +[121I-] +[17FH] +[71Se] +[157Sm] +[148Tb] +[164Dy] +[15OH2] +[15O+] +[39K] +[40Ar] +[50Cr+3] +[50Cr] +[52Ti] +[103Pd+2] +[130Ba] +[142Pm] +[153Gd+3] +[151Eu] +[103Rh] +[124Xe] +[152Tb] +[17OH2] +[20Ne] +[52Fe] +[94Zr+4] +[94Zr] +[149Pr] +[16OH2] +[53Cr+6] +[53Cr] +[81Br-] +[112Pd] +[125Xe] +[155Gd] +[157Gd] +[168Yb] +[184Os] +[166Tb] +[221Fr] +[212Ra] +[75Br-] +[79Br-] +[113Ag] +[23Na] +[34Cl-] +[34ClH] +[38Cl-] +[56Fe] +[68Cu] +[77Br-] +[90Zr+4] +[90Zr] +[102Pd] +[154Eu+3] +[57Mn] +[165Tm] +[152Dy] +[217At] +[77se] +[13cH-] +[122Te] +[156Gd] +[124Te] +[53Ni] +[131Xe] +[174Hf+4] +[174Hf] +[76Se] +[168Tm] +[167Dy] +[154Gd] +[95Ru] +[210At] +[85Br] +[59Co] +[122Xe] +[27Al] +[54Cr] +[198Hg] +[85Rb+] +[214Tl] +[229Rn] +[218Pb] +[218Bi] +[167Tm+3] +[18o+] +[P@@H+] +[P@H+] +[13N+] +[212Pb+2] +[217Bi] +[249Cf+2] +[18OH3+] +[90Sr-] +[Cf+3] +[200Hg] +[86Tc] +[141Pr+3] +[141Pr] +[16nH] +[14NH4+] +[132Xe] +[83Kr] +[70Zn+2] +[137Ba+2] +[36Ar] +[38Ar] +[21Ne] +[126Xe] +[136Xe] +[128Xe] +[134Xe] +[84Kr] +[86Kr] +[78Kr] +[80Kr] +[82Kr] +[67Zn+2] +[65Cu+2] +[110Te] +[58Fe+3] +[142Nd] +[38K] +[198Au+3] +[122IH] +[38PH3] +[130I-] +[40K+] +[38K+] +[28Mg+2] +[208Tl+] +[13OH2] +[198Bi] +[192Bi] +[194Bi] +[196Bi] +[132I-] +[83Sr+2] +[169Er+3] +[122I-] +[120I-] +[92Sr+2] +[126I-] +[24Mg] +[84Sr] +[118Pd+2] +[118Pd] +[AsH4] +[127I-] +[9C-] +[11CH3+] +[17B] +[7B] +[4HH] +[18C-] +[22CH3-] +[22CH4] +[17C-] +[15CH3] +[16CH3] +[11NH3] +[21NH3] +[11N-] +[11NH] +[16CH] +[17CH2] +[99Ru+2] +[181Ta+2] +[181Ta] +[20CH] +[32PH2] +[55Fe+2] +[SH3] +[S@H] +[Mn-] +[IH4] +[ThH] +[GaH-] +[BiH+] +[EuH2] +[FeH4-3] +[FeH6] +[IH5] +[NiH+] +[SrH2] +[VH4] +[YH3] +[seH+] +<unk> diff --git a/models/smi_ted/finetune/smi_ted_large/clintox/run_finetune_clintox.sh b/models/smi_ted/finetune/smi_ted_large/clintox/run_finetune_clintox.sh new file mode 100644 index 0000000000000000000000000000000000000000..48fe562011a3e35d4ba56c471589b20716492691 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/clintox/run_finetune_clintox.sh @@ -0,0 +1,23 @@ +python ../../finetune_classification_multitask.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 100 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/clintox' \ + --dataset_name clintox \ + --checkpoints_folder './checkpoints_clintox' \ + --loss_fn 'bceloss' \ + --target_metric 'roc-auc' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/esol/run_finetune_esol.sh b/models/smi_ted/finetune/smi_ted_large/esol/run_finetune_esol.sh new file mode 100644 index 0000000000000000000000000000000000000000..25785cec18f4b11c1e4bc336841509292f5240e1 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/esol/run_finetune_esol.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/esol' \ + --dataset_name esol \ + --measure_name 'measured log solubility in mols per litre' \ + --checkpoints_folder './checkpoints_esol' \ + --loss_fn 'rmse' \ + --target_metric 'rmse' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/freesolv/run_finetune_freesolv.sh b/models/smi_ted/finetune/smi_ted_large/freesolv/run_finetune_freesolv.sh new file mode 100644 index 0000000000000000000000000000000000000000..fd930a648830373aa38b973d9418a80ddd62f440 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/freesolv/run_finetune_freesolv.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/freesolv' \ + --dataset_name freesolv \ + --measure_name 'expt' \ + --checkpoints_folder './checkpoints_freesolv' \ + --loss_fn 'rmse' \ + --target_metric 'rmse' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/hiv/run_finetune_hiv.sh b/models/smi_ted/finetune/smi_ted_large/hiv/run_finetune_hiv.sh new file mode 100644 index 0000000000000000000000000000000000000000..978f7fc09ff89b64bca71c726229c3b81fce608d --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/hiv/run_finetune_hiv.sh @@ -0,0 +1,25 @@ +python ../../finetune_classification.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 1e-7 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/hiv' \ + --dataset_name hiv \ + --measure_name 'HIV_active' \ + --checkpoints_folder './checkpoints_hiv_1e-7' \ + --loss_fn 'crossentropy' \ + --target_metric 'roc-auc' \ + --n_output 2 \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/lipo/run_finetune_lipo.sh b/models/smi_ted/finetune/smi_ted_large/lipo/run_finetune_lipo.sh new file mode 100644 index 0000000000000000000000000000000000000000..a71a544afaa2accea3fd637454e90a294424700d --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/lipo/run_finetune_lipo.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 1e-6 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/lipophilicity' \ + --dataset_name lipophilicity \ + --measure_name 'y' \ + --checkpoints_folder './checkpoints_lipophilicity' \ + --loss_fn 'rmse' \ + --target_metric 'rmse' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/load.py b/models/smi_ted/finetune/smi_ted_large/load.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc1129cfe73b3c9161413e84623bb0ff7294528 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/load.py @@ -0,0 +1,504 @@ +PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +# Transformers +from fast_transformers.attention import AttentionLayer +from fast_transformers.events import QKVEvent +from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer +from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder +from fast_transformers.builders.attention_builders import AttentionBuilder +from fast_transformers.feature_maps import GeneralizedRandomFeatures +from fast_transformers.masking import LengthMask +from transformers import BertTokenizer + +# Data +import numpy as np +import pandas as pd + +# Standard library +from functools import partial +import regex as re +import random +import os +import gc +from tqdm import tqdm +tqdm.pandas() + + +class MolTranBertTokenizer(BertTokenizer): + def __init__(self, vocab_file: str = '', + do_lower_case=False, + unk_token='<pad>', + sep_token='<eos>', + pad_token='<pad>', + cls_token='<bos>', + mask_token='<mask>', + **kwargs): + super().__init__(vocab_file, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + self.regex_tokenizer = re.compile(PATTERN) + self.wordpiece_tokenizer = None + self.basic_tokenizer = None + with open(vocab_file) as f: + self.padding_idx = f.readlines().index(pad_token+'\n') + + def _tokenize(self, text): + split_tokens = self.regex_tokenizer.findall(text) + return split_tokens + + def convert_idx_to_tokens(self, idx_tensor): + tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()] + return tokens + + def convert_tokens_to_string(self, tokens): + stopwords = ['<bos>', '<eos>'] + clean_tokens = [word for word in tokens if word not in stopwords] + out_string = ''.join(clean_tokens) + return out_string + + def get_padding_idx(self): + return self.padding_idx + + def idx_to_smiles(self, torch_model, idx): + '''Convert tokens idx back to SMILES text''' + rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx) + flat_list_tokens = [item for sublist in rev_tokens for item in sublist] + decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens) + return decoded_smiles + + +## Transformer layers +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = 0 + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self.cos_cached = emb.cos()[None,:, None, :] + self.sin_cached = emb.sin()[None,:, None, :] + + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotateAttentionLayer(AttentionLayer): + """Rotate attention layer inherits from fast_transformer attention layer. + The only thing added is an Embedding encoding, for more information + on the attention layer see the fast_transformers code + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, + d_values=d_values, event_dispatcher=event_dispatcher) + + self.rotaryemb = RotaryEmbedding(d_keys) + print('Using Rotation Embedding') + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """ + Using the same frame work as the fast_Transformers attention layer + but injecting rotary information to the queries and the keys + after the keys and queries are projected. + In the argument description we make use of the following sizes + - N: the batch size + - L: The maximum length of the queries + - S: The maximum length of the keys (the actual length per sequence + is given by the length mask) + - D: The input feature dimensionality passed in the constructor as + 'd_model' + Arguments + --------- + queries: (N, L, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + Returns + ------- + The new value for each query as a tensor of shape (N, L, D). + """ + # Extract the dimensions into local variables + N, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + # Project the queries/keys/values + queries = self.query_projection(queries).view(N, L, H, -1) + keys = self.key_projection(keys).view(N, S, H, -1) + cos, sin = self.rotaryemb(queries) + queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) + values = self.value_projection(values).view(N, S, H, -1) + # Let the world know of the qkv + self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) + + + # Compute the attention + new_values = self.inner_attention( + queries, + keys, + values, + attn_mask, + query_lengths, + key_lengths + ).view(N, L, -1) + + # Project the output and return + return self.out_projection(new_values) + +class RotateEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a batch transformer encoder with Relative Rotary embeddings + for training or processing of sequences all elements at a time. + Example usage: + builder = RotateEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + return AttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + return RotateAttentionLayer + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + return TransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + return TransformerEncoderLayer + + +class AutoEncoderLayer(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.encoder = self.Encoder(feature_size, latent_size) + self.decoder = self.Decoder(feature_size, latent_size) + + class Encoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(feature_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.lat = nn.Linear(latent_size, latent_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.lat.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.lat(x) + return x # -> (N, D) + + class Decoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(latent_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.rec = nn.Linear(latent_size, feature_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.rec.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.rec(x) + return x # -> (N, L*D) + + +class LangLayer(nn.Module): + + def __init__(self, n_embd, n_vocab): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.embed = nn.Linear(n_embd, n_embd) + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, n_vocab, bias=False) + + def forward(self, tensor): + if self.is_cuda_available: + self.embed.cuda() + self.ln_f.cuda() + self.head.cuda() + tensor = tensor.cuda() + tensor = self.embed(tensor) + tensor = F.gelu(tensor) + tensor = self.ln_f(tensor) + tensor = self.head(tensor) + return tensor + + +class Net(nn.Module): + + def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2): + super().__init__() + self.desc_skip_connection = True + self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.relu1 = nn.GELU() + self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout2 = nn.Dropout(dropout) + self.relu2 = nn.GELU() + self.final = nn.Linear(smiles_embed_dim, n_output) + + def forward(self, smiles_emb, multitask=False): + x_out = self.fc1(smiles_emb) + x_out = self.dropout1(x_out) + x_out = self.relu1(x_out) + + if self.desc_skip_connection is True: + x_out = x_out + smiles_emb + + z = self.fc2(x_out) + z = self.dropout2(z) + z = self.relu2(z) + if self.desc_skip_connection is True: + z = self.final(z + x_out) + else: + z = self.final(z) + + if multitask: + return F.sigmoid(z) + return z + + +class MoLEncoder(nn.Module): + + def __init__(self, config, n_vocab, eval=False): + super(MoLEncoder, self).__init__() + + # embeddings + self.config = config + self.tok_emb = nn.Embedding(n_vocab, config['n_embd']) + self.drop = nn.Dropout(config['d_dropout']) + + # transformer + builder = RotateEncoderBuilder.from_kwargs( + n_layers=config['n_layer'], + n_heads=config['n_head'], + query_dimensions=config['n_embd']//config['n_head'], + value_dimensions=config['n_embd']//config['n_head'], + feed_forward_dimensions=None, + attention_type='linear', + # unless we do deterministic_eval here, we will have random outputs + feature_map=partial(GeneralizedRandomFeatures, + n_dims=config['num_feats'], + deterministic_eval=eval), + activation='gelu' + ) + self.blocks = builder.get() + + # classification + self.lang_model = LangLayer(config['n_embd'], n_vocab) + + +class MoLDecoder(nn.Module): + + def __init__(self, n_vocab, max_len, n_embd, n_gpu=None): + super(MoLDecoder, self).__init__() + + self.max_len = max_len + self.n_embd = n_embd + self.n_gpu = n_gpu + self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd) + self.lang_model = LangLayer(n_embd, n_vocab) + + +class Smi_ted(nn.Module): + """materials.smi-ted-Large 738M Parameters""" + + def __init__(self, tokenizer, config=None, eval=False): + super(Smi_ted, self).__init__() + + # configuration + self.config = config + self.tokenizer = tokenizer + self.padding_idx = tokenizer.get_padding_idx() + self.n_vocab = len(self.tokenizer.vocab) + self.is_cuda_available = torch.cuda.is_available() + + # instantiate modules + if self.config: + self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval) + self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd']) + self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout']) + + def load_checkpoint(self, ckpt_path, n_output, eval=False): + # load checkpoint file + checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) + + # load hyparameters + self.config = checkpoint['hparams'] + self.max_len = self.config['max_len'] + self.n_embd = self.config['n_embd'] + self._set_seed(self.config['seed']) + + # instantiate modules + self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval) + self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd) + self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout']) + + # load weights + if 'state_dict' in checkpoint: + if isinstance(checkpoint['state_dict'], list): + self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False) + self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False) + else: + self.load_state_dict(checkpoint['state_dict'], strict=False) + elif 'MODEL_STATE' in checkpoint: + self.load_state_dict(checkpoint['MODEL_STATE'], strict=False) + + # load RNG states each time the model and states are loaded from checkpoint + if 'rng' in self.config: + rng = self.config['rng'] + for key, value in rng.items(): + if key =='torch_state': + torch.set_rng_state(value.cpu()) + elif key =='cuda_state': + torch.cuda.set_rng_state(value.cpu()) + elif key =='numpy_state': + np.random.set_state(value) + elif key =='python_state': + random.setstate(value) + else: + print('unrecognized state') + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_seed(self, value): + print('Random Seed:', value) + random.seed(value) + torch.manual_seed(value) + torch.cuda.manual_seed(value) + torch.cuda.manual_seed_all(value) + np.random.seed(value) + cudnn.deterministic = True + cudnn.benchmark = False + + def tokenize(self, smiles): + """Tokenize a string into tokens.""" + if isinstance(smiles, str): + batch = [smiles] + else: + batch = smiles + + tokens = self.tokenizer( + batch, + padding=True, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + max_length=self.max_len, + ) + + idx = tokens['input_ids'].clone().detach() + mask = tokens['attention_mask'].clone().detach() + + if self.is_cuda_available: + return idx.cuda(), mask.cuda() + + return idx, mask + + def extract_embeddings(self, smiles): + """Extract token and SMILES embeddings.""" + if self.is_cuda_available: + self.encoder.cuda() + self.decoder.cuda() + + # tokenizer + idx, mask = self.tokenize(smiles) + + # transformer encoder + x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector + x = self.encoder.drop(x) + x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1])) + + # add padding + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + mask_embeddings = (token_embeddings * input_mask_expanded) + token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0) + + # aggregate token embeddings (similar to mean pooling) + # CAUTION: use the embeddings from the autoencoder. + smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd)) + + return smiles_embeddings + + def __str__(self): + return 'smi-ted-Large' + + +def load_smi_ted(folder="./smi_ted_large", + ckpt_filename="smi-ted-Large_30.pt", + vocab_filename="bert_vocab_curated.txt", + n_output=1, + eval=False + ): + tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename)) + model = Smi_ted(tokenizer) + model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval) + print('Vocab size:', len(tokenizer.vocab)) + print(f'[FINETUNE MODE - {str(model)}]') + return model diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..9debb92f3ad1748661a8aef9b69b78cb045345dc --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E1-CAM' \ + --checkpoints_folder './checkpoints_QM8-E1-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..589e20897b4cd5871c65fea35617a47aa5cbc5df --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E1-CC2' \ + --checkpoints_folder './checkpoints_QM8-E1-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..d5c8c90452ccf0b3069887ad0c06a0ea192e95af --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E1-PBE0' \ + --checkpoints_folder './checkpoints_QM8-E1-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..147591f60a1a0d6921c8b9df5e0d895ddd2f7839 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E2-CAM' \ + --checkpoints_folder './checkpoints_QM8-E2-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..c471c2491b1bd08cefc47a503faf66e1ae12a713 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E2-CC2' \ + --checkpoints_folder './checkpoints_QM8-E2-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..e1527232ef9cfe39cf4576e6def0927c1d4b39fc --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E2-PBE0' \ + --checkpoints_folder './checkpoints_QM8-E2-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..64a3297c5f2ff839507614ba94072791ccb60436 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f1-CAM' \ + --checkpoints_folder './checkpoints_QM8-f1-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..a326d8686aa7d12f2162ab48323131947e7d88f5 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f1-CC2' \ + --checkpoints_folder './checkpoints_QM8-f1-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..210cbf1928ef9b1f356c3a14813c81a0891e349b --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f1-PBE0' \ + --checkpoints_folder './checkpoints_QM8-f1-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3d38f5854afcd95df057464d278e530a50e90bb --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f2-CAM' \ + --checkpoints_folder './checkpoints_QM8-f2-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..597f9c37937169379d3b8e28cc221d9637fc566a --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f2-CC2' \ + --checkpoints_folder './checkpoints_QM8-f2-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..f4eaaa4cb96bef5b34ee38e508a36a6819016e49 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f2-PBE0' \ + --checkpoints_folder './checkpoints_QM8-f2-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_alpha.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_alpha.sh new file mode 100644 index 0000000000000000000000000000000000000000..54ddaa254018f5090abed6bf930d74beaccbe29d --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_alpha.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'alpha' \ + --checkpoints_folder './checkpoints_QM9-alpha' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_cv.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_cv.sh new file mode 100644 index 0000000000000000000000000000000000000000..78accb629f42542fde4c26e39da6e1e1453cee7e --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_cv.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'cv' \ + --checkpoints_folder './checkpoints_QM9-cv' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_g298.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_g298.sh new file mode 100644 index 0000000000000000000000000000000000000000..19843c43b83626d61764d152138b3df1d2d62023 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_g298.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'g298' \ + --checkpoints_folder './checkpoints_QM9-g298' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_gap.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_gap.sh new file mode 100644 index 0000000000000000000000000000000000000000..d2726bb325164a18a578a4103de314deac704973 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_gap.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'gap' \ + --checkpoints_folder './checkpoints_QM9-gap' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_h298.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_h298.sh new file mode 100644 index 0000000000000000000000000000000000000000..d638fe5a899d4b091714c133b2ff3d1ac7e72991 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_h298.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'h298' \ + --checkpoints_folder './checkpoints_QM9-h298' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_homo.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_homo.sh new file mode 100644 index 0000000000000000000000000000000000000000..44b39ed0e5fc3130d97246712da0823adcf75b5c --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_homo.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'homo' \ + --checkpoints_folder './checkpoints_QM9-homo' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_lumo.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_lumo.sh new file mode 100644 index 0000000000000000000000000000000000000000..4619898a4587d23b8f11a52c1e147f2b491bbc94 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_lumo.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'lumo' \ + --checkpoints_folder './checkpoints_QM9-lumo' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_mu.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_mu.sh new file mode 100644 index 0000000000000000000000000000000000000000..b8c5b8aa012c6668bfc7d41781fc5ea66acc68ec --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_mu.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'mu' \ + --checkpoints_folder './checkpoints_QM9-mu' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_r2.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_r2.sh new file mode 100644 index 0000000000000000000000000000000000000000..1d7c630b89a226ec42ce68fa53e4f6f916a2eb7d --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_r2.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'r2' \ + --checkpoints_folder './checkpoints_QM9-r2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u0.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u0.sh new file mode 100644 index 0000000000000000000000000000000000000000..e9769ec917660bb88444f0e7de63e36c15576146 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u0.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'u0' \ + --checkpoints_folder './checkpoints_QM9-u0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u298.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u298.sh new file mode 100644 index 0000000000000000000000000000000000000000..c1eac1faae458ab03ce06c4aa729ac96387f2921 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u298.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'u298' \ + --checkpoints_folder './checkpoints_QM9-u298' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_zpve.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_zpve.sh new file mode 100644 index 0000000000000000000000000000000000000000..aa31850b5a618fbbef203c0f86f6680faec960e3 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_zpve.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'zpve' \ + --checkpoints_folder './checkpoints_QM9-zpve' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/sider/run_finetune_sider.sh b/models/smi_ted/finetune/smi_ted_large/sider/run_finetune_sider.sh new file mode 100644 index 0000000000000000000000000000000000000000..94522a262f20bd736fdabfb9d00f2004ae644bd6 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/sider/run_finetune_sider.sh @@ -0,0 +1,23 @@ +python ../../finetune_classification_multitask.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/sider' \ + --dataset_name sider \ + --checkpoints_folder './checkpoints_sider' \ + --loss_fn 'bceloss' \ + --target_metric 'roc-auc' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_large/tox21/run_finetune_tox21.sh b/models/smi_ted/finetune/smi_ted_large/tox21/run_finetune_tox21.sh new file mode 100644 index 0000000000000000000000000000000000000000..84302a36fcc7babfde0233f5f12b4a03cb63c02c --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_large/tox21/run_finetune_tox21.sh @@ -0,0 +1,23 @@ +python ../../finetune_classification_multitask.py \ + --n_batch 32 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 1e-6 \ + --lr_multiplier 1 \ + --max_epochs 100 \ + --num_feats 32 \ + --smi_ted_version 'v2' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Large_11.pt' \ + --data_root '../../moleculenet/tox21' \ + --dataset_name tox21 \ + --checkpoints_folder './checkpoints_tox21' \ + --loss_fn 'bceloss' \ + --target_metric 'roc-auc' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/bace/run_finetune_bace.sh b/models/smi_ted/finetune/smi_ted_light/bace/run_finetune_bace.sh new file mode 100644 index 0000000000000000000000000000000000000000..da1b97953d4ad65a64c3fd68f495ea8ca91f5bde --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/bace/run_finetune_bace.sh @@ -0,0 +1,25 @@ +python ../../finetune_classification.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/bace' \ + --dataset_name bace \ + --measure_name 'Class' \ + --checkpoints_folder './checkpoints_bace' \ + --loss_fn 'crossentropy' \ + --target_metric 'roc-auc' \ + --n_output 2 \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/bbbp/run_finetune_bbbp.sh b/models/smi_ted/finetune/smi_ted_light/bbbp/run_finetune_bbbp.sh new file mode 100644 index 0000000000000000000000000000000000000000..860d657ca01b36ee78fe031dfb5f44880b8eda3f --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/bbbp/run_finetune_bbbp.sh @@ -0,0 +1,25 @@ +python ../../finetune_classification.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/bbbp' \ + --dataset_name bbbp \ + --measure_name 'p_np' \ + --checkpoints_folder './checkpoints_bbbp' \ + --loss_fn 'crossentropy' \ + --target_metric 'roc-auc' \ + --n_output 2 \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/bert_vocab_curated.txt b/models/smi_ted/finetune/smi_ted_light/bert_vocab_curated.txt new file mode 100644 index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/bert_vocab_curated.txt @@ -0,0 +1,2393 @@ +<bos> +<eos> +<pad> +<mask> +C +c +( +) +1 +O +N +2 += +n +3 +[C@H] +[C@@H] +F +S +4 +Cl +- +o +s +[nH] +# +/ +Br +[C@] +[C@@] +[N+] +[O-] +5 +\ +. +I +6 +[S@] +[S@@] +P +[N-] +[Si] +7 +[n+] +[2H] +8 +[NH+] +B +9 +[C-] +[Na+] +[Cl-] +[c-] +[CH] +%10 +[NH2+] +[P+] +[B] +[I-] +%11 +[CH2-] +[O+] +[NH3+] +[C] +[Br-] +[IH2] +[S-] +[cH-] +%12 +[nH+] +[B-] +[K+] +[Sn] +[Se] +[CH-] +[HH] +[Y] +[n-] +[CH3-] +[SiH] +[S+] +%13 +[SiH2] +[Li+] +[NH-] +%14 +[Na] +[CH2] +[O-2] +[U+2] +[W] +[Al] +[P@] +[Fe+2] +[PH+] +%15 +[Cl+3] +[Zn+2] +[Ir] +[Mg+2] +[Pt+2] +[OH2+] +[As] +[Fe] +[OH+] +[Zr+2] +[3H] +[Ge] +[SiH3] +[OH-] +[NH4+] +[Cu+2] +[P@@] +p +[Pt] +%16 +[Ca+2] +[Zr] +[F-] +[C+] +[Ti] +[P-] +[V] +[se] +[U] +[O] +[Ni+2] +[Zn] +[Co] +[Ni] +[Pd+2] +[Cu] +%17 +[Cu+] +[Te] +[H+] +[CH+] +[Li] +[Pd] +[Mo] +[Ru+2] +[o+] +[Re] +[SH+] +%18 +[Ac] +[Cr] +[NH2-] +[K] +[13CH2] +[c] +[Zr+4] +[Tl] +[13C] +[Mn] +[N@+] +[Hg] +[Rh] +[Ti+4] +[Sb] +[Co+2] +[Ag+] +[Ru] +%19 +[N@@+] +[Ti+2] +[Al+3] +[Pb] +[I+] +[18F] +[s+] +[Rb+] +[Ba+2] +[H-] +[Fe+3] +[Ir+3] +[13cH] +%20 +[AlH2] +[Au+] +[13c] +[SH2+] +[Sn+2] +[Mn+2] +[Si-] +[Ag] +[N] +[Bi] +%21 +[In] +[CH2+] +[Y+3] +[Ga] +%22 +[Co+3] +[Au] +[13CH3] +[Mg] +[Cs+] +[W+2] +[Hf] +[Zn+] +[Se-] +[S-2] +[Ca] +[pH] +[ClH+] +[Ti+3] +%23 +[Ru+] +[SH-] +[13CH] +[IH+] +[Hf+4] +[Rf] +[OH3+] +%24 +[Pt+4] +[Zr+3] +[PH3+] +[Sr+2] +[Cd+2] +[Cd] +%25 +[Os] +[BH-] +[Sn+4] +[Cr+3] +[Ru+3] +[PH2+] +[Rh+2] +[V+2] +%26 +[Gd+3] +[Pb+2] +[PH] +[Hg+] +[Mo+2] +[AlH] +[Sn+] +%27 +[Pd+] +b +[Rh+3] +[Hg+2] +[15NH] +[14C] +%28 +[Mn+3] +[Si+] +[SeH] +[13C@H] +[NH] +[Ga+3] +[SiH-] +[13C@@H] +[Ce] +[Au+3] +[Bi+3] +[15N] +%29 +[BH3-] +[14cH] +[Ti+] +[Gd] +[cH+] +[Cr+2] +[Sb-] +%30 +[Be+2] +[Al+] +[te] +[11CH3] +[Sm] +[Pr] +[La] +%31 +[Al-] +[Ta] +[125I] +[BH2-] +[Nb] +[Si@] +%32 +[14c] +[Sb+3] +[Ba] +%33 +[Os+2] +[Si@@] +[La+3] +[15n] +[15NH2] +[Nd+3] +%34 +[14CH2] +[18O] +[Nd] +[GeH] +[Ni+3] +[Eu] +[Dy+3] +[Sc] +%36 +[Se-2] +[As+] +%35 +[AsH] +[Tb] +[Sb+5] +[Se+] +[Ce+3] +[c+] +[In+3] +[SnH] +[Mo+4] +%37 +[V+4] +[Eu+3] +[Hf+2] +%38 +[Pt+] +[p+] +[123I] +[Tl+] +[Sm+3] +%39 +[Yb+3] +%40 +[Yb] +[Os+] +%41 +[10B] +[Sc+3] +[Al+2] +%42 +[Sr] +[Tb+3] +[Po] +[Tc] +[PH-] +[AlH3] +[Ar] +[U+4] +[SnH2] +[Cl+2] +[si] +[Fe+] +[14CH3] +[U+3] +[Cl+] +%43 +[GeH2] +%44 +[Er+3] +[Mo+3] +[I+2] +[Fe+4] +[99Tc] +%45 +[11C] +%46 +[SnH3] +[S] +[Te+] +[Er] +[Lu+3] +[11B] +%47 +%48 +[P] +[Tm] +[Th] +[Dy] +[Pr+3] +[Ta+5] +[Nb+5] +[Rb] +[GeH3] +[Br+2] +%49 +[131I] +[Fm] +[Cs] +[BH4-] +[Lu] +[15nH] +%50 +[Ru+6] +[b-] +[Ho] +[Th+4] +[Ru+4] +%52 +[14CH] +%51 +[Cr+6] +[18OH] +[Ho+3] +[Ce+4] +[Bi+2] +[Co+] +%53 +[Yb+2] +[Fe+6] +[Be] +%54 +[SH3+] +[Np] +[As-] +%55 +[14C@@H] +[Ir+2] +[GaH3] +[p-] +[GeH4] +[Sn+3] +[Os+4] +%56 +[14C@H] +[sH+] +[19F] +[Eu+2] +[TlH] +%57 +[Cr+4] +%58 +[B@@-] +[SiH+] +[At] +[Am] +[Fe+5] +[AsH2] +[Si+4] +[B@-] +[Pu] +[SbH] +[P-2] +[Tm+3] +* +%59 +[se+] +[IH-] +%60 +[oH+] +[1H] +[15N+] +[124I] +[S@@+] +[P-3] +[H] +[IH2+] +[TeH] +[Xe] +[PH4+] +[Cr+] +[Cm] +[I+3] +%61 +[Nb+2] +[Ru+5] +%62 +[Ta+2] +[Tc+4] +[CH3+] +[Pm] +[Si@H] +[No] +%63 +[Cr+5] +[Th+2] +[Zn-2] +[13C@] +[Lr] +%64 +[99Tc+3] +%65 +[13C@@] +%66 +[Fe-] +[17O] +[siH] +[Sb+] +[OH] +[IH] +[11CH2] +[Cf] +[SiH2+] +[Gd+2] +[In+] +[Si@@H] +[Mn+] +[99Tc+4] +[Ga-] +%67 +[S@+] +[Ge+4] +[Tl+3] +[16OH] +%68 +[2H-] +[Ra] +[si-] +[NiH2] +[P@@H] +[Rh+] +[12C] +[35S] +[32P] +[SiH2-] +[AlH2+] +[16O] +%69 +[BiH] +[BiH2] +[Zn-] +[BH] +[Tc+3] +[Ir+] +[Ni+] +%70 +[InH2] +[InH] +[Nb+3] +[PbH] +[Bi+] +%71 +[As+3] +%72 +[18O-] +[68Ga+3] +%73 +[Pa] +[76Br] +[Tc+5] +[pH+] +[64Cu+2] +[Ru+8] +%74 +[PH2-] +[Si+2] +[17OH] +[RuH] +[111In+3] +[AlH+] +%75 +%76 +[W+] +[SbH2] +[PoH] +[Ru-] +[XeH] +[Tc+2] +[13C-] +[Br+] +[Pt-2] +[Es] +[Cu-] +[Mg+] +[3HH] +[P@H] +[ClH2+] +%77 +[SH] +[Au-] +[2HH] +%78 +[Sn-] +[11CH] +[PdH2] +0 +[Os+6] +%79 +[Mo+] +%80 +[al] +[PbH2] +[64Cu] +[Cl] +[12CH3] +%81 +[Tc+7] +[11c] +%82 +[Li-] +[99Tc+5] +[He] +[12c] +[Kr] +[RuH+2] +[35Cl] +[Pd-2] +[GaH2] +[4H] +[Sg] +[Cu-2] +[Br+3] +%83 +[37Cl] +[211At] +[IrH+2] +[Mt] +[Ir-2] +[In-] +[12cH] +[12CH2] +[RuH2] +[99Tc+7] +%84 +[15n+] +[ClH2+2] +[16N] +[111In] +[Tc+] +[Ru-2] +[12CH] +[si+] +[Tc+6] +%85 +%86 +[90Y] +[Pd-] +[188Re] +[RuH+] +[NiH] +[SiH3-] +[14n] +[CH3] +[14N] +[10BH2] +%88 +%89 +%90 +[34S] +[77Br] +[GaH] +[Br] +[Ge@] +[B@@H-] +[CuH] +[SiH4] +[3H-] +%87 +%91 +%92 +[67Cu] +[I] +[177Lu] +[ReH] +[67Ga+3] +[Db] +[177Lu+3] +[AlH2-] +[Si+3] +[Ti-2] +[RuH+3] +[al+] +[68Ga] +[2H+] +[B@H-] +[WH2] +[OsH] +[Ir-3] +[AlH-] +[Bk] +[75Se] +[14C@] +[Pt-] +[N@@H+] +[Nb-] +[13NH2] +%93 +[186Re] +[Tb+4] +[PtH] +[IrH2] +[Hg-2] +[AlH3-] +[PdH+] +[Md] +[RhH+2] +[11cH] +[Co-2] +[15N-] +[ZrH2] +%94 +[Hg-] +[127I] +[AsH2+] +[MoH2] +[Te+4] +[14C@@] +[As+5] +[SnH+3] +[Ge@@] +[6Li+] +[WH] +[Ne] +[14NH2] +[14NH] +[12C@@H] +[Os+7] +[RhH] +[Al-3] +[SnH+] +[15NH3+] +[Zr+] +[197Hg+] +%95 +%96 +[90Y+3] +[Os-2] +[98Tc+5] +[15NH3] +[bH-] +[33P] +[Zr-2] +[15O] +[Rh-] +[PbH3] +[PH2] +[Ni-] +[CuH+] +%97 +%98 +%99 +[Os+5] +[PtH+] +[ReH4] +[16NH] +[82Br] +[W-] +[18F-] +[15NH4+] +[Se+4] +[SeH-] +[SH4] +[67Cu+2] +[12C@H] +[AsH3] +[HgH] +[10B-] +[99Tc+6] +[117Sn+4] +[Te@] +[P@+] +[35SH] +[SeH+] +[Ni-2] +[Al-2] +[TeH2] +[Bh] +[99Tc+2] +[Os+8] +[PH-2] +[7Li+] +[14nH] +[AlH+2] +[18FH] +[SnH4] +[18O-2] +[IrH] +[13N] +[Te@@] +[Rh-3] +[15NH+] +[AsH3+] +[SeH2] +[AsH+] +[CoH2] +[16NH2] +[AsH-] +[203Hg+] +[P@@+] +[166Ho+3] +[60Co+3] +[13CH2-] +[SeH2+] +[75Br] +[TlH2] +[80Br] +[siH+] +[Ca+] +[153Sm+3] +[PdH] +[225Ac] +[13CH3-] +[AlH4-] +[FeH] +[13CH-] +[14C-] +[11C-] +[153Sm] +[Re-] +[te+] +[13CH4] +[ClH+2] +[8CH2] +[99Mo] +[ClH3+3] +[SbH3] +[25Mg+2] +[16N+] +[SnH2+] +[PH4] +[11C@H] +[122I] +[Re-2] +[RuH2+2] +[ZrH] +[Bi-] +[Pr+] +[Rn] +[Fr] +[36Cl] +[18o] +[YH] +[79Br] +[121I] +[113In+3] +[InH4-] +[TaH] +[RhH2] +[Ta-] +[67Ga] +[ZnH+] +[SnH2-] +[OsH2] +[16F] +[FeH2] +[14O] +[PbH2+2] +[BH2] +[6H] +[125Te] +[197Hg] +[TaH2] +[TaH3] +[76As] +[Nb-2] +[14N+] +[125I-] +[33S] +[IH2+2] +[NH2] +[PtH2] +[MnH] +[19C] +[17F] +[1H-] +[SnH4+2] +[Mn-2] +[15NH2+] +[TiH2] +[ReH7] +[Cd-2] +[Fe-3] +[SH2] +[17O-] +[siH-] +[CoH+] +[VH] +[10BH] +[Ru-3] +[13O] +[5H] +[CoH] +[PH5] +[15n-] +[153Gd] +[12C@] +[11CH3-] +[IrH3] +[RuH3] +[74Se] +[Se@] +[Hf+] +[77Se] +[166Ho] +[59Fe+2] +[203Hg] +[18OH-] +[8CH] +[12C@@] +[11CH4] +[15C] +[249Cf] +[PbH4] +[64Zn] +[PH3] +[99Tc+] +[14c-] +[149Pm] +[IrH4] +[Se@@] +[13OH] +[14CH3-] +[28Si] +[Rh-2] +[Fe-2] +[131I-] +[51Cr] +[62Cu+2] +[81Br] +[121Sb] +[7Li] +[89Zr+4] +[SbH3+] +[11C@@H] +[98Tc] +[59Fe+3] +[BiH2+] +[SbH+] +[TiH] +[14NH3] +[15OH] +[119Sn] +[201Hg] +[MnH+] +[201Tl] +[51Cr+3] +[123I-] +[MoH] +[AlH6-3] +[MnH2] +[WH3] +[213Bi+3] +[SnH2+2] +[123IH] +[13CH+] +[Zr-] +[74As] +[13C+] +[32P+] +[KrH] +[SiH+2] +[ClH3+2] +[13NH] +[9CH2] +[ZrH2+2] +[87Sr+2] +[35s] +[239Pu] +[198Au] +[241Am] +[203Hg+2] +[V+] +[YH2] +[SH5] +[195Pt] +[203Pb] +[RuH4] +[ThH2] +[AuH] +[66Ga+3] +[11B-] +[F] +[24Na+] +[85Sr+2] +[201Tl+] +[14CH4] +[32S] +[TeH2+] +[ClH2+3] +[AgH] +[Ge@H] +[44Ca+2] +[Os-] +[31P] +[15nH+] +[SbH4] +[TiH+] +[Ba+] +[57Co+2] +[Ta+] +[125IH] +[77As] +[129I] +[Fe-4] +[Ta-2] +[19O] +[12O] +[BiH3] +[237Np] +[252Cf] +[86Y] +[Cr-2] +[89Y] +[195Pt+2] +[si+2] +[58Fe+2] +[Hs] +[S@@H] +[OsH6] +[GdH2] +[IH3] +[8CH4] +[164Dy+3] +[47Ca+2] +[57Co] +[NbH2] +[ReH2] +[ZnH2] +[CrH2] +[17NH] +[ZrH3] +[RhH3] +[12C-] +[18O+] +[Bi-2] +[ClH4+3] +[Ni-3] +[Ag-] +[111In-] +[Mo-2] +[55Fe+3] +[204Hg+] +[35Cl-] +[211Pb] +[75Ge] +[8B] +[TeH3] +[SnH3+] +[Zr-3] +[28F] +[249Bk] +[169Yb] +[34SH] +[6Li] +[94Tc] +[197Au] +[195Pt+4] +[169Yb+3] +[32Cl] +[82Se] +[159Gd+3] +[213Bi] +[CoH+2] +[36S] +[35P] +[Ru-4] +[Cr-3] +[60Co] +[1H+] +[18CH2] +[Cd-] +[152Sm+3] +[106Ru] +[238Pu] +[220Rn] +[45Ca+2] +[89Sr+2] +[239Np] +[90Sr+2] +[137Cs+] +[165Dy] +[68GaH3] +[65Zn+2] +[89Zr] +[BiH2+2] +[62Cu] +[165Dy+3] +[238U] +[105Rh+3] +[70Zn] +[12B] +[12OH] +[18CH] +[17CH] +[OsH3] +[SbH-] +[SH6] +[AlH2-2] +[42K] +[76Br-] +[71As] +[NbH3] +[ReH3] +[OsH-] +[WH4] +[MoH3] +[OsH4] +[RuH6] +[PtH3] +[CuH2] +[CoH3] +[TiH4] +[64Zn+2] +[Si-2] +[79BrH] +[14CH2-] +[PtH2+2] +[Os-3] +[29Si] +[Ti-] +[Se+6] +[22Na+] +[42K+] +[131Cs+] +[86Rb+] +[134Cs+] +[209Po] +[208Po] +[81Rb+] +[203Tl+] +[Zr-4] +[148Sm] +[147Sm] +[37Cl-] +[12CH4] +[Ge@@H] +[63Cu] +[13CH2+] +[AsH2-] +[CeH] +[SnH-] +[UH] +[9c] +[21CH3] +[TeH+] +[57Co+3] +[8BH2] +[12BH2] +[19BH2] +[9BH2] +[YbH2] +[CrH+2] +[208Bi] +[152Gd] +[61Cu] +[115In] +[60Co+2] +[13NH2-] +[120I] +[18OH2] +[75SeH] +[SbH2+] +[144Ce] +[16n] +[113In] +[22nH] +[129I-] +[InH3] +[32PH3] +[234U] +[235U] +[59Fe] +[82Rb+] +[65Zn] +[244Cm] +[147Pm] +[91Y] +[237Pu] +[231Pa] +[253Cf] +[127Te] +[187Re] +[236Np] +[235Np] +[72Zn] +[253Es] +[159Dy] +[62Zn] +[101Tc] +[149Tb] +[124I-] +[SeH3+] +[210Pb] +[40K] +[210Po] +[214Pb] +[218Po] +[214Po] +[7Be] +[212Pb] +[205Pb] +[209Pb] +[123Te] +[202Pb] +[72As] +[201Pb] +[70As] +[73Ge] +[200Pb] +[198Pb] +[66Ga] +[73Se] +[195Pb] +[199Pb] +[144Ce+3] +[235U+2] +[90Tc] +[114In+3] +[128I] +[100Tc+] +[82Br-] +[191Pt+2] +[191Pt+4] +[193Pt+4] +[31PH3] +[125I+2] +[131I+2] +[125Te+4] +[82Sr+2] +[149Sm] +[81BrH] +[129Xe] +[193Pt+2] +[123I+2] +[Cr-] +[Co-] +[227Th+4] +[249Cf+3] +[252Cf+3] +[187Os] +[16O-] +[17O+] +[16OH-] +[98Tc+7] +[58Co+2] +[69Ga+3] +[57Fe+2] +[43K+] +[16C] +[52Fe+3] +[SeH5] +[194Pb] +[196Pb] +[197Pb] +[213Pb] +[9B] +[19B] +[11CH-] +[9CH] +[20OH] +[25OH] +[8cH] +[TiH+3] +[SnH6+3] +[N@H+] +[ZnH] +[VH3] +[52Mn+2] +[64Ga] +[13B] +[216Bi] +[117Sn+2] +[232Th] +[SnH+2] +[BiH5] +[77Kr] +[103Cd] +[62Ni] +[LaH3] +[SmH3] +[EuH3] +[MoH5] +[64Ni] +[66Zn] +[68Zn] +[186W] +[FeH4] +[MoH4] +[HgH2] +[15NH2-] +[UH2] +[204Hg] +[GaH4-] +[ThH4] +[WH6] +[PtH4] +[VH2] +[UH3] +[FeH3] +[RuH5] +[BiH4] +[80Br-] +[CeH3] +[37ClH] +[157Gd+3] +[205Tl] +[203Tl] +[62Cu+] +[64Cu+] +[61Cu+] +[37SH2] +[30Si] +[28Al] +[19OH2] +[8He] +[6He] +[153Pm] +[209Bi] +[66Zn+2] +[10CH4] +[191Ir] +[66Cu] +[16O+] +[25O] +[10c] +[Co-3] +[Sn@@] +[17OH-] +[206Po] +[204Po] +[202Po] +[201Po] +[200Po] +[199Po] +[198Po] +[197Po] +[196Po] +[195Po] +[194Po] +[193Po] +[192Po] +[191Po] +[190Po] +[217Po] +[BiH4-] +[TeH4] +[222Ra] +[62Ga] +[39Ar] +[144Sm] +[58Fe] +[153Eu] +[85Rb] +[171Yb] +[172Yb] +[114Cd] +[51Fe] +[142Ce] +[207Tl] +[92Mo] +[115Sn] +[140Ce] +[202Hg] +[180W] +[182W] +[183W] +[184W] +[96Mo] +[47Ti] +[111Cd] +[143Nd] +[145Nd] +[126Te] +[128Te] +[130Te] +[185Re] +[97Mo] +[98Mo] +[183Re] +[52V] +[80Se] +[87Kr] +[137Xe] +[196Au] +[146Ce] +[88Kr] +[51Ti] +[138Xe] +[112Cd] +[116Sn] +[120Sn] +[28SiH3] +[35S-] +[15NH-] +[13CH3+] +[34S+] +[34s] +[SiH4-] +[100Tc+5] +[NiH2+2] +[239Th] +[186Lu] +[AuH3] +[I@@-] +[XeH2] +[B+] +[16CH2] +[8C] +[TaH5] +[FeH4-] +[19C@H] +[10NH] +[FeH6-3] +[22CH] +[25N] +[25N+] +[25N-] +[21CH2] +[18cH] +[113I] +[ScH3] +[30PH3] +[43Ca+2] +[41Ca+2] +[106Cd] +[122Sn] +[18CH3] +[58Co+3] +[98Tc+4] +[70Ge] +[76Ge] +[108Cd] +[116Cd] +[130Xe] +[94Mo] +[124Sn] +[186Os] +[188Os] +[190Os] +[192Os] +[106Pd] +[110Pd] +[120Te] +[132Ba] +[134Ba] +[136Ba] +[136Ce] +[138Ce] +[156Dy] +[158Dy] +[160Dy] +[163Dy] +[162Er] +[164Er] +[167Er] +[176Hf] +[26Mg] +[144Nd] +[150Nd] +[41K] +[46Ti] +[48Ti] +[49Ti] +[50Ti] +[170Yb] +[173Yb] +[91Zr] +[92Zr] +[96Zr] +[34S-] +[CuH2-] +[38Cl] +[25Mg] +[51V] +[93Nb] +[95Mo] +[45Sc] +[123Sb] +[139La] +[9Be] +[99Y+3] +[99Y] +[156Ho] +[67Zn] +[144Ce+4] +[210Tl] +[42Ca] +[54Fe] +[193Ir] +[92Nb] +[141Cs] +[52Cr] +[35ClH] +[46Ca] +[139Cs] +[65Cu] +[71Ga] +[60Ni] +[16NH3] +[148Nd] +[72Ge] +[161Dy] +[49Ca] +[43Ca] +[8Be] +[48Ca] +[44Ca] +[120Xe] +[80Rb] +[215At] +[180Re] +[146Sm] +[19Ne] +[74Kr] +[134La] +[76Kr] +[219Fr] +[121Xe] +[220Fr] +[216At] +[223Ac] +[218At] +[37Ar] +[135I] +[110Cd] +[94Tc+7] +[86Y+3] +[135I-] +[15O-2] +[151Eu+3] +[161Tb+3] +[197Hg+2] +[109Cd+2] +[191Os+4] +[170Tm+3] +[205Bi+3] +[233U+4] +[126Sb+3] +[127Sb+3] +[132Cs+] +[136Eu+3] +[136Eu] +[125Sn+4] +[175Yb+3] +[100Mo] +[22Ne] +[13c-] +[13NH4+] +[17C] +[9C] +[31S] +[31SH] +[133I] +[126I] +[36SH] +[30S] +[32SH] +[19CH2] +[19c] +[18c] +[15F] +[10C] +[RuH-] +[62Zn+2] +[32ClH] +[33ClH] +[78BrH] +[12Li+] +[12Li] +[233Ra] +[68Ge+4] +[44Sc+3] +[91Y+3] +[106Ru+3] +[PoH2] +[AtH] +[55Fe] +[233U] +[210PoH2] +[230Th] +[228Th] +[222Rn] +[35SH2] +[227Th] +[192Ir] +[133Xe] +[81Kr] +[95Zr] +[240Pu] +[54Mn] +[103Ru] +[95Nb] +[109Cd] +[141Ce] +[85Kr] +[110Ag] +[58Co] +[241Pu] +[234Th] +[140La] +[63Ni] +[152Eu] +[132IH] +[226Rn] +[154Eu] +[36ClH] +[228Ac] +[155Eu] +[106Rh] +[243Am] +[227Ac] +[243Cm] +[236U] +[144Pr] +[232U] +[32SH2] +[88Y] +[82BrH] +[135IH] +[242Cm] +[115Cd] +[242Pu] +[46Sc] +[56Mn] +[234Pa] +[41Ar] +[147Nd] +[187W] +[151Sm] +[59Ni] +[233Pa] +[52Mn] +[94Nb] +[219Rn] +[236Pu] +[13NH3] +[93Zr] +[51Cr+6] +[TlH3] +[123Xe] +[160Tb] +[170Tm] +[182Ta] +[175Yb] +[93Mo] +[143Ce] +[191Os] +[126IH] +[48V] +[113Cd] +[47Sc] +[181Hf] +[185W] +[143Pr] +[191Pt] +[181W] +[33PH3] +[97Ru] +[97Tc] +[111Ag] +[169Er] +[107Pd] +[103Ru+2] +[34SH2] +[137Ce] +[242Am] +[117SnH2] +[57Ni] +[239U] +[60Cu] +[250Cf] +[193Au] +[69Zn] +[55Co] +[139Ce] +[127Xe] +[159Gd] +[56Co] +[177Hf] +[244Pu] +[38ClH] +[142Pr] +[199Hg] +[179Hf] +[178Hf] +[237U] +[156Eu] +[157Eu] +[105Ru] +[171Tm] +[199Au] +[155Sm] +[80BrH] +[108Ag] +[128IH] +[48Sc] +[45Ti] +[176Lu] +[121SnH2] +[148Pm] +[57Fe] +[10BH3] +[96Tc] +[133IH] +[143Pm] +[105Rh] +[130IH] +[134IH] +[131IH] +[71Zn] +[105Ag] +[97Zr] +[235Pu] +[231Th] +[109Pd] +[93Y] +[190Ir] +[135Xe] +[53Mn] +[134Ce] +[234Np] +[240Am] +[246Cf] +[240Cm] +[241Cm] +[226Th] +[39ClH] +[229Th] +[245Cm] +[240U] +[240Np] +[249Cm] +[243Pu] +[145Pm] +[199Pt] +[246Bk] +[193Pt] +[230U] +[250Cm] +[44Ti] +[175Hf] +[254Fm] +[255Fm] +[257Fm] +[92Y] +[188Ir] +[171Lu] +[257Md] +[247Bk] +[121IH] +[250Bk] +[179Lu] +[224Ac] +[195Hg] +[244Am] +[246Pu] +[194Au] +[252Fm] +[173Hf] +[246Cm] +[135Ce] +[49Cr] +[248Cf] +[247Cm] +[248Cm] +[174Ta] +[176Ta] +[154Tb] +[172Ta] +[177Ta] +[175Ta] +[180Ta] +[158Tb] +[115Ag] +[189Os] +[251Cf] +[145Pr] +[147Pr] +[76BrH] +[102Rh] +[238Np] +[185Os] +[246Am] +[233Np] +[166Dy] +[254Es] +[244Cf] +[193Os] +[245Am] +[245Bk] +[239Am] +[238Am] +[97Nb] +[245Pu] +[254Cf] +[188W] +[250Es] +[251Es] +[237Am] +[182Hf] +[258Md] +[232Np] +[238Cm] +[60Fe] +[109Pd+2] +[234Pu] +[141Ce+3] +[136Nd] +[136Pr] +[173Ta] +[110Ru] +[147Tb] +[253Fm] +[139Nd] +[178Re] +[177Re] +[200Au] +[182Re] +[156Tb] +[155Tb] +[157Tb] +[161Tb] +[161Ho] +[167Tm] +[173Lu] +[179Ta] +[171Er] +[44Sc] +[49Sc] +[49V] +[51Mn] +[90Nb] +[88Nb] +[88Zr] +[36SH2] +[174Yb] +[178Lu] +[179W] +[83BrH] +[107Cd] +[75BrH] +[62Co] +[48Cr] +[63Zn] +[102Ag] +[154Sm] +[168Er] +[65Ni] +[137La] +[187Ir] +[144Pm] +[146Pm] +[160Gd] +[166Yb] +[162Dy] +[47V] +[141Nd] +[141Sm] +[166Er] +[150Sm] +[146Eu] +[149Eu] +[174Lu] +[17NH3] +[102Ru] +[170Hf] +[188Pt] +[61Ni] +[56Ni] +[149Gd] +[151Gd] +[141Pm] +[147Gd] +[146Gd] +[161Er] +[103Ag] +[145Eu] +[153Tb] +[155Dy] +[184Re] +[180Os] +[182Os] +[186Pt] +[181Os] +[181Re] +[151Tb] +[178Ta] +[178W] +[189Pt] +[194Hg] +[145Sm] +[150Tb] +[132La] +[158Gd] +[104Ag] +[193Hg] +[94Ru] +[137Pr] +[155Ho] +[117Cd] +[99Ru] +[146Nd] +[218Rn] +[95Y] +[79Kr] +[120IH] +[138Pr] +[100Pd] +[166Tm] +[90Mo] +[151Nd] +[231U] +[138Nd] +[89Nb] +[98Nb] +[162Ho] +[142Sm] +[186Ta] +[104Tc] +[184Ta] +[185Ta] +[170Er] +[107Rh] +[131La] +[169Lu] +[74BrH] +[150Pm] +[172Tm] +[197Pt] +[230Pu] +[170Lu] +[86Zr] +[176W] +[177W] +[101Pd] +[105Pd] +[108Pd] +[149Nd] +[164Ho] +[159Ho] +[167Ho] +[176Yb] +[156Sm] +[77BrH] +[189Re] +[99Rh] +[100Rh] +[151Pm] +[232Pa] +[228Pa] +[230Pa] +[66Ni] +[194Os] +[135La] +[138La] +[141La] +[142La] +[195Ir] +[96Nb] +[157Ho] +[183Hf] +[162Tm] +[172Er] +[148Eu] +[150Eu] +[15CH4] +[89Kr] +[143La] +[58Ni] +[61Co] +[158Eu] +[165Er] +[167Yb] +[173Tm] +[175Tm] +[172Hf] +[172Lu] +[93Tc] +[177Yb] +[124IH] +[194Ir] +[147Eu] +[101Mo] +[180Hf] +[189Ir] +[87Y] +[43Sc] +[195Au] +[112Ag] +[84BrH] +[106Ag] +[109Ag] +[101Rh] +[162Yb] +[228Rn] +[139Pr] +[94Y] +[201Au] +[40PH3] +[110Ag+] +[104Cd] +[133Ba+2] +[226Ac] +[145Gd] +[186Ir] +[184Ir] +[224Rn] +[185Ir] +[182Ir] +[184Hf] +[200Pt] +[227Pa] +[178Yb] +[72Br-] +[72BrH] +[248Am] +[238Th] +[161Gd] +[35S-2] +[107Ag] +[FeH6-4] +[89Sr] +[SnH3-] +[SeH3] +[TeH3+] +[SbH4+] +[AsH4+] +[4He] +[AsH3-] +[1HH] +[3H+] +[82Rb] +[85Sr] +[90Sr] +[137Cs] +[133Ba] +[131Cs] +[SbH5] +[224Ra] +[22Na] +[210Bi] +[214Bi] +[228Ra] +[127Sb] +[136Cs] +[125Sb] +[134Cs] +[140Ba] +[45Ca] +[206Pb] +[207Pb] +[24Na] +[86Rb] +[212Bi] +[208Pb] +[124Sb] +[204Pb] +[44K] +[129Te] +[113Sn] +[204Tl] +[87Sr] +[208Tl] +[87Rb] +[47Ca] +[135Cs] +[216Po] +[137Ba] +[207Bi] +[212Po] +[79Se] +[223Ra] +[86Sr] +[122Sb] +[26Al] +[32Si] +[126Sn] +[225Ra] +[114In] +[72Ga] +[132Te] +[10Be] +[125Sn] +[73As] +[206Bi] +[117Sn] +[40Ca] +[41Ca] +[89Rb] +[116In] +[129Sb] +[91Sr] +[71Ge] +[139Ba] +[69Ga] +[120Sb] +[121Sn] +[123Sn] +[131Te] +[77Ge] +[135Ba] +[82Sr] +[43K] +[131Ba] +[92Sr] +[88Rb] +[129Cs] +[144Cs] +[127Cs] +[200Tl] +[202Tl] +[141Ba] +[117Sb] +[116Sb] +[78As] +[131Sb] +[126Sb] +[128Sb] +[130Sb] +[67Ge] +[68Ge] +[78Ge] +[66Ge] +[223Fr] +[132Cs] +[125Cs] +[138Cs] +[133Te] +[84Rb] +[83Rb] +[81Rb] +[142Ba] +[200Bi] +[115Sb] +[194Tl] +[70Se] +[112In] +[118Sb] +[70Ga] +[27Mg] +[202Bi] +[83Se] +[9Li] +[69As] +[79Rb] +[81Sr] +[83Sr] +[78Se] +[109In] +[29Al] +[118Sn] +[117In] +[119Sb] +[114Sn] +[138Ba] +[69Ge] +[73Ga] +[74Ge] +[206Tl] +[199Tl] +[130Cs] +[28Mg] +[116Te] +[112Sn] +[126Ba] +[211Bi] +[81Se] +[127Sn] +[143Cs] +[134Te] +[80Sr] +[45K] +[215Po] +[207Po] +[111Sn] +[211Po] +[128Ba] +[198Tl] +[227Ra] +[213Po] +[220Ra] +[128Sn] +[203Po] +[205Po] +[65Ga] +[197Tl] +[88Sr] +[110In] +[31Si] +[201Bi] +[121Te] +[205Bi] +[203Bi] +[195Tl] +[209Tl] +[110Sn] +[222Fr] +[207At] +[119In] +[As@] +[129IH] +[157Dy] +[111IH] +[230Ra] +[144Pr+3] +[SiH3+] +[3He] +[AsH5] +[72Se] +[95Tc] +[103Pd] +[121Sn+2] +[211Rn] +[38SH2] +[127IH] +[74Br-] +[133I-] +[100Tc+4] +[100Tc] +[36Cl-] +[89Y+3] +[104Rh] +[152Sm] +[226Ra] +[19FH] +[104Pd] +[148Gd] +[157Lu] +[33SH2] +[121I-] +[17FH] +[71Se] +[157Sm] +[148Tb] +[164Dy] +[15OH2] +[15O+] +[39K] +[40Ar] +[50Cr+3] +[50Cr] +[52Ti] +[103Pd+2] +[130Ba] +[142Pm] +[153Gd+3] +[151Eu] +[103Rh] +[124Xe] +[152Tb] +[17OH2] +[20Ne] +[52Fe] +[94Zr+4] +[94Zr] +[149Pr] +[16OH2] +[53Cr+6] +[53Cr] +[81Br-] +[112Pd] +[125Xe] +[155Gd] +[157Gd] +[168Yb] +[184Os] +[166Tb] +[221Fr] +[212Ra] +[75Br-] +[79Br-] +[113Ag] +[23Na] +[34Cl-] +[34ClH] +[38Cl-] +[56Fe] +[68Cu] +[77Br-] +[90Zr+4] +[90Zr] +[102Pd] +[154Eu+3] +[57Mn] +[165Tm] +[152Dy] +[217At] +[77se] +[13cH-] +[122Te] +[156Gd] +[124Te] +[53Ni] +[131Xe] +[174Hf+4] +[174Hf] +[76Se] +[168Tm] +[167Dy] +[154Gd] +[95Ru] +[210At] +[85Br] +[59Co] +[122Xe] +[27Al] +[54Cr] +[198Hg] +[85Rb+] +[214Tl] +[229Rn] +[218Pb] +[218Bi] +[167Tm+3] +[18o+] +[P@@H+] +[P@H+] +[13N+] +[212Pb+2] +[217Bi] +[249Cf+2] +[18OH3+] +[90Sr-] +[Cf+3] +[200Hg] +[86Tc] +[141Pr+3] +[141Pr] +[16nH] +[14NH4+] +[132Xe] +[83Kr] +[70Zn+2] +[137Ba+2] +[36Ar] +[38Ar] +[21Ne] +[126Xe] +[136Xe] +[128Xe] +[134Xe] +[84Kr] +[86Kr] +[78Kr] +[80Kr] +[82Kr] +[67Zn+2] +[65Cu+2] +[110Te] +[58Fe+3] +[142Nd] +[38K] +[198Au+3] +[122IH] +[38PH3] +[130I-] +[40K+] +[38K+] +[28Mg+2] +[208Tl+] +[13OH2] +[198Bi] +[192Bi] +[194Bi] +[196Bi] +[132I-] +[83Sr+2] +[169Er+3] +[122I-] +[120I-] +[92Sr+2] +[126I-] +[24Mg] +[84Sr] +[118Pd+2] +[118Pd] +[AsH4] +[127I-] +[9C-] +[11CH3+] +[17B] +[7B] +[4HH] +[18C-] +[22CH3-] +[22CH4] +[17C-] +[15CH3] +[16CH3] +[11NH3] +[21NH3] +[11N-] +[11NH] +[16CH] +[17CH2] +[99Ru+2] +[181Ta+2] +[181Ta] +[20CH] +[32PH2] +[55Fe+2] +[SH3] +[S@H] +[Mn-] +[IH4] +[ThH] +[GaH-] +[BiH+] +[EuH2] +[FeH4-3] +[FeH6] +[IH5] +[NiH+] +[SrH2] +[VH4] +[YH3] +[seH+] +<unk> diff --git a/models/smi_ted/finetune/smi_ted_light/clintox/run_finetune_clintox.sh b/models/smi_ted/finetune/smi_ted_light/clintox/run_finetune_clintox.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6ffd279c957bef21f52a009c2e363c26c630ea6 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/clintox/run_finetune_clintox.sh @@ -0,0 +1,23 @@ +python ../../finetune_classification_multitask.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 100 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/clintox' \ + --dataset_name clintox \ + --checkpoints_folder './checkpoints_clintox' \ + --loss_fn 'bceloss' \ + --target_metric 'roc-auc' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/esol/run_finetune_esol.sh b/models/smi_ted/finetune/smi_ted_light/esol/run_finetune_esol.sh new file mode 100644 index 0000000000000000000000000000000000000000..dd573cf4bac8c917b001eb157ce160ef4fca9d72 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/esol/run_finetune_esol.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/esol' \ + --dataset_name esol \ + --measure_name 'measured log solubility in mols per litre' \ + --checkpoints_folder './checkpoints_esol' \ + --loss_fn 'rmse' \ + --target_metric 'rmse' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/freesolv/run_finetune_freesolv.sh b/models/smi_ted/finetune/smi_ted_light/freesolv/run_finetune_freesolv.sh new file mode 100644 index 0000000000000000000000000000000000000000..43c1f321e357f873f51460e428e605ca30f2c3e9 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/freesolv/run_finetune_freesolv.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/freesolv' \ + --dataset_name freesolv \ + --measure_name 'expt' \ + --checkpoints_folder './checkpoints_freesolv' \ + --loss_fn 'rmse' \ + --target_metric 'rmse' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/hiv/run_finetune_hiv.sh b/models/smi_ted/finetune/smi_ted_light/hiv/run_finetune_hiv.sh new file mode 100644 index 0000000000000000000000000000000000000000..8191fb6960d793e184c5237bb87db0223ee6c888 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/hiv/run_finetune_hiv.sh @@ -0,0 +1,25 @@ +python ../../finetune_classification.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 1e-7 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/hiv' \ + --dataset_name hiv \ + --measure_name 'HIV_active' \ + --checkpoints_folder './checkpoints_hiv_1e-7' \ + --loss_fn 'crossentropy' \ + --target_metric 'roc-auc' \ + --n_output 2 \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/lipo/run_finetune_lipo.sh b/models/smi_ted/finetune/smi_ted_light/lipo/run_finetune_lipo.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe4932283f31494e03808982800a8d205a4485ec --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/lipo/run_finetune_lipo.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/lipophilicity' \ + --dataset_name lipophilicity \ + --measure_name 'y' \ + --checkpoints_folder './checkpoints_lipophilicity' \ + --loss_fn 'rmse' \ + --target_metric 'rmse' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/load.py b/models/smi_ted/finetune/smi_ted_light/load.py new file mode 100644 index 0000000000000000000000000000000000000000..2f3aeea2cfe802ae32706cc283f980f2e74ec6a0 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/load.py @@ -0,0 +1,504 @@ +PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +# Transformers +from fast_transformers.attention import AttentionLayer +from fast_transformers.events import QKVEvent +from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer +from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder +from fast_transformers.builders.attention_builders import AttentionBuilder +from fast_transformers.feature_maps import GeneralizedRandomFeatures +from fast_transformers.masking import LengthMask +from transformers import BertTokenizer + +# Data +import numpy as np +import pandas as pd + +# Standard library +from functools import partial +import regex as re +import random +import os +import gc +from tqdm import tqdm +tqdm.pandas() + + +class MolTranBertTokenizer(BertTokenizer): + def __init__(self, vocab_file: str = '', + do_lower_case=False, + unk_token='<pad>', + sep_token='<eos>', + pad_token='<pad>', + cls_token='<bos>', + mask_token='<mask>', + **kwargs): + super().__init__(vocab_file, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + self.regex_tokenizer = re.compile(PATTERN) + self.wordpiece_tokenizer = None + self.basic_tokenizer = None + with open(vocab_file) as f: + self.padding_idx = f.readlines().index(pad_token+'\n') + + def _tokenize(self, text): + split_tokens = self.regex_tokenizer.findall(text) + return split_tokens + + def convert_idx_to_tokens(self, idx_tensor): + tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()] + return tokens + + def convert_tokens_to_string(self, tokens): + stopwords = ['<bos>', '<eos>'] + clean_tokens = [word for word in tokens if word not in stopwords] + out_string = ''.join(clean_tokens) + return out_string + + def get_padding_idx(self): + return self.padding_idx + + def idx_to_smiles(self, torch_model, idx): + '''Convert tokens idx back to SMILES text''' + rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx) + flat_list_tokens = [item for sublist in rev_tokens for item in sublist] + decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens) + return decoded_smiles + + +## Transformer layers +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = 0 + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self.cos_cached = emb.cos()[None,:, None, :] + self.sin_cached = emb.sin()[None,:, None, :] + + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotateAttentionLayer(AttentionLayer): + """Rotate attention layer inherits from fast_transformer attention layer. + The only thing added is an Embedding encoding, for more information + on the attention layer see the fast_transformers code + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, + d_values=d_values, event_dispatcher=event_dispatcher) + + self.rotaryemb = RotaryEmbedding(d_keys) + print('Using Rotation Embedding') + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """ + Using the same frame work as the fast_Transformers attention layer + but injecting rotary information to the queries and the keys + after the keys and queries are projected. + In the argument description we make use of the following sizes + - N: the batch size + - L: The maximum length of the queries + - S: The maximum length of the keys (the actual length per sequence + is given by the length mask) + - D: The input feature dimensionality passed in the constructor as + 'd_model' + Arguments + --------- + queries: (N, L, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + Returns + ------- + The new value for each query as a tensor of shape (N, L, D). + """ + # Extract the dimensions into local variables + N, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + # Project the queries/keys/values + queries = self.query_projection(queries).view(N, L, H, -1) + keys = self.key_projection(keys).view(N, S, H, -1) + cos, sin = self.rotaryemb(queries) + queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) + values = self.value_projection(values).view(N, S, H, -1) + # Let the world know of the qkv + self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) + + + # Compute the attention + new_values = self.inner_attention( + queries, + keys, + values, + attn_mask, + query_lengths, + key_lengths + ).view(N, L, -1) + + # Project the output and return + return self.out_projection(new_values) + +class RotateEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a batch transformer encoder with Relative Rotary embeddings + for training or processing of sequences all elements at a time. + Example usage: + builder = RotateEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + return AttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + return RotateAttentionLayer + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + return TransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + return TransformerEncoderLayer + + +class AutoEncoderLayer(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.encoder = self.Encoder(feature_size, latent_size) + self.decoder = self.Decoder(feature_size, latent_size) + + class Encoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(feature_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.lat = nn.Linear(latent_size, latent_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.lat.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.lat(x) + return x # -> (N, D) + + class Decoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(latent_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.rec = nn.Linear(latent_size, feature_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.rec.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.rec(x) + return x # -> (N, L*D) + + +class LangLayer(nn.Module): + + def __init__(self, n_embd, n_vocab): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.embed = nn.Linear(n_embd, n_embd) + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, n_vocab, bias=False) + + def forward(self, tensor): + if self.is_cuda_available: + self.embed.cuda() + self.ln_f.cuda() + self.head.cuda() + tensor = tensor.cuda() + tensor = self.embed(tensor) + tensor = F.gelu(tensor) + tensor = self.ln_f(tensor) + tensor = self.head(tensor) + return tensor + + +class Net(nn.Module): + + def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2): + super().__init__() + self.desc_skip_connection = True + self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.relu1 = nn.GELU() + self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout2 = nn.Dropout(dropout) + self.relu2 = nn.GELU() + self.final = nn.Linear(smiles_embed_dim, n_output) + + def forward(self, smiles_emb, multitask=False): + x_out = self.fc1(smiles_emb) + x_out = self.dropout1(x_out) + x_out = self.relu1(x_out) + + if self.desc_skip_connection is True: + x_out = x_out + smiles_emb + + z = self.fc2(x_out) + z = self.dropout2(z) + z = self.relu2(z) + if self.desc_skip_connection is True: + z = self.final(z + x_out) + else: + z = self.final(z) + + if multitask: + return F.sigmoid(z) + return z + + +class MoLEncoder(nn.Module): + + def __init__(self, config, n_vocab, eval=False): + super(MoLEncoder, self).__init__() + + # embeddings + self.config = config + self.tok_emb = nn.Embedding(n_vocab, config['n_embd']) + self.drop = nn.Dropout(config['d_dropout']) + + # transformer + builder = RotateEncoderBuilder.from_kwargs( + n_layers=config['n_layer'], + n_heads=config['n_head'], + query_dimensions=config['n_embd']//config['n_head'], + value_dimensions=config['n_embd']//config['n_head'], + feed_forward_dimensions=config['n_embd'], + attention_type='linear', + # unless we do deterministic_eval here, we will have random outputs + feature_map=partial(GeneralizedRandomFeatures, + n_dims=config['num_feats'], + deterministic_eval=eval), + activation='gelu' + ) + self.blocks = builder.get() + + # classification + self.lang_model = LangLayer(config['n_embd'], n_vocab) + + +class MoLDecoder(nn.Module): + + def __init__(self, n_vocab, max_len, n_embd, n_gpu=None): + super(MoLDecoder, self).__init__() + + self.max_len = max_len + self.n_embd = n_embd + self.n_gpu = n_gpu + self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd) + self.lang_model = LangLayer(n_embd, n_vocab) + + +class Smi_ted(nn.Module): + """materials.smi-ted-Light 289M Parameters""" + + def __init__(self, tokenizer, config=None, eval=False): + super(Smi_ted, self).__init__() + + # configuration + self.config = config + self.tokenizer = tokenizer + self.padding_idx = tokenizer.get_padding_idx() + self.n_vocab = len(self.tokenizer.vocab) + self.is_cuda_available = torch.cuda.is_available() + + # instantiate modules + if self.config: + self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval) + self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd']) + self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout']) + + def load_checkpoint(self, ckpt_path, n_output, eval=False): + # load checkpoint file + checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) + + # load hyparameters + self.config = checkpoint['hparams'] + self.max_len = self.config['max_len'] + self.n_embd = self.config['n_embd'] + self._set_seed(self.config['seed']) + + # instantiate modules + self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval) + self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd) + self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout']) + + # load weights + if 'state_dict' in checkpoint: + if isinstance(checkpoint['state_dict'], list): + self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False) + self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False) + else: + self.load_state_dict(checkpoint['state_dict'], strict=False) + elif 'MODEL_STATE' in checkpoint: + self.load_state_dict(checkpoint['MODEL_STATE'], strict=False) + + # load RNG states each time the model and states are loaded from checkpoint + if 'rng' in self.config: + rng = self.config['rng'] + for key, value in rng.items(): + if key =='torch_state': + torch.set_rng_state(value.cpu()) + elif key =='cuda_state': + torch.cuda.set_rng_state(value.cpu()) + elif key =='numpy_state': + np.random.set_state(value) + elif key =='python_state': + random.setstate(value) + else: + print('unrecognized state') + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_seed(self, value): + print('Random Seed:', value) + random.seed(value) + torch.manual_seed(value) + torch.cuda.manual_seed(value) + torch.cuda.manual_seed_all(value) + np.random.seed(value) + cudnn.deterministic = True + cudnn.benchmark = False + + def tokenize(self, smiles): + """Tokenize a string into tokens.""" + if isinstance(smiles, str): + batch = [smiles] + else: + batch = smiles + + tokens = self.tokenizer( + batch, + padding=True, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + max_length=self.max_len, + ) + + idx = tokens['input_ids'].clone().detach() + mask = tokens['attention_mask'].clone().detach() + + if self.is_cuda_available: + return idx.cuda(), mask.cuda() + + return idx, mask + + def extract_embeddings(self, smiles): + """Extract token and SMILES embeddings.""" + if self.is_cuda_available: + self.encoder.cuda() + self.decoder.cuda() + + # tokenizer + idx, mask = self.tokenize(smiles) + + # transformer encoder + x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector + x = self.encoder.drop(x) + x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1])) + + # add padding + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + mask_embeddings = (token_embeddings * input_mask_expanded) + token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0) + + # aggregate token embeddings (similar to mean pooling) + # CAUTION: use the embeddings from the autoencoder. + smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd)) + + return smiles_embeddings + + def __str__(self): + return 'smi-ted-Light' + + +def load_smi_ted(folder="./smi_ted_light", + ckpt_filename="smi-ted-Light_40.pt", + vocab_filename="bert_vocab_curated.txt", + n_output=1, + eval=False + ): + tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename)) + model = Smi_ted(tokenizer) + model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval) + print('Vocab size:', len(tokenizer.vocab)) + print(f'[FINETUNE MODE - {str(model)}]') + return model \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..21d42f953b398fe05a6edb592d4d1da9275ec844 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E1-CAM' \ + --checkpoints_folder './checkpoints_QM8-E1-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..79f7fbade5c91e7d61bd94121423a76b3db600c8 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E1-CC2' \ + --checkpoints_folder './checkpoints_QM8-E1-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..0128628560ad2b08b0959deaeff7ddb1d4d01239 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E1-PBE0' \ + --checkpoints_folder './checkpoints_QM8-E1-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..db7eec5338116d1e213a9f81a4704df525105701 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E2-CAM' \ + --checkpoints_folder './checkpoints_QM8-E2-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..f769c4b78585d0f6a3b64e56e4b634ee1a40c3db --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E2-CC2' \ + --checkpoints_folder './checkpoints_QM8-E2-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..39abacf5d6b2e917b82b45c357363f045dea54c1 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'E2-PBE0' \ + --checkpoints_folder './checkpoints_QM8-E2-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..91ee475d98f5b0dc6719000e1977635f2636e22e --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f1-CAM' \ + --checkpoints_folder './checkpoints_QM8-f1-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..9603905036f6614a5de8a8a71be259d099c50a6e --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f1-CC2' \ + --checkpoints_folder './checkpoints_QM8-f1-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..874b9d0a357d058a94796c29dc10adf02befc568 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f1-PBE0' \ + --checkpoints_folder './checkpoints_QM8-f1-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CAM_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..40832c450e3e3fc9b98c9715de8a5a15e5509ded --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CAM_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f2-CAM' \ + --checkpoints_folder './checkpoints_QM8-f2-CAM' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CC2_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7356e62ec55f32d4902bb8a4a3896ac7849e748 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CC2_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f2-CC2' \ + --checkpoints_folder './checkpoints_QM8-f2-CC2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-PBE0_set.sh new file mode 100644 index 0000000000000000000000000000000000000000..855c5223471ddde0940f05ac6471e9481c727a8d --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-PBE0_set.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 16 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-6 \ + --lr_multiplier 1 \ + --max_epochs 720 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm8' \ + --dataset_name qm8 \ + --measure_name 'f2-PBE0' \ + --checkpoints_folder './checkpoints_QM8-f2-PBE0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_alpha.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_alpha.sh new file mode 100644 index 0000000000000000000000000000000000000000..62111cac55825827603340cc5d5bc45218339ad5 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_alpha.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'alpha' \ + --checkpoints_folder './checkpoints_QM9-alpha' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_cv.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_cv.sh new file mode 100644 index 0000000000000000000000000000000000000000..f499840a8fea8821d7703efe79bb07661b5767df --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_cv.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'cv' \ + --checkpoints_folder './checkpoints_QM9-cv' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_g298.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_g298.sh new file mode 100644 index 0000000000000000000000000000000000000000..fd001dda63f72fc301b11072a06d8dc62e54b5e2 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_g298.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'g298' \ + --checkpoints_folder './checkpoints_QM9-g298' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_gap.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_gap.sh new file mode 100644 index 0000000000000000000000000000000000000000..8031170234e71cc55af842b6231fd448bdf34b99 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_gap.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'gap' \ + --checkpoints_folder './checkpoints_QM9-gap' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_h298.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_h298.sh new file mode 100644 index 0000000000000000000000000000000000000000..c1945ad80e05d6916ec35c264361b938cd1333f0 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_h298.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'h298' \ + --checkpoints_folder './checkpoints_QM9-h298' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_homo.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_homo.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb9f4a6fab256465e60ee21530957198e8160f58 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_homo.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'homo' \ + --checkpoints_folder './checkpoints_QM9-homo' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_lumo.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_lumo.sh new file mode 100644 index 0000000000000000000000000000000000000000..d012bd7167b4997290a0ee0659f988748e9f83e7 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_lumo.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'lumo' \ + --checkpoints_folder './checkpoints_QM9-lumo' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_mu.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_mu.sh new file mode 100644 index 0000000000000000000000000000000000000000..ac604c0c050401deace608447fb1f7089a4af4b6 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_mu.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'mu' \ + --checkpoints_folder './checkpoints_QM9-mu' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_r2.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_r2.sh new file mode 100644 index 0000000000000000000000000000000000000000..d688eb79d433dcda120bda929fbfd410d6c193bb --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_r2.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'r2' \ + --checkpoints_folder './checkpoints_QM9-r2' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u0.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u0.sh new file mode 100644 index 0000000000000000000000000000000000000000..1ff6506190997a57fc0ce9655e9e358984204ff5 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u0.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'u0' \ + --checkpoints_folder './checkpoints_QM9-u0' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u298.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u298.sh new file mode 100644 index 0000000000000000000000000000000000000000..880c6a8f6c0a351218efe17bbdcd2581bd0dd6f8 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u298.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'u298' \ + --checkpoints_folder './checkpoints_QM9-u298' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_zpve.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_zpve.sh new file mode 100644 index 0000000000000000000000000000000000000000..45adaf0769fddb8e8761d4ffe24c26a47b3940a0 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_zpve.sh @@ -0,0 +1,24 @@ +python ../../finetune_regression.py \ + --n_batch 128 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/qm9' \ + --dataset_name qm9 \ + --measure_name 'zpve' \ + --checkpoints_folder './checkpoints_QM9-zpve' \ + --loss_fn 'mae' \ + --target_metric 'mae' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 1 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/sider/run_finetune_sider.sh b/models/smi_ted/finetune/smi_ted_light/sider/run_finetune_sider.sh new file mode 100644 index 0000000000000000000000000000000000000000..bb9d03d7d920bd3c54f97343f868b1859cea1e7c --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/sider/run_finetune_sider.sh @@ -0,0 +1,23 @@ +python ../../finetune_classification_multitask.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 500 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/sider' \ + --dataset_name sider \ + --checkpoints_folder './checkpoints_sider' \ + --loss_fn 'bceloss' \ + --target_metric 'roc-auc' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/smi_ted_light/tox21/run_finetune_tox21.sh b/models/smi_ted/finetune/smi_ted_light/tox21/run_finetune_tox21.sh new file mode 100644 index 0000000000000000000000000000000000000000..46a37d65e3d05ae0c95faed776c5cb90829726a5 --- /dev/null +++ b/models/smi_ted/finetune/smi_ted_light/tox21/run_finetune_tox21.sh @@ -0,0 +1,23 @@ +python ../../finetune_classification_multitask.py \ + --n_batch 32 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.1 \ + --dropout 0.1 \ + --lr_start 3e-5 \ + --lr_multiplier 1 \ + --max_epochs 100 \ + --num_feats 32 \ + --smi_ted_version 'v1' \ + --model_path '../' \ + --ckpt_filename 'smi-ted-Light_40.pt' \ + --data_root '../../moleculenet/tox21' \ + --dataset_name tox21 \ + --checkpoints_folder './checkpoints_tox21' \ + --loss_fn 'bceloss' \ + --target_metric 'roc-auc' \ + --save_ckpt 1 \ + --start_seed 0 \ + --train_decoder 0 \ \ No newline at end of file diff --git a/models/smi_ted/finetune/trainers.py b/models/smi_ted/finetune/trainers.py new file mode 100644 index 0000000000000000000000000000000000000000..4db4917c03d77b5af472b3727d209bf701336ca5 --- /dev/null +++ b/models/smi_ted/finetune/trainers.py @@ -0,0 +1,591 @@ +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from utils import CustomDataset, CustomDatasetMultitask, RMSELoss, normalize_smiles + +# Data +import pandas as pd +import numpy as np + +# Standard library +import random +import args +import os +import shutil +from tqdm import tqdm + +# Machine Learning +from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score, roc_auc_score, roc_curve, auc, precision_recall_curve +from scipy import stats +from utils import RMSE, sensitivity, specificity + + +class Trainer: + + def __init__(self, raw_data, dataset_name, target, batch_size, hparams, + target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): + # data + self.df_train = raw_data[0] + self.df_valid = raw_data[1] + self.df_test = raw_data[2] + self.dataset_name = dataset_name + self.target = target + self.batch_size = batch_size + self.hparams = hparams + self._prepare_data() + + # config + self.target_metric = target_metric + self.seed = seed + self.smi_ted_version = smi_ted_version + self.checkpoints_folder = checkpoints_folder + self.restart_filename = restart_filename + self.start_epoch = 1 + self.save_every_epoch = save_every_epoch + self.save_ckpt = save_ckpt + self.device = device + self.best_vloss = float('inf') + self.last_filename = None + self._set_seed(seed) + + def _prepare_data(self): + # normalize dataset + self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles) + self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles) + self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles) + + self.df_train = self.df_train.dropna(subset=['canon_smiles']) + self.df_valid = self.df_valid.dropna(subset=['canon_smiles']) + self.df_test = self.df_test.dropna(subset=['canon_smiles']) + + # create dataloader + self.train_loader = DataLoader( + CustomDataset(self.df_train, self.target), + batch_size=self.batch_size, + shuffle=True, + pin_memory=True + ) + self.valid_loader = DataLoader( + CustomDataset(self.df_valid, self.target), + batch_size=self.batch_size, + shuffle=False, + pin_memory=True + ) + self.test_loader = DataLoader( + CustomDataset(self.df_test, self.target), + batch_size=self.batch_size, + shuffle=False, + pin_memory=True + ) + + def compile(self, model, optimizer, loss_fn): + self.model = model + self.optimizer = optimizer + self.loss_fn = loss_fn + self._print_configuration() + if self.restart_filename: + self._load_checkpoint(self.restart_filename) + print('Checkpoint restored!') + + def fit(self, max_epochs=500): + for epoch in range(self.start_epoch, max_epochs+1): + print(f'\n=====Epoch [{epoch}/{max_epochs}]=====') + + # training + self.model.to(self.device) + self.model.train() + train_loss = self._train_one_epoch() + + # validation + self.model.eval() + val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader) + for m in val_metrics.keys(): + print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}") + + ############################### Save Finetune checkpoint ####################################### + if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt: + # remove old checkpoint + if (self.last_filename != None) and (not self.save_every_epoch): + os.remove(os.path.join(self.checkpoints_folder, self.last_filename)) + + # filename + model_name = f'{str(self.model)}-Finetune' + self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt" + + # update best loss + self.best_vloss = val_loss + + # save checkpoint + print('Saving checkpoint...') + self._save_checkpoint(epoch, self.last_filename) + + def evaluate(self, verbose=True): + if verbose: + print("\n=====Test Evaluation=====") + + if self.smi_ted_version == 'v1': + import smi_ted_light.load as load + elif self.smi_ted_version == 'v2': + import smi_ted_large.load as load + else: + raise Exception('Please, specify the SMI-TED version: `v1` or `v2`.') + + # copy vocabulary to checkpoint folder + if not os.path.exists(os.path.join(self.checkpoints_folder, 'bert_vocab_curated.txt')): + smi_ted_path = os.path.dirname(load.__file__) + shutil.copy(os.path.join(smi_ted_path, 'bert_vocab_curated.txt'), self.checkpoints_folder) + + # load model for inference + model_inf = load.load_smi_ted( + folder=self.checkpoints_folder, + ckpt_filename=self.last_filename, + eval=True, + ).to(self.device) + + # set model evaluation mode + model_inf.eval() + + # evaluate on test set + tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf) + + if verbose: + # show metrics + for m in tst_metrics.keys(): + print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}") + + # save predictions + pd.DataFrame(tst_preds).to_csv( + os.path.join( + self.checkpoints_folder, + f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'), + index=False + ) + + def _train_one_epoch(self): + raise NotImplementedError + + def _validate_one_epoch(self, data_loader, model=None): + raise NotImplementedError + + def _print_configuration(self): + print('----Finetune information----') + print('Dataset:\t', self.dataset_name) + print('Target:\t\t', self.target) + print('Batch size:\t', self.batch_size) + print('LR:\t\t', self._get_lr()) + print('Device:\t\t', self.device) + print('Optimizer:\t', self.optimizer.__class__.__name__) + print('Loss function:\t', self.loss_fn.__class__.__name__) + print('Seed:\t\t', self.seed) + print('Train size:\t', self.df_train.shape[0]) + print('Valid size:\t', self.df_valid.shape[0]) + print('Test size:\t', self.df_test.shape[0]) + + def _load_checkpoint(self, filename): + ckpt_path = os.path.join(self.checkpoints_folder, filename) + ckpt_dict = torch.load(ckpt_path, map_location='cpu') + self.model.load_state_dict(ckpt_dict['MODEL_STATE']) + self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1 + self.best_vloss = ckpt_dict['finetune_info']['best_vloss'] + + def _save_checkpoint(self, current_epoch, filename): + if not os.path.exists(self.checkpoints_folder): + os.makedirs(self.checkpoints_folder) + + ckpt_dict = { + 'MODEL_STATE': self.model.state_dict(), + 'EPOCHS_RUN': current_epoch, + 'hparams': vars(self.hparams), + 'finetune_info': { + 'dataset': self.dataset_name, + 'target`': self.target, + 'batch_size': self.batch_size, + 'lr': self._get_lr(), + 'device': self.device, + 'optim': self.optimizer.__class__.__name__, + 'loss_fn': self.loss_fn.__class__.__name__, + 'train_size': self.df_train.shape[0], + 'valid_size': self.df_valid.shape[0], + 'test_size': self.df_test.shape[0], + 'best_vloss': self.best_vloss, + }, + 'seed': self.seed, + } + + assert list(ckpt_dict.keys()) == ['MODEL_STATE', 'EPOCHS_RUN', 'hparams', 'finetune_info', 'seed'] + + torch.save(ckpt_dict, os.path.join(self.checkpoints_folder, filename)) + + def _set_seed(self, value): + random.seed(value) + torch.manual_seed(value) + np.random.seed(value) + if torch.cuda.is_available(): + torch.cuda.manual_seed(value) + torch.cuda.manual_seed_all(value) + cudnn.deterministic = True + cudnn.benchmark = False + + def _get_lr(self): + for param_group in self.optimizer.param_groups: + return param_group['lr'] + + +class TrainerRegressor(Trainer): + + def __init__(self, raw_data, dataset_name, target, batch_size, hparams, + target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): + super().__init__(raw_data, dataset_name, target, batch_size, hparams, + target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device) + + def _train_one_epoch(self): + running_loss = 0.0 + + for idx, data in enumerate(pbar := tqdm(self.train_loader)): + # Every data instance is an input + label pair + smiles, targets = data + targets = targets.clone().detach().to(self.device) + + # zero the parameter gradients (otherwise they are accumulated) + self.optimizer.zero_grad() + + # Make predictions for this batch + embeddings = self.model.extract_embeddings(smiles).to(self.device) + outputs = self.model.net(embeddings).squeeze() + + # Compute the loss and its gradients + loss = self.loss_fn(outputs, targets) + loss.backward() + + # Adjust learning weights + self.optimizer.step() + + # print statistics + running_loss += loss.item() + + # progress bar + pbar.set_description('[TRAINING]') + pbar.set_postfix(loss=running_loss/(idx+1)) + pbar.refresh() + + return running_loss / len(self.train_loader) + + def _validate_one_epoch(self, data_loader, model=None): + data_targets = [] + data_preds = [] + running_loss = 0.0 + + model = self.model if model is None else model + + with torch.no_grad(): + for idx, data in enumerate(pbar := tqdm(data_loader)): + # Every data instance is an input + label pair + smiles, targets = data + targets = targets.clone().detach().to(self.device) + + # Make predictions for this batch + embeddings = model.extract_embeddings(smiles).to(self.device) + predictions = model.net(embeddings).squeeze() + + # Compute the loss + loss = self.loss_fn(predictions, targets) + + data_targets.append(targets.view(-1)) + data_preds.append(predictions.view(-1)) + + # print statistics + running_loss += loss.item() + + # progress bar + pbar.set_description('[EVALUATION]') + pbar.set_postfix(loss=running_loss/(idx+1)) + pbar.refresh() + + # Put together predictions and labels from batches + preds = torch.cat(data_preds, dim=0).cpu().numpy() + tgts = torch.cat(data_targets, dim=0).cpu().numpy() + + # Compute metrics + mae = mean_absolute_error(tgts, preds) + r2 = r2_score(tgts, preds) + rmse = RMSE(preds, tgts) + spearman = stats.spearmanr(tgts, preds).statistic # scipy 1.12.0 + + # Rearange metrics + metrics = { + 'mae': mae, + 'r2': r2, + 'rmse': rmse, + 'spearman': spearman, + } + + return preds, running_loss / len(data_loader), metrics + + +class TrainerClassifier(Trainer): + + def __init__(self, raw_data, dataset_name, target, batch_size, hparams, + target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): + super().__init__(raw_data, dataset_name, target, batch_size, hparams, + target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device) + + def _train_one_epoch(self): + running_loss = 0.0 + + for idx, data in enumerate(pbar := tqdm(self.train_loader)): + # Every data instance is an input + label pair + smiles, targets = data + targets = targets.clone().detach().to(self.device) + + # zero the parameter gradients (otherwise they are accumulated) + self.optimizer.zero_grad() + + # Make predictions for this batch + embeddings = self.model.extract_embeddings(smiles).to(self.device) + outputs = self.model.net(embeddings).squeeze() + + # Compute the loss and its gradients + loss = self.loss_fn(outputs, targets.long()) + loss.backward() + + # Adjust learning weights + self.optimizer.step() + + # print statistics + running_loss += loss.item() + + # progress bar + pbar.set_description('[TRAINING]') + pbar.set_postfix(loss=running_loss/(idx+1)) + pbar.refresh() + + return running_loss / len(self.train_loader) + + def _validate_one_epoch(self, data_loader, model=None): + data_targets = [] + data_preds = [] + running_loss = 0.0 + + model = self.model if model is None else model + + with torch.no_grad(): + for idx, data in enumerate(pbar := tqdm(data_loader)): + # Every data instance is an input + label pair + smiles, targets = data + targets = targets.clone().detach().to(self.device) + + # Make predictions for this batch + embeddings = model.extract_embeddings(smiles).to(self.device) + predictions = model.net(embeddings).squeeze() + + # Compute the loss + loss = self.loss_fn(predictions, targets.long()) + + data_targets.append(targets.view(-1)) + data_preds.append(predictions) + + # print statistics + running_loss += loss.item() + + # progress bar + pbar.set_description('[EVALUATION]') + pbar.set_postfix(loss=running_loss/(idx+1)) + pbar.refresh() + + # Put together predictions and labels from batches + preds = torch.cat(data_preds, dim=0).cpu().numpy() + tgts = torch.cat(data_targets, dim=0).cpu().numpy() + + # Compute metrics + preds_cpu = F.softmax(torch.tensor(preds), dim=1).cpu().numpy()[:, 1] + + # accuracy + y_pred = np.where(preds_cpu >= 0.5, 1, 0) + accuracy = accuracy_score(tgts, y_pred) + + # sensitivity + sn = sensitivity(tgts, y_pred) + + # specificity + sp = specificity(tgts, y_pred) + + # roc-auc + fpr, tpr, _ = roc_curve(tgts, preds_cpu) + roc_auc = auc(fpr, tpr) + + # prc-auc + precision, recall, _ = precision_recall_curve(tgts, preds_cpu) + prc_auc = auc(recall, precision) + + # Rearange metrics + metrics = { + 'acc': accuracy, + 'roc-auc': roc_auc, + 'prc-auc': prc_auc, + 'sensitivity': sn, + 'specificity': sp, + } + + return preds, running_loss / len(data_loader), metrics + + +class TrainerClassifierMultitask(Trainer): + + def __init__(self, raw_data, dataset_name, target, batch_size, hparams, + target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): + super().__init__(raw_data, dataset_name, target, batch_size, hparams, + target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device) + + def _prepare_data(self): + # normalize dataset + self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles) + self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles) + self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles) + + self.df_train = self.df_train.dropna(subset=['canon_smiles']) + self.df_valid = self.df_valid.dropna(subset=['canon_smiles']) + self.df_test = self.df_test.dropna(subset=['canon_smiles']) + + # create dataloader + self.train_loader = DataLoader( + CustomDatasetMultitask(self.df_train, self.target), + batch_size=self.batch_size, + shuffle=True, + pin_memory=True + ) + self.valid_loader = DataLoader( + CustomDatasetMultitask(self.df_valid, self.target), + batch_size=self.batch_size, + shuffle=False, + pin_memory=True + ) + self.test_loader = DataLoader( + CustomDatasetMultitask(self.df_test, self.target), + batch_size=self.batch_size, + shuffle=False, + pin_memory=True + ) + + def _train_one_epoch(self): + running_loss = 0.0 + + for idx, data in enumerate(pbar := tqdm(self.train_loader)): + # Every data instance is an input + label pair + mask + smiles, targets, target_masks = data + targets = targets.clone().detach().to(self.device) + + # zero the parameter gradients (otherwise they are accumulated) + self.optimizer.zero_grad() + + # Make predictions for this batch + embeddings = self.model.extract_embeddings(smiles).to(self.device) + outputs = self.model.net(embeddings, multitask=True).squeeze() + outputs = outputs * target_masks.to(self.device) + + # Compute the loss and its gradients + loss = self.loss_fn(outputs, targets) + loss.backward() + + # Adjust learning weights + self.optimizer.step() + + # print statistics + running_loss += loss.item() + + # progress bar + pbar.set_description('[TRAINING]') + pbar.set_postfix(loss=running_loss/(idx+1)) + pbar.refresh() + + return running_loss / len(self.train_loader) + + def _validate_one_epoch(self, data_loader, model=None): + data_targets = [] + data_preds = [] + data_masks = [] + running_loss = 0.0 + + model = self.model if model is None else model + + with torch.no_grad(): + for idx, data in enumerate(pbar := tqdm(data_loader)): + # Every data instance is an input + label pair + mask + smiles, targets, target_masks = data + targets = targets.clone().detach().to(self.device) + + # Make predictions for this batch + embeddings = model.extract_embeddings(smiles).to(self.device) + predictions = model.net(embeddings, multitask=True).squeeze() + predictions = predictions * target_masks.to(self.device) + + # Compute the loss + loss = self.loss_fn(predictions, targets) + + data_targets.append(targets) + data_preds.append(predictions) + data_masks.append(target_masks) + + # print statistics + running_loss += loss.item() + + # progress bar + pbar.set_description('[EVALUATION]') + pbar.set_postfix(loss=running_loss/(idx+1)) + pbar.refresh() + + # Put together predictions and labels from batches + preds = torch.cat(data_preds, dim=0) + tgts = torch.cat(data_targets, dim=0) + mask = torch.cat(data_masks, dim=0) + mask = mask > 0 + + # Compute metrics + roc_aucs = [] + prc_aucs = [] + sns = [] + sps = [] + num_tasks = len(self.target) + for idx in range(num_tasks): + actuals_task = torch.masked_select(tgts[:, idx], mask[:, idx].to(self.device)) + preds_task = torch.masked_select(preds[:, idx], mask[:, idx].to(self.device)) + + # accuracy + y_pred = np.where(preds_task.cpu().detach() >= 0.5, 1, 0) + accuracy = accuracy_score(actuals_task.cpu().numpy(), y_pred) + + # sensitivity + sn = sensitivity(actuals_task.cpu().numpy(), y_pred) + + # specificity + sp = specificity(actuals_task.cpu().numpy(), y_pred) + + # roc-auc + roc_auc = roc_auc_score(actuals_task.cpu().numpy(), preds_task.cpu().numpy()) + + # prc-auc + precision, recall, thresholds = precision_recall_curve(actuals_task.cpu().numpy(), preds_task.cpu().numpy()) + prc_auc = auc(recall, precision) + + # append + sns.append(sn) + sps.append(sp) + roc_aucs.append(roc_auc) + prc_aucs.append(prc_auc) + average_sn = torch.mean(torch.tensor(sns)) + average_sp = torch.mean(torch.tensor(sps)) + average_roc_auc = torch.mean(torch.tensor(roc_aucs)) + average_prc_auc = torch.mean(torch.tensor(prc_aucs)) + + # Rearange metrics + metrics = { + 'acc': accuracy, + 'roc-auc': average_roc_auc.item(), + 'prc-auc': average_prc_auc.item(), + 'sensitivity': average_sn.item(), + 'specificity': average_sp.item(), + } + + return preds.cpu().numpy(), running_loss / len(data_loader), metrics \ No newline at end of file diff --git a/models/smi_ted/finetune/utils.py b/models/smi_ted/finetune/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0fd79adf4edd67417138cb258c91ab0201b5de --- /dev/null +++ b/models/smi_ted/finetune/utils.py @@ -0,0 +1,115 @@ +# Deep learning +import torch +from torch.utils.data import Dataset +from sklearn.metrics import confusion_matrix + +# Data +import pandas as pd +import numpy as np + +# Standard library +import os + +# Chemistry +from rdkit import Chem +from rdkit.Chem import PandasTools +from rdkit.Chem import Descriptors +PandasTools.RenderImagesInAllDataFrames(True) + + +def normalize_smiles(smi, canonical=True, isomeric=False): + try: + normalized = Chem.MolToSmiles( + Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric + ) + except: + normalized = None + return normalized + + +class RMSELoss: + def __init__(self): + pass + + def __call__(self, yhat, y): + return torch.sqrt(torch.mean((yhat-y)**2)) + + +def RMSE(predictions, targets): + return np.sqrt(((predictions - targets) ** 2).mean()) + + +def sensitivity(y_true, y_pred): + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + return (tp/(tp+fn)) + + +def specificity(y_true, y_pred): + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + return (tn/(tn+fp)) + + +def get_optim_groups(module, keep_decoder=False): + # setup optimizer + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in module.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if not keep_decoder and 'decoder' in fpn: # exclude decoder components + continue + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in module.named_parameters()} + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + return optim_groups + + +class CustomDataset(Dataset): + def __init__(self, dataset, target): + self.dataset = dataset + self.target = target + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + smiles = self.dataset['canon_smiles'].iloc[idx] + labels = self.dataset[self.target].iloc[idx] + return smiles, labels + + +class CustomDatasetMultitask(Dataset): + def __init__(self, dataset, targets): + self.dataset = dataset + self.targets = targets + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + smiles = self.dataset['canon_smiles'].iloc[idx] + labels = self.dataset[self.targets].iloc[idx].to_numpy() + mask = [0.0 if np.isnan(x) else 1.0 for x in labels] + labels = [0.0 if np.isnan(x) else x for x in labels] + return smiles, torch.tensor(labels, dtype=torch.float32), torch.tensor(mask) \ No newline at end of file diff --git a/models/smi_ted/images/smi-ted.png b/models/smi_ted/images/smi-ted.png new file mode 100644 index 0000000000000000000000000000000000000000..4f1688456dc60e5adef22533de0fc0b1f0e3b561 --- /dev/null +++ b/models/smi_ted/images/smi-ted.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa41339c9e8c14412f05dcaa5f42d4d185e101dff4d97b82749cedf671678a71 +size 1891667 diff --git a/models/smi_ted/inference/smi_ted_large/bert_vocab_curated.txt b/models/smi_ted/inference/smi_ted_large/bert_vocab_curated.txt new file mode 100644 index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd --- /dev/null +++ b/models/smi_ted/inference/smi_ted_large/bert_vocab_curated.txt @@ -0,0 +1,2393 @@ +<bos> +<eos> +<pad> +<mask> +C +c +( +) +1 +O +N +2 += +n +3 +[C@H] +[C@@H] +F +S +4 +Cl +- +o +s +[nH] +# +/ +Br +[C@] +[C@@] +[N+] +[O-] +5 +\ +. +I +6 +[S@] +[S@@] +P +[N-] +[Si] +7 +[n+] +[2H] +8 +[NH+] +B +9 +[C-] +[Na+] +[Cl-] +[c-] +[CH] +%10 +[NH2+] +[P+] +[B] +[I-] +%11 +[CH2-] +[O+] +[NH3+] +[C] +[Br-] +[IH2] +[S-] +[cH-] +%12 +[nH+] +[B-] +[K+] +[Sn] +[Se] +[CH-] +[HH] +[Y] +[n-] +[CH3-] +[SiH] +[S+] +%13 +[SiH2] +[Li+] +[NH-] +%14 +[Na] +[CH2] +[O-2] +[U+2] +[W] +[Al] +[P@] +[Fe+2] +[PH+] +%15 +[Cl+3] +[Zn+2] +[Ir] +[Mg+2] +[Pt+2] +[OH2+] +[As] +[Fe] +[OH+] +[Zr+2] +[3H] +[Ge] +[SiH3] +[OH-] +[NH4+] +[Cu+2] +[P@@] +p +[Pt] +%16 +[Ca+2] +[Zr] +[F-] +[C+] +[Ti] +[P-] +[V] +[se] +[U] +[O] +[Ni+2] +[Zn] +[Co] +[Ni] +[Pd+2] +[Cu] +%17 +[Cu+] +[Te] +[H+] +[CH+] +[Li] +[Pd] +[Mo] +[Ru+2] +[o+] +[Re] +[SH+] +%18 +[Ac] +[Cr] +[NH2-] +[K] +[13CH2] +[c] +[Zr+4] +[Tl] +[13C] +[Mn] +[N@+] +[Hg] +[Rh] +[Ti+4] +[Sb] +[Co+2] +[Ag+] +[Ru] +%19 +[N@@+] +[Ti+2] +[Al+3] +[Pb] +[I+] +[18F] +[s+] +[Rb+] +[Ba+2] +[H-] +[Fe+3] +[Ir+3] +[13cH] +%20 +[AlH2] +[Au+] +[13c] +[SH2+] +[Sn+2] +[Mn+2] +[Si-] +[Ag] +[N] +[Bi] +%21 +[In] +[CH2+] +[Y+3] +[Ga] +%22 +[Co+3] +[Au] +[13CH3] +[Mg] +[Cs+] +[W+2] +[Hf] +[Zn+] +[Se-] +[S-2] +[Ca] +[pH] +[ClH+] +[Ti+3] +%23 +[Ru+] +[SH-] +[13CH] +[IH+] +[Hf+4] +[Rf] +[OH3+] +%24 +[Pt+4] +[Zr+3] +[PH3+] +[Sr+2] +[Cd+2] +[Cd] +%25 +[Os] +[BH-] +[Sn+4] +[Cr+3] +[Ru+3] +[PH2+] +[Rh+2] +[V+2] +%26 +[Gd+3] +[Pb+2] +[PH] +[Hg+] +[Mo+2] +[AlH] +[Sn+] +%27 +[Pd+] +b +[Rh+3] +[Hg+2] +[15NH] +[14C] +%28 +[Mn+3] +[Si+] +[SeH] +[13C@H] +[NH] +[Ga+3] +[SiH-] +[13C@@H] +[Ce] +[Au+3] +[Bi+3] +[15N] +%29 +[BH3-] +[14cH] +[Ti+] +[Gd] +[cH+] +[Cr+2] +[Sb-] +%30 +[Be+2] +[Al+] +[te] +[11CH3] +[Sm] +[Pr] +[La] +%31 +[Al-] +[Ta] +[125I] +[BH2-] +[Nb] +[Si@] +%32 +[14c] +[Sb+3] +[Ba] +%33 +[Os+2] +[Si@@] +[La+3] +[15n] +[15NH2] +[Nd+3] +%34 +[14CH2] +[18O] +[Nd] +[GeH] +[Ni+3] +[Eu] +[Dy+3] +[Sc] +%36 +[Se-2] +[As+] +%35 +[AsH] +[Tb] +[Sb+5] +[Se+] +[Ce+3] +[c+] +[In+3] +[SnH] +[Mo+4] +%37 +[V+4] +[Eu+3] +[Hf+2] +%38 +[Pt+] +[p+] +[123I] +[Tl+] +[Sm+3] +%39 +[Yb+3] +%40 +[Yb] +[Os+] +%41 +[10B] +[Sc+3] +[Al+2] +%42 +[Sr] +[Tb+3] +[Po] +[Tc] +[PH-] +[AlH3] +[Ar] +[U+4] +[SnH2] +[Cl+2] +[si] +[Fe+] +[14CH3] +[U+3] +[Cl+] +%43 +[GeH2] +%44 +[Er+3] +[Mo+3] +[I+2] +[Fe+4] +[99Tc] +%45 +[11C] +%46 +[SnH3] +[S] +[Te+] +[Er] +[Lu+3] +[11B] +%47 +%48 +[P] +[Tm] +[Th] +[Dy] +[Pr+3] +[Ta+5] +[Nb+5] +[Rb] +[GeH3] +[Br+2] +%49 +[131I] +[Fm] +[Cs] +[BH4-] +[Lu] +[15nH] +%50 +[Ru+6] +[b-] +[Ho] +[Th+4] +[Ru+4] +%52 +[14CH] +%51 +[Cr+6] +[18OH] +[Ho+3] +[Ce+4] +[Bi+2] +[Co+] +%53 +[Yb+2] +[Fe+6] +[Be] +%54 +[SH3+] +[Np] +[As-] +%55 +[14C@@H] +[Ir+2] +[GaH3] +[p-] +[GeH4] +[Sn+3] +[Os+4] +%56 +[14C@H] +[sH+] +[19F] +[Eu+2] +[TlH] +%57 +[Cr+4] +%58 +[B@@-] +[SiH+] +[At] +[Am] +[Fe+5] +[AsH2] +[Si+4] +[B@-] +[Pu] +[SbH] +[P-2] +[Tm+3] +* +%59 +[se+] +[IH-] +%60 +[oH+] +[1H] +[15N+] +[124I] +[S@@+] +[P-3] +[H] +[IH2+] +[TeH] +[Xe] +[PH4+] +[Cr+] +[Cm] +[I+3] +%61 +[Nb+2] +[Ru+5] +%62 +[Ta+2] +[Tc+4] +[CH3+] +[Pm] +[Si@H] +[No] +%63 +[Cr+5] +[Th+2] +[Zn-2] +[13C@] +[Lr] +%64 +[99Tc+3] +%65 +[13C@@] +%66 +[Fe-] +[17O] +[siH] +[Sb+] +[OH] +[IH] +[11CH2] +[Cf] +[SiH2+] +[Gd+2] +[In+] +[Si@@H] +[Mn+] +[99Tc+4] +[Ga-] +%67 +[S@+] +[Ge+4] +[Tl+3] +[16OH] +%68 +[2H-] +[Ra] +[si-] +[NiH2] +[P@@H] +[Rh+] +[12C] +[35S] +[32P] +[SiH2-] +[AlH2+] +[16O] +%69 +[BiH] +[BiH2] +[Zn-] +[BH] +[Tc+3] +[Ir+] +[Ni+] +%70 +[InH2] +[InH] +[Nb+3] +[PbH] +[Bi+] +%71 +[As+3] +%72 +[18O-] +[68Ga+3] +%73 +[Pa] +[76Br] +[Tc+5] +[pH+] +[64Cu+2] +[Ru+8] +%74 +[PH2-] +[Si+2] +[17OH] +[RuH] +[111In+3] +[AlH+] +%75 +%76 +[W+] +[SbH2] +[PoH] +[Ru-] +[XeH] +[Tc+2] +[13C-] +[Br+] +[Pt-2] +[Es] +[Cu-] +[Mg+] +[3HH] +[P@H] +[ClH2+] +%77 +[SH] +[Au-] +[2HH] +%78 +[Sn-] +[11CH] +[PdH2] +0 +[Os+6] +%79 +[Mo+] +%80 +[al] +[PbH2] +[64Cu] +[Cl] +[12CH3] +%81 +[Tc+7] +[11c] +%82 +[Li-] +[99Tc+5] +[He] +[12c] +[Kr] +[RuH+2] +[35Cl] +[Pd-2] +[GaH2] +[4H] +[Sg] +[Cu-2] +[Br+3] +%83 +[37Cl] +[211At] +[IrH+2] +[Mt] +[Ir-2] +[In-] +[12cH] +[12CH2] +[RuH2] +[99Tc+7] +%84 +[15n+] +[ClH2+2] +[16N] +[111In] +[Tc+] +[Ru-2] +[12CH] +[si+] +[Tc+6] +%85 +%86 +[90Y] +[Pd-] +[188Re] +[RuH+] +[NiH] +[SiH3-] +[14n] +[CH3] +[14N] +[10BH2] +%88 +%89 +%90 +[34S] +[77Br] +[GaH] +[Br] +[Ge@] +[B@@H-] +[CuH] +[SiH4] +[3H-] +%87 +%91 +%92 +[67Cu] +[I] +[177Lu] +[ReH] +[67Ga+3] +[Db] +[177Lu+3] +[AlH2-] +[Si+3] +[Ti-2] +[RuH+3] +[al+] +[68Ga] +[2H+] +[B@H-] +[WH2] +[OsH] +[Ir-3] +[AlH-] +[Bk] +[75Se] +[14C@] +[Pt-] +[N@@H+] +[Nb-] +[13NH2] +%93 +[186Re] +[Tb+4] +[PtH] +[IrH2] +[Hg-2] +[AlH3-] +[PdH+] +[Md] +[RhH+2] +[11cH] +[Co-2] +[15N-] +[ZrH2] +%94 +[Hg-] +[127I] +[AsH2+] +[MoH2] +[Te+4] +[14C@@] +[As+5] +[SnH+3] +[Ge@@] +[6Li+] +[WH] +[Ne] +[14NH2] +[14NH] +[12C@@H] +[Os+7] +[RhH] +[Al-3] +[SnH+] +[15NH3+] +[Zr+] +[197Hg+] +%95 +%96 +[90Y+3] +[Os-2] +[98Tc+5] +[15NH3] +[bH-] +[33P] +[Zr-2] +[15O] +[Rh-] +[PbH3] +[PH2] +[Ni-] +[CuH+] +%97 +%98 +%99 +[Os+5] +[PtH+] +[ReH4] +[16NH] +[82Br] +[W-] +[18F-] +[15NH4+] +[Se+4] +[SeH-] +[SH4] +[67Cu+2] +[12C@H] +[AsH3] +[HgH] +[10B-] +[99Tc+6] +[117Sn+4] +[Te@] +[P@+] +[35SH] +[SeH+] +[Ni-2] +[Al-2] +[TeH2] +[Bh] +[99Tc+2] +[Os+8] +[PH-2] +[7Li+] +[14nH] +[AlH+2] +[18FH] +[SnH4] +[18O-2] +[IrH] +[13N] +[Te@@] +[Rh-3] +[15NH+] +[AsH3+] +[SeH2] +[AsH+] +[CoH2] +[16NH2] +[AsH-] +[203Hg+] +[P@@+] +[166Ho+3] +[60Co+3] +[13CH2-] +[SeH2+] +[75Br] +[TlH2] +[80Br] +[siH+] +[Ca+] +[153Sm+3] +[PdH] +[225Ac] +[13CH3-] +[AlH4-] +[FeH] +[13CH-] +[14C-] +[11C-] +[153Sm] +[Re-] +[te+] +[13CH4] +[ClH+2] +[8CH2] +[99Mo] +[ClH3+3] +[SbH3] +[25Mg+2] +[16N+] +[SnH2+] +[PH4] +[11C@H] +[122I] +[Re-2] +[RuH2+2] +[ZrH] +[Bi-] +[Pr+] +[Rn] +[Fr] +[36Cl] +[18o] +[YH] +[79Br] +[121I] +[113In+3] +[InH4-] +[TaH] +[RhH2] +[Ta-] +[67Ga] +[ZnH+] +[SnH2-] +[OsH2] +[16F] +[FeH2] +[14O] +[PbH2+2] +[BH2] +[6H] +[125Te] +[197Hg] +[TaH2] +[TaH3] +[76As] +[Nb-2] +[14N+] +[125I-] +[33S] +[IH2+2] +[NH2] +[PtH2] +[MnH] +[19C] +[17F] +[1H-] +[SnH4+2] +[Mn-2] +[15NH2+] +[TiH2] +[ReH7] +[Cd-2] +[Fe-3] +[SH2] +[17O-] +[siH-] +[CoH+] +[VH] +[10BH] +[Ru-3] +[13O] +[5H] +[CoH] +[PH5] +[15n-] +[153Gd] +[12C@] +[11CH3-] +[IrH3] +[RuH3] +[74Se] +[Se@] +[Hf+] +[77Se] +[166Ho] +[59Fe+2] +[203Hg] +[18OH-] +[8CH] +[12C@@] +[11CH4] +[15C] +[249Cf] +[PbH4] +[64Zn] +[PH3] +[99Tc+] +[14c-] +[149Pm] +[IrH4] +[Se@@] +[13OH] +[14CH3-] +[28Si] +[Rh-2] +[Fe-2] +[131I-] +[51Cr] +[62Cu+2] +[81Br] +[121Sb] +[7Li] +[89Zr+4] +[SbH3+] +[11C@@H] +[98Tc] +[59Fe+3] +[BiH2+] +[SbH+] +[TiH] +[14NH3] +[15OH] +[119Sn] +[201Hg] +[MnH+] +[201Tl] +[51Cr+3] +[123I-] +[MoH] +[AlH6-3] +[MnH2] +[WH3] +[213Bi+3] +[SnH2+2] +[123IH] +[13CH+] +[Zr-] +[74As] +[13C+] +[32P+] +[KrH] +[SiH+2] +[ClH3+2] +[13NH] +[9CH2] +[ZrH2+2] +[87Sr+2] +[35s] +[239Pu] +[198Au] +[241Am] +[203Hg+2] +[V+] +[YH2] +[SH5] +[195Pt] +[203Pb] +[RuH4] +[ThH2] +[AuH] +[66Ga+3] +[11B-] +[F] +[24Na+] +[85Sr+2] +[201Tl+] +[14CH4] +[32S] +[TeH2+] +[ClH2+3] +[AgH] +[Ge@H] +[44Ca+2] +[Os-] +[31P] +[15nH+] +[SbH4] +[TiH+] +[Ba+] +[57Co+2] +[Ta+] +[125IH] +[77As] +[129I] +[Fe-4] +[Ta-2] +[19O] +[12O] +[BiH3] +[237Np] +[252Cf] +[86Y] +[Cr-2] +[89Y] +[195Pt+2] +[si+2] +[58Fe+2] +[Hs] +[S@@H] +[OsH6] +[GdH2] +[IH3] +[8CH4] +[164Dy+3] +[47Ca+2] +[57Co] +[NbH2] +[ReH2] +[ZnH2] +[CrH2] +[17NH] +[ZrH3] +[RhH3] +[12C-] +[18O+] +[Bi-2] +[ClH4+3] +[Ni-3] +[Ag-] +[111In-] +[Mo-2] +[55Fe+3] +[204Hg+] +[35Cl-] +[211Pb] +[75Ge] +[8B] +[TeH3] +[SnH3+] +[Zr-3] +[28F] +[249Bk] +[169Yb] +[34SH] +[6Li] +[94Tc] +[197Au] +[195Pt+4] +[169Yb+3] +[32Cl] +[82Se] +[159Gd+3] +[213Bi] +[CoH+2] +[36S] +[35P] +[Ru-4] +[Cr-3] +[60Co] +[1H+] +[18CH2] +[Cd-] +[152Sm+3] +[106Ru] +[238Pu] +[220Rn] +[45Ca+2] +[89Sr+2] +[239Np] +[90Sr+2] +[137Cs+] +[165Dy] +[68GaH3] +[65Zn+2] +[89Zr] +[BiH2+2] +[62Cu] +[165Dy+3] +[238U] +[105Rh+3] +[70Zn] +[12B] +[12OH] +[18CH] +[17CH] +[OsH3] +[SbH-] +[SH6] +[AlH2-2] +[42K] +[76Br-] +[71As] +[NbH3] +[ReH3] +[OsH-] +[WH4] +[MoH3] +[OsH4] +[RuH6] +[PtH3] +[CuH2] +[CoH3] +[TiH4] +[64Zn+2] +[Si-2] +[79BrH] +[14CH2-] +[PtH2+2] +[Os-3] +[29Si] +[Ti-] +[Se+6] +[22Na+] +[42K+] +[131Cs+] +[86Rb+] +[134Cs+] +[209Po] +[208Po] +[81Rb+] +[203Tl+] +[Zr-4] +[148Sm] +[147Sm] +[37Cl-] +[12CH4] +[Ge@@H] +[63Cu] +[13CH2+] +[AsH2-] +[CeH] +[SnH-] +[UH] +[9c] +[21CH3] +[TeH+] +[57Co+3] +[8BH2] +[12BH2] +[19BH2] +[9BH2] +[YbH2] +[CrH+2] +[208Bi] +[152Gd] +[61Cu] +[115In] +[60Co+2] +[13NH2-] +[120I] +[18OH2] +[75SeH] +[SbH2+] +[144Ce] +[16n] +[113In] +[22nH] +[129I-] +[InH3] +[32PH3] +[234U] +[235U] +[59Fe] +[82Rb+] +[65Zn] +[244Cm] +[147Pm] +[91Y] +[237Pu] +[231Pa] +[253Cf] +[127Te] +[187Re] +[236Np] +[235Np] +[72Zn] +[253Es] +[159Dy] +[62Zn] +[101Tc] +[149Tb] +[124I-] +[SeH3+] +[210Pb] +[40K] +[210Po] +[214Pb] +[218Po] +[214Po] +[7Be] +[212Pb] +[205Pb] +[209Pb] +[123Te] +[202Pb] +[72As] +[201Pb] +[70As] +[73Ge] +[200Pb] +[198Pb] +[66Ga] +[73Se] +[195Pb] +[199Pb] +[144Ce+3] +[235U+2] +[90Tc] +[114In+3] +[128I] +[100Tc+] +[82Br-] +[191Pt+2] +[191Pt+4] +[193Pt+4] +[31PH3] +[125I+2] +[131I+2] +[125Te+4] +[82Sr+2] +[149Sm] +[81BrH] +[129Xe] +[193Pt+2] +[123I+2] +[Cr-] +[Co-] +[227Th+4] +[249Cf+3] +[252Cf+3] +[187Os] +[16O-] +[17O+] +[16OH-] +[98Tc+7] +[58Co+2] +[69Ga+3] +[57Fe+2] +[43K+] +[16C] +[52Fe+3] +[SeH5] +[194Pb] +[196Pb] +[197Pb] +[213Pb] +[9B] +[19B] +[11CH-] +[9CH] +[20OH] +[25OH] +[8cH] +[TiH+3] +[SnH6+3] +[N@H+] +[ZnH] +[VH3] +[52Mn+2] +[64Ga] +[13B] +[216Bi] +[117Sn+2] +[232Th] +[SnH+2] +[BiH5] +[77Kr] +[103Cd] +[62Ni] +[LaH3] +[SmH3] +[EuH3] +[MoH5] +[64Ni] +[66Zn] +[68Zn] +[186W] +[FeH4] +[MoH4] +[HgH2] +[15NH2-] +[UH2] +[204Hg] +[GaH4-] +[ThH4] +[WH6] +[PtH4] +[VH2] +[UH3] +[FeH3] +[RuH5] +[BiH4] +[80Br-] +[CeH3] +[37ClH] +[157Gd+3] +[205Tl] +[203Tl] +[62Cu+] +[64Cu+] +[61Cu+] +[37SH2] +[30Si] +[28Al] +[19OH2] +[8He] +[6He] +[153Pm] +[209Bi] +[66Zn+2] +[10CH4] +[191Ir] +[66Cu] +[16O+] +[25O] +[10c] +[Co-3] +[Sn@@] +[17OH-] +[206Po] +[204Po] +[202Po] +[201Po] +[200Po] +[199Po] +[198Po] +[197Po] +[196Po] +[195Po] +[194Po] +[193Po] +[192Po] +[191Po] +[190Po] +[217Po] +[BiH4-] +[TeH4] +[222Ra] +[62Ga] +[39Ar] +[144Sm] +[58Fe] +[153Eu] +[85Rb] +[171Yb] +[172Yb] +[114Cd] +[51Fe] +[142Ce] +[207Tl] +[92Mo] +[115Sn] +[140Ce] +[202Hg] +[180W] +[182W] +[183W] +[184W] +[96Mo] +[47Ti] +[111Cd] +[143Nd] +[145Nd] +[126Te] +[128Te] +[130Te] +[185Re] +[97Mo] +[98Mo] +[183Re] +[52V] +[80Se] +[87Kr] +[137Xe] +[196Au] +[146Ce] +[88Kr] +[51Ti] +[138Xe] +[112Cd] +[116Sn] +[120Sn] +[28SiH3] +[35S-] +[15NH-] +[13CH3+] +[34S+] +[34s] +[SiH4-] +[100Tc+5] +[NiH2+2] +[239Th] +[186Lu] +[AuH3] +[I@@-] +[XeH2] +[B+] +[16CH2] +[8C] +[TaH5] +[FeH4-] +[19C@H] +[10NH] +[FeH6-3] +[22CH] +[25N] +[25N+] +[25N-] +[21CH2] +[18cH] +[113I] +[ScH3] +[30PH3] +[43Ca+2] +[41Ca+2] +[106Cd] +[122Sn] +[18CH3] +[58Co+3] +[98Tc+4] +[70Ge] +[76Ge] +[108Cd] +[116Cd] +[130Xe] +[94Mo] +[124Sn] +[186Os] +[188Os] +[190Os] +[192Os] +[106Pd] +[110Pd] +[120Te] +[132Ba] +[134Ba] +[136Ba] +[136Ce] +[138Ce] +[156Dy] +[158Dy] +[160Dy] +[163Dy] +[162Er] +[164Er] +[167Er] +[176Hf] +[26Mg] +[144Nd] +[150Nd] +[41K] +[46Ti] +[48Ti] +[49Ti] +[50Ti] +[170Yb] +[173Yb] +[91Zr] +[92Zr] +[96Zr] +[34S-] +[CuH2-] +[38Cl] +[25Mg] +[51V] +[93Nb] +[95Mo] +[45Sc] +[123Sb] +[139La] +[9Be] +[99Y+3] +[99Y] +[156Ho] +[67Zn] +[144Ce+4] +[210Tl] +[42Ca] +[54Fe] +[193Ir] +[92Nb] +[141Cs] +[52Cr] +[35ClH] +[46Ca] +[139Cs] +[65Cu] +[71Ga] +[60Ni] +[16NH3] +[148Nd] +[72Ge] +[161Dy] +[49Ca] +[43Ca] +[8Be] +[48Ca] +[44Ca] +[120Xe] +[80Rb] +[215At] +[180Re] +[146Sm] +[19Ne] +[74Kr] +[134La] +[76Kr] +[219Fr] +[121Xe] +[220Fr] +[216At] +[223Ac] +[218At] +[37Ar] +[135I] +[110Cd] +[94Tc+7] +[86Y+3] +[135I-] +[15O-2] +[151Eu+3] +[161Tb+3] +[197Hg+2] +[109Cd+2] +[191Os+4] +[170Tm+3] +[205Bi+3] +[233U+4] +[126Sb+3] +[127Sb+3] +[132Cs+] +[136Eu+3] +[136Eu] +[125Sn+4] +[175Yb+3] +[100Mo] +[22Ne] +[13c-] +[13NH4+] +[17C] +[9C] +[31S] +[31SH] +[133I] +[126I] +[36SH] +[30S] +[32SH] +[19CH2] +[19c] +[18c] +[15F] +[10C] +[RuH-] +[62Zn+2] +[32ClH] +[33ClH] +[78BrH] +[12Li+] +[12Li] +[233Ra] +[68Ge+4] +[44Sc+3] +[91Y+3] +[106Ru+3] +[PoH2] +[AtH] +[55Fe] +[233U] +[210PoH2] +[230Th] +[228Th] +[222Rn] +[35SH2] +[227Th] +[192Ir] +[133Xe] +[81Kr] +[95Zr] +[240Pu] +[54Mn] +[103Ru] +[95Nb] +[109Cd] +[141Ce] +[85Kr] +[110Ag] +[58Co] +[241Pu] +[234Th] +[140La] +[63Ni] +[152Eu] +[132IH] +[226Rn] +[154Eu] +[36ClH] +[228Ac] +[155Eu] +[106Rh] +[243Am] +[227Ac] +[243Cm] +[236U] +[144Pr] +[232U] +[32SH2] +[88Y] +[82BrH] +[135IH] +[242Cm] +[115Cd] +[242Pu] +[46Sc] +[56Mn] +[234Pa] +[41Ar] +[147Nd] +[187W] +[151Sm] +[59Ni] +[233Pa] +[52Mn] +[94Nb] +[219Rn] +[236Pu] +[13NH3] +[93Zr] +[51Cr+6] +[TlH3] +[123Xe] +[160Tb] +[170Tm] +[182Ta] +[175Yb] +[93Mo] +[143Ce] +[191Os] +[126IH] +[48V] +[113Cd] +[47Sc] +[181Hf] +[185W] +[143Pr] +[191Pt] +[181W] +[33PH3] +[97Ru] +[97Tc] +[111Ag] +[169Er] +[107Pd] +[103Ru+2] +[34SH2] +[137Ce] +[242Am] +[117SnH2] +[57Ni] +[239U] +[60Cu] +[250Cf] +[193Au] +[69Zn] +[55Co] +[139Ce] +[127Xe] +[159Gd] +[56Co] +[177Hf] +[244Pu] +[38ClH] +[142Pr] +[199Hg] +[179Hf] +[178Hf] +[237U] +[156Eu] +[157Eu] +[105Ru] +[171Tm] +[199Au] +[155Sm] +[80BrH] +[108Ag] +[128IH] +[48Sc] +[45Ti] +[176Lu] +[121SnH2] +[148Pm] +[57Fe] +[10BH3] +[96Tc] +[133IH] +[143Pm] +[105Rh] +[130IH] +[134IH] +[131IH] +[71Zn] +[105Ag] +[97Zr] +[235Pu] +[231Th] +[109Pd] +[93Y] +[190Ir] +[135Xe] +[53Mn] +[134Ce] +[234Np] +[240Am] +[246Cf] +[240Cm] +[241Cm] +[226Th] +[39ClH] +[229Th] +[245Cm] +[240U] +[240Np] +[249Cm] +[243Pu] +[145Pm] +[199Pt] +[246Bk] +[193Pt] +[230U] +[250Cm] +[44Ti] +[175Hf] +[254Fm] +[255Fm] +[257Fm] +[92Y] +[188Ir] +[171Lu] +[257Md] +[247Bk] +[121IH] +[250Bk] +[179Lu] +[224Ac] +[195Hg] +[244Am] +[246Pu] +[194Au] +[252Fm] +[173Hf] +[246Cm] +[135Ce] +[49Cr] +[248Cf] +[247Cm] +[248Cm] +[174Ta] +[176Ta] +[154Tb] +[172Ta] +[177Ta] +[175Ta] +[180Ta] +[158Tb] +[115Ag] +[189Os] +[251Cf] +[145Pr] +[147Pr] +[76BrH] +[102Rh] +[238Np] +[185Os] +[246Am] +[233Np] +[166Dy] +[254Es] +[244Cf] +[193Os] +[245Am] +[245Bk] +[239Am] +[238Am] +[97Nb] +[245Pu] +[254Cf] +[188W] +[250Es] +[251Es] +[237Am] +[182Hf] +[258Md] +[232Np] +[238Cm] +[60Fe] +[109Pd+2] +[234Pu] +[141Ce+3] +[136Nd] +[136Pr] +[173Ta] +[110Ru] +[147Tb] +[253Fm] +[139Nd] +[178Re] +[177Re] +[200Au] +[182Re] +[156Tb] +[155Tb] +[157Tb] +[161Tb] +[161Ho] +[167Tm] +[173Lu] +[179Ta] +[171Er] +[44Sc] +[49Sc] +[49V] +[51Mn] +[90Nb] +[88Nb] +[88Zr] +[36SH2] +[174Yb] +[178Lu] +[179W] +[83BrH] +[107Cd] +[75BrH] +[62Co] +[48Cr] +[63Zn] +[102Ag] +[154Sm] +[168Er] +[65Ni] +[137La] +[187Ir] +[144Pm] +[146Pm] +[160Gd] +[166Yb] +[162Dy] +[47V] +[141Nd] +[141Sm] +[166Er] +[150Sm] +[146Eu] +[149Eu] +[174Lu] +[17NH3] +[102Ru] +[170Hf] +[188Pt] +[61Ni] +[56Ni] +[149Gd] +[151Gd] +[141Pm] +[147Gd] +[146Gd] +[161Er] +[103Ag] +[145Eu] +[153Tb] +[155Dy] +[184Re] +[180Os] +[182Os] +[186Pt] +[181Os] +[181Re] +[151Tb] +[178Ta] +[178W] +[189Pt] +[194Hg] +[145Sm] +[150Tb] +[132La] +[158Gd] +[104Ag] +[193Hg] +[94Ru] +[137Pr] +[155Ho] +[117Cd] +[99Ru] +[146Nd] +[218Rn] +[95Y] +[79Kr] +[120IH] +[138Pr] +[100Pd] +[166Tm] +[90Mo] +[151Nd] +[231U] +[138Nd] +[89Nb] +[98Nb] +[162Ho] +[142Sm] +[186Ta] +[104Tc] +[184Ta] +[185Ta] +[170Er] +[107Rh] +[131La] +[169Lu] +[74BrH] +[150Pm] +[172Tm] +[197Pt] +[230Pu] +[170Lu] +[86Zr] +[176W] +[177W] +[101Pd] +[105Pd] +[108Pd] +[149Nd] +[164Ho] +[159Ho] +[167Ho] +[176Yb] +[156Sm] +[77BrH] +[189Re] +[99Rh] +[100Rh] +[151Pm] +[232Pa] +[228Pa] +[230Pa] +[66Ni] +[194Os] +[135La] +[138La] +[141La] +[142La] +[195Ir] +[96Nb] +[157Ho] +[183Hf] +[162Tm] +[172Er] +[148Eu] +[150Eu] +[15CH4] +[89Kr] +[143La] +[58Ni] +[61Co] +[158Eu] +[165Er] +[167Yb] +[173Tm] +[175Tm] +[172Hf] +[172Lu] +[93Tc] +[177Yb] +[124IH] +[194Ir] +[147Eu] +[101Mo] +[180Hf] +[189Ir] +[87Y] +[43Sc] +[195Au] +[112Ag] +[84BrH] +[106Ag] +[109Ag] +[101Rh] +[162Yb] +[228Rn] +[139Pr] +[94Y] +[201Au] +[40PH3] +[110Ag+] +[104Cd] +[133Ba+2] +[226Ac] +[145Gd] +[186Ir] +[184Ir] +[224Rn] +[185Ir] +[182Ir] +[184Hf] +[200Pt] +[227Pa] +[178Yb] +[72Br-] +[72BrH] +[248Am] +[238Th] +[161Gd] +[35S-2] +[107Ag] +[FeH6-4] +[89Sr] +[SnH3-] +[SeH3] +[TeH3+] +[SbH4+] +[AsH4+] +[4He] +[AsH3-] +[1HH] +[3H+] +[82Rb] +[85Sr] +[90Sr] +[137Cs] +[133Ba] +[131Cs] +[SbH5] +[224Ra] +[22Na] +[210Bi] +[214Bi] +[228Ra] +[127Sb] +[136Cs] +[125Sb] +[134Cs] +[140Ba] +[45Ca] +[206Pb] +[207Pb] +[24Na] +[86Rb] +[212Bi] +[208Pb] +[124Sb] +[204Pb] +[44K] +[129Te] +[113Sn] +[204Tl] +[87Sr] +[208Tl] +[87Rb] +[47Ca] +[135Cs] +[216Po] +[137Ba] +[207Bi] +[212Po] +[79Se] +[223Ra] +[86Sr] +[122Sb] +[26Al] +[32Si] +[126Sn] +[225Ra] +[114In] +[72Ga] +[132Te] +[10Be] +[125Sn] +[73As] +[206Bi] +[117Sn] +[40Ca] +[41Ca] +[89Rb] +[116In] +[129Sb] +[91Sr] +[71Ge] +[139Ba] +[69Ga] +[120Sb] +[121Sn] +[123Sn] +[131Te] +[77Ge] +[135Ba] +[82Sr] +[43K] +[131Ba] +[92Sr] +[88Rb] +[129Cs] +[144Cs] +[127Cs] +[200Tl] +[202Tl] +[141Ba] +[117Sb] +[116Sb] +[78As] +[131Sb] +[126Sb] +[128Sb] +[130Sb] +[67Ge] +[68Ge] +[78Ge] +[66Ge] +[223Fr] +[132Cs] +[125Cs] +[138Cs] +[133Te] +[84Rb] +[83Rb] +[81Rb] +[142Ba] +[200Bi] +[115Sb] +[194Tl] +[70Se] +[112In] +[118Sb] +[70Ga] +[27Mg] +[202Bi] +[83Se] +[9Li] +[69As] +[79Rb] +[81Sr] +[83Sr] +[78Se] +[109In] +[29Al] +[118Sn] +[117In] +[119Sb] +[114Sn] +[138Ba] +[69Ge] +[73Ga] +[74Ge] +[206Tl] +[199Tl] +[130Cs] +[28Mg] +[116Te] +[112Sn] +[126Ba] +[211Bi] +[81Se] +[127Sn] +[143Cs] +[134Te] +[80Sr] +[45K] +[215Po] +[207Po] +[111Sn] +[211Po] +[128Ba] +[198Tl] +[227Ra] +[213Po] +[220Ra] +[128Sn] +[203Po] +[205Po] +[65Ga] +[197Tl] +[88Sr] +[110In] +[31Si] +[201Bi] +[121Te] +[205Bi] +[203Bi] +[195Tl] +[209Tl] +[110Sn] +[222Fr] +[207At] +[119In] +[As@] +[129IH] +[157Dy] +[111IH] +[230Ra] +[144Pr+3] +[SiH3+] +[3He] +[AsH5] +[72Se] +[95Tc] +[103Pd] +[121Sn+2] +[211Rn] +[38SH2] +[127IH] +[74Br-] +[133I-] +[100Tc+4] +[100Tc] +[36Cl-] +[89Y+3] +[104Rh] +[152Sm] +[226Ra] +[19FH] +[104Pd] +[148Gd] +[157Lu] +[33SH2] +[121I-] +[17FH] +[71Se] +[157Sm] +[148Tb] +[164Dy] +[15OH2] +[15O+] +[39K] +[40Ar] +[50Cr+3] +[50Cr] +[52Ti] +[103Pd+2] +[130Ba] +[142Pm] +[153Gd+3] +[151Eu] +[103Rh] +[124Xe] +[152Tb] +[17OH2] +[20Ne] +[52Fe] +[94Zr+4] +[94Zr] +[149Pr] +[16OH2] +[53Cr+6] +[53Cr] +[81Br-] +[112Pd] +[125Xe] +[155Gd] +[157Gd] +[168Yb] +[184Os] +[166Tb] +[221Fr] +[212Ra] +[75Br-] +[79Br-] +[113Ag] +[23Na] +[34Cl-] +[34ClH] +[38Cl-] +[56Fe] +[68Cu] +[77Br-] +[90Zr+4] +[90Zr] +[102Pd] +[154Eu+3] +[57Mn] +[165Tm] +[152Dy] +[217At] +[77se] +[13cH-] +[122Te] +[156Gd] +[124Te] +[53Ni] +[131Xe] +[174Hf+4] +[174Hf] +[76Se] +[168Tm] +[167Dy] +[154Gd] +[95Ru] +[210At] +[85Br] +[59Co] +[122Xe] +[27Al] +[54Cr] +[198Hg] +[85Rb+] +[214Tl] +[229Rn] +[218Pb] +[218Bi] +[167Tm+3] +[18o+] +[P@@H+] +[P@H+] +[13N+] +[212Pb+2] +[217Bi] +[249Cf+2] +[18OH3+] +[90Sr-] +[Cf+3] +[200Hg] +[86Tc] +[141Pr+3] +[141Pr] +[16nH] +[14NH4+] +[132Xe] +[83Kr] +[70Zn+2] +[137Ba+2] +[36Ar] +[38Ar] +[21Ne] +[126Xe] +[136Xe] +[128Xe] +[134Xe] +[84Kr] +[86Kr] +[78Kr] +[80Kr] +[82Kr] +[67Zn+2] +[65Cu+2] +[110Te] +[58Fe+3] +[142Nd] +[38K] +[198Au+3] +[122IH] +[38PH3] +[130I-] +[40K+] +[38K+] +[28Mg+2] +[208Tl+] +[13OH2] +[198Bi] +[192Bi] +[194Bi] +[196Bi] +[132I-] +[83Sr+2] +[169Er+3] +[122I-] +[120I-] +[92Sr+2] +[126I-] +[24Mg] +[84Sr] +[118Pd+2] +[118Pd] +[AsH4] +[127I-] +[9C-] +[11CH3+] +[17B] +[7B] +[4HH] +[18C-] +[22CH3-] +[22CH4] +[17C-] +[15CH3] +[16CH3] +[11NH3] +[21NH3] +[11N-] +[11NH] +[16CH] +[17CH2] +[99Ru+2] +[181Ta+2] +[181Ta] +[20CH] +[32PH2] +[55Fe+2] +[SH3] +[S@H] +[Mn-] +[IH4] +[ThH] +[GaH-] +[BiH+] +[EuH2] +[FeH4-3] +[FeH6] +[IH5] +[NiH+] +[SrH2] +[VH4] +[YH3] +[seH+] +<unk> diff --git a/models/smi_ted/inference/smi_ted_large/load.py b/models/smi_ted/inference/smi_ted_large/load.py new file mode 100644 index 0000000000000000000000000000000000000000..5ebe464b24cb7ac04223ee61528ddfea4a216f54 --- /dev/null +++ b/models/smi_ted/inference/smi_ted_large/load.py @@ -0,0 +1,672 @@ +PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +# Transformers +from fast_transformers.attention import AttentionLayer +from fast_transformers.events import QKVEvent +from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer +from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder +from fast_transformers.builders.attention_builders import AttentionBuilder +from fast_transformers.feature_maps import GeneralizedRandomFeatures +from fast_transformers.masking import LengthMask +from transformers import BertTokenizer + +# Data +import numpy as np +import pandas as pd + +# Chemistry +from rdkit import Chem +from rdkit.Chem import PandasTools +from rdkit.Chem import Descriptors +PandasTools.RenderImagesInAllDataFrames(True) + +# Standard library +from functools import partial +import regex as re +import random +import os +import gc +from tqdm import tqdm +tqdm.pandas() + + +# function to canonicalize SMILES +def normalize_smiles(smi, canonical=True, isomeric=False): + try: + normalized = Chem.MolToSmiles( + Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric + ) + except: + normalized = None + return normalized + + +class MolTranBertTokenizer(BertTokenizer): + def __init__(self, vocab_file: str = '', + do_lower_case=False, + unk_token='<pad>', + sep_token='<eos>', + pad_token='<pad>', + cls_token='<bos>', + mask_token='<mask>', + **kwargs): + super().__init__(vocab_file, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + self.regex_tokenizer = re.compile(PATTERN) + self.wordpiece_tokenizer = None + self.basic_tokenizer = None + with open(vocab_file) as f: + self.padding_idx = f.readlines().index(pad_token+'\n') + + def _tokenize(self, text): + split_tokens = self.regex_tokenizer.findall(text) + return split_tokens + + def convert_idx_to_tokens(self, idx_tensor): + tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()] + return tokens + + def convert_tokens_to_string(self, tokens): + stopwords = ['<bos>', '<eos>'] + clean_tokens = [word for word in tokens if word not in stopwords] + out_string = ''.join(clean_tokens) + return out_string + + def get_padding_idx(self): + return self.padding_idx + + def idx_to_smiles(self, torch_model, idx): + '''Convert tokens idx back to SMILES text''' + rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx) + flat_list_tokens = [item for sublist in rev_tokens for item in sublist] + decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens) + return decoded_smiles + + +## Transformer layers +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = 0 + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self.cos_cached = emb.cos()[None,:, None, :] + self.sin_cached = emb.sin()[None,:, None, :] + + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotateAttentionLayer(AttentionLayer): + """Rotate attention layer inherits from fast_transformer attention layer. + The only thing added is an Embedding encoding, for more information + on the attention layer see the fast_transformers code + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, + d_values=d_values, event_dispatcher=event_dispatcher) + + self.rotaryemb = RotaryEmbedding(d_keys) + print('Using Rotation Embedding') + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """ + Using the same frame work as the fast_Transformers attention layer + but injecting rotary information to the queries and the keys + after the keys and queries are projected. + In the argument description we make use of the following sizes + - N: the batch size + - L: The maximum length of the queries + - S: The maximum length of the keys (the actual length per sequence + is given by the length mask) + - D: The input feature dimensionality passed in the constructor as + 'd_model' + Arguments + --------- + queries: (N, L, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + Returns + ------- + The new value for each query as a tensor of shape (N, L, D). + """ + # Extract the dimensions into local variables + N, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + # Project the queries/keys/values + queries = self.query_projection(queries).view(N, L, H, -1) + keys = self.key_projection(keys).view(N, S, H, -1) + cos, sin = self.rotaryemb(queries) + queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) + values = self.value_projection(values).view(N, S, H, -1) + # Let the world know of the qkv + self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) + + + # Compute the attention + new_values = self.inner_attention( + queries, + keys, + values, + attn_mask, + query_lengths, + key_lengths + ).view(N, L, -1) + + # Project the output and return + return self.out_projection(new_values) + +class RotateEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a batch transformer encoder with Relative Rotary embeddings + for training or processing of sequences all elements at a time. + Example usage: + builder = RotateEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + return AttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + return RotateAttentionLayer + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + return TransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + return TransformerEncoderLayer + + +class AutoEncoderLayer(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.encoder = self.Encoder(feature_size, latent_size) + self.decoder = self.Decoder(feature_size, latent_size) + + class Encoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(feature_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.lat = nn.Linear(latent_size, latent_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.lat.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.lat(x) + return x # -> (N, D) + + class Decoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(latent_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.rec = nn.Linear(latent_size, feature_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.rec.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.rec(x) + return x # -> (N, L*D) + + +class LangLayer(nn.Module): + + def __init__(self, n_embd, n_vocab): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.embed = nn.Linear(n_embd, n_embd) + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, n_vocab, bias=False) + + def forward(self, tensor): + if self.is_cuda_available: + self.embed.cuda() + self.ln_f.cuda() + self.head.cuda() + tensor = tensor.cuda() + tensor = self.embed(tensor) + tensor = F.gelu(tensor) + tensor = self.ln_f(tensor) + tensor = self.head(tensor) + return tensor + + +class Net(nn.Module): + + def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2): + super().__init__() + self.desc_skip_connection = True + self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.relu1 = nn.GELU() + self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout2 = nn.Dropout(dropout) + self.relu2 = nn.GELU() + self.final = nn.Linear(smiles_embed_dim, n_output) + + def forward(self, smiles_emb, multitask=False): + x_out = self.fc1(smiles_emb) + x_out = self.dropout1(x_out) + x_out = self.relu1(x_out) + + if self.desc_skip_connection is True: + x_out = x_out + smiles_emb + + z = self.fc2(x_out) + z = self.dropout2(z) + z = self.relu2(z) + if self.desc_skip_connection is True: + z = self.final(z + x_out) + else: + z = self.final(z) + + if multitask: + return F.sigmoid(z) + return z + + +class MoLEncoder(nn.Module): + + def __init__(self, config, n_vocab): + super(MoLEncoder, self).__init__() + + # embeddings + self.config = config + self.tok_emb = nn.Embedding(n_vocab, config['n_embd']) + self.drop = nn.Dropout(config['d_dropout']) + + # transformer + builder = RotateEncoderBuilder.from_kwargs( + n_layers=config['n_layer'], + n_heads=config['n_head'], + query_dimensions=config['n_embd']//config['n_head'], + value_dimensions=config['n_embd']//config['n_head'], + feed_forward_dimensions=None, + attention_type='linear', + # unless we do deterministic_eval here, we will have random outputs + feature_map=partial(GeneralizedRandomFeatures, + n_dims=config['num_feats'], + deterministic_eval=True), + activation='gelu' + ) + self.blocks = builder.get() + + # classification + self.lang_model = LangLayer(config['n_embd'], n_vocab) + + def forward(self, idx, mask): + # transformer encoder + x = self.tok_emb(idx) # each index maps to a (learnable) vector + x = self.drop(x) + x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1])) + + # add padding + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + mask_embeddings = (token_embeddings * input_mask_expanded) + token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0) + + return token_embeddings + + +class MoLDecoder(nn.Module): + + def __init__(self, n_vocab, max_len, n_embd, n_gpu=None): + super(MoLDecoder, self).__init__() + + self.max_len = max_len + self.n_embd = n_embd + self.n_gpu = n_gpu + self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd) + self.lang_model = LangLayer(n_embd, n_vocab) + + +class Smi_ted(nn.Module): + """materials.smi-ted-Large 738M Parameters""" + + def __init__(self, tokenizer, config=None): + super(Smi_ted, self).__init__() + + # configuration + self.config = config + self.tokenizer = tokenizer + self.padding_idx = tokenizer.get_padding_idx() + self.n_vocab = len(self.tokenizer.vocab) + self.is_cuda_available = torch.cuda.is_available() + + # instantiate modules + if self.config: + self.encoder = MoLEncoder(self.config, self.n_vocab) + self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd']) + self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['d_dropout']) + + def load_checkpoint(self, ckpt_path): + # load checkpoint file + checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) + + # load hyparameters + self.config = checkpoint['hparams'] + self.max_len = self.config['max_len'] + self.n_embd = self.config['n_embd'] + self._set_seed(self.config['seed']) + + # instantiate modules + self.encoder = MoLEncoder(self.config, self.n_vocab) + self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd) + self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else 1, dropout=self.config['d_dropout']) + + # load weights + if 'state_dict' in checkpoint: + if isinstance(checkpoint['state_dict'], list): + self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False) + self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False) + else: + self.load_state_dict(checkpoint['state_dict'], strict=False) + elif 'MODEL_STATE' in checkpoint: + self.load_state_dict(checkpoint['MODEL_STATE'], strict=False) + + # load RNG states each time the model and states are loaded from checkpoint + if 'rng' in self.config: + rng = self.config['rng'] + for key, value in rng.items(): + if key =='torch_state': + torch.set_rng_state(value.cpu()) + elif key =='cuda_state': + torch.cuda.set_rng_state(value.cpu()) + elif key =='numpy_state': + np.random.set_state(value) + elif key =='python_state': + random.setstate(value) + else: + print('unrecognized state') + + def _set_seed(self, value): + print('Random Seed:', value) + random.seed(value) + torch.manual_seed(value) + torch.cuda.manual_seed(value) + torch.cuda.manual_seed_all(value) + np.random.seed(value) + cudnn.deterministic = True + cudnn.benchmark = False + + def forward(self, smiles, batch_size=100): + return self.decode(self.encode(smiles, batch_size=batch_size, return_torch=True)) + + def tokenize(self, smiles): + """Tokenize a string into tokens.""" + if isinstance(smiles, str): + batch = [smiles] + else: + batch = smiles + + tokens = self.tokenizer( + batch, + padding=True, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + max_length=self.max_len, + ) + + idx = tokens['input_ids'].clone().detach() + mask = tokens['attention_mask'].clone().detach() + + if self.is_cuda_available: + return idx.cuda(), mask.cuda() + + return idx, mask + + def extract_all(self, smiles): + """Extract all elements from each part of smi-ted. Be careful.""" + # evaluation mode + self.encoder.eval() + self.decoder.eval() + if self.is_cuda_available: + self.encoder.cuda() + self.decoder.cuda() + + # handle single str or a list of str + smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles)) + + # SMILES normalization + smiles = smiles.apply(normalize_smiles) + null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize + smiles = smiles.dropna() + + # tokenizer + idx, mask = self.tokenize(smiles.to_list()) + + ########### + # Encoder # + ########### + # encoder forward + x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector + x = self.encoder.drop(x) + x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1))) + + # mean pooling + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + true_set = sum_embeddings / sum_mask # DO NOT USE THIS FOR DOWNSTREAM TASKS, USE `pred_set` INSTEAD + + # add padding + mask_embeddings = (token_embeddings * input_mask_expanded) + token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.max_len - mask_embeddings.shape[1]), value=0) + idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=2) + + true_ids = idx + true_cte = token_embeddings + true_cte = true_cte.view(-1, self.max_len*self.n_embd) + + ########### + # Decoder # + ########### + # CTE autoencoder + pred_set = self.decoder.autoencoder.encoder(true_cte) + pred_cte = self.decoder.autoencoder.decoder(pred_set) + + # reconstruct tokens + pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd)) + pred_ids = torch.argmax(pred_ids, axis=-1) + + # replacing null SMILES with NaN values + for idx in null_idx: + true_ids = true_ids.tolist() + pred_ids = pred_ids.tolist() + true_cte = true_cte.tolist() + pred_cte = pred_cte.tolist() + true_set = true_set.tolist() + pred_set = pred_set.tolist() + + true_ids.insert(idx, np.array([np.nan]*self.config['max_len'])) + pred_ids.insert(idx, np.array([np.nan]*self.config['max_len'])) + true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd']))) + pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd']))) + true_set.insert(idx, np.array([np.nan]*self.config['n_embd'])) + pred_set.insert(idx, np.array([np.nan]*self.config['n_embd'])) + + if len(null_idx) > 0: + true_ids = torch.tensor(true_ids) + pred_ids = torch.tensor(pred_ids) + true_cte = torch.tensor(true_cte) + pred_cte = torch.tensor(pred_cte) + true_set = torch.tensor(true_set) + pred_set = torch.tensor(pred_set) + + return ((true_ids, pred_ids), # tokens + (true_cte, pred_cte), # token embeddings + (true_set, pred_set)) # smiles embeddings + + def extract_embeddings(self, smiles): + """Extract token and SMILES embeddings.""" + # evaluation mode + self.encoder.eval() + if self.is_cuda_available: + self.encoder.cuda() + + # tokenizer + idx, mask = self.tokenize(smiles) + + # encoder forward + token_embeddings = self.encoder(idx, mask) + + # aggregate token embeddings (similar to mean pooling) + # CAUTION: use the embeddings from the autoencoder. + smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd)) + + # add padding + idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=self.padding_idx) + + return idx, token_embeddings, smiles_embeddings + + def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False): + """Extract efficiently SMILES embeddings per batches.""" + # TODO: remove useCuda argument + + # handle single str or a list of str + smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles)) + + # SMILES normalization + smiles = smiles.apply(normalize_smiles) + null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize + smiles = smiles.dropna() + + # process in batches + n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0] + embeddings = [ + self.extract_embeddings(list(batch))[2].cpu().detach().numpy() + for batch in tqdm(np.array_split(smiles, n_split)) + ] + flat_list = [item for sublist in embeddings for item in sublist] + + # clear GPU memory + if self.is_cuda_available: + torch.cuda.empty_cache() + gc.collect() + + # replacing null SMILES with NaN values + for idx in null_idx: + flat_list.insert(idx, np.array([np.nan]*self.config['n_embd'])) + flat_list = np.asarray(flat_list) + + if return_torch: + return torch.tensor(flat_list) + return pd.DataFrame(flat_list) + + def decode(self, smiles_embeddings): + """Decode SMILES embeddings back to SMILES.""" + # evaluation mode + self.decoder.eval() + if self.is_cuda_available: + self.decoder.cuda() + + # reconstruct token embeddings + pred_token_embds = self.decoder.autoencoder.decoder(smiles_embeddings) + + # reconstruct tokens + pred_idx = self.decoder.lang_model(pred_token_embds.view(-1, self.max_len, self.n_embd)) + pred_idx = torch.argmax(pred_idx, axis=-1).cpu().detach().numpy() + + # convert idx to tokens + pred_smiles = [] + for i in range(pred_idx.shape[0]): + idx = pred_idx[i] + smiles = self.tokenizer.idx_to_smiles(self, idx) + smiles = smiles.replace('<bos>', '') # begin token + smiles = smiles.replace('<eos>', '') # end token + smiles = smiles.replace('<pad>', '') # pad token + pred_smiles.append(smiles) + + # clear GPU memory + if self.is_cuda_available: + torch.cuda.empty_cache() + gc.collect() + + return pred_smiles + + def __str__(self): + return 'smi-ted-Large' + + +def load_smi_ted(folder="./smi_ted_large", + ckpt_filename="smi-ted-Large_30.pt", + vocab_filename="bert_vocab_curated.txt" + ): + tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename)) + model = Smi_ted(tokenizer) + model.load_checkpoint(os.path.join(folder, ckpt_filename)) + model.eval() + print('Vocab size:', len(tokenizer.vocab)) + print(f'[INFERENCE MODE - {str(model)}]') + return model \ No newline at end of file diff --git a/models/smi_ted/inference/smi_ted_light/bert_vocab_curated.txt b/models/smi_ted/inference/smi_ted_light/bert_vocab_curated.txt new file mode 100644 index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd --- /dev/null +++ b/models/smi_ted/inference/smi_ted_light/bert_vocab_curated.txt @@ -0,0 +1,2393 @@ +<bos> +<eos> +<pad> +<mask> +C +c +( +) +1 +O +N +2 += +n +3 +[C@H] +[C@@H] +F +S +4 +Cl +- +o +s +[nH] +# +/ +Br +[C@] +[C@@] +[N+] +[O-] +5 +\ +. +I +6 +[S@] +[S@@] +P +[N-] +[Si] +7 +[n+] +[2H] +8 +[NH+] +B +9 +[C-] +[Na+] +[Cl-] +[c-] +[CH] +%10 +[NH2+] +[P+] +[B] +[I-] +%11 +[CH2-] +[O+] +[NH3+] +[C] +[Br-] +[IH2] +[S-] +[cH-] +%12 +[nH+] +[B-] +[K+] +[Sn] +[Se] +[CH-] +[HH] +[Y] +[n-] +[CH3-] +[SiH] +[S+] +%13 +[SiH2] +[Li+] +[NH-] +%14 +[Na] +[CH2] +[O-2] +[U+2] +[W] +[Al] +[P@] +[Fe+2] +[PH+] +%15 +[Cl+3] +[Zn+2] +[Ir] +[Mg+2] +[Pt+2] +[OH2+] +[As] +[Fe] +[OH+] +[Zr+2] +[3H] +[Ge] +[SiH3] +[OH-] +[NH4+] +[Cu+2] +[P@@] +p +[Pt] +%16 +[Ca+2] +[Zr] +[F-] +[C+] +[Ti] +[P-] +[V] +[se] +[U] +[O] +[Ni+2] +[Zn] +[Co] +[Ni] +[Pd+2] +[Cu] +%17 +[Cu+] +[Te] +[H+] +[CH+] +[Li] +[Pd] +[Mo] +[Ru+2] +[o+] +[Re] +[SH+] +%18 +[Ac] +[Cr] +[NH2-] +[K] +[13CH2] +[c] +[Zr+4] +[Tl] +[13C] +[Mn] +[N@+] +[Hg] +[Rh] +[Ti+4] +[Sb] +[Co+2] +[Ag+] +[Ru] +%19 +[N@@+] +[Ti+2] +[Al+3] +[Pb] +[I+] +[18F] +[s+] +[Rb+] +[Ba+2] +[H-] +[Fe+3] +[Ir+3] +[13cH] +%20 +[AlH2] +[Au+] +[13c] +[SH2+] +[Sn+2] +[Mn+2] +[Si-] +[Ag] +[N] +[Bi] +%21 +[In] +[CH2+] +[Y+3] +[Ga] +%22 +[Co+3] +[Au] +[13CH3] +[Mg] +[Cs+] +[W+2] +[Hf] +[Zn+] +[Se-] +[S-2] +[Ca] +[pH] +[ClH+] +[Ti+3] +%23 +[Ru+] +[SH-] +[13CH] +[IH+] +[Hf+4] +[Rf] +[OH3+] +%24 +[Pt+4] +[Zr+3] +[PH3+] +[Sr+2] +[Cd+2] +[Cd] +%25 +[Os] +[BH-] +[Sn+4] +[Cr+3] +[Ru+3] +[PH2+] +[Rh+2] +[V+2] +%26 +[Gd+3] +[Pb+2] +[PH] +[Hg+] +[Mo+2] +[AlH] +[Sn+] +%27 +[Pd+] +b +[Rh+3] +[Hg+2] +[15NH] +[14C] +%28 +[Mn+3] +[Si+] +[SeH] +[13C@H] +[NH] +[Ga+3] +[SiH-] +[13C@@H] +[Ce] +[Au+3] +[Bi+3] +[15N] +%29 +[BH3-] +[14cH] +[Ti+] +[Gd] +[cH+] +[Cr+2] +[Sb-] +%30 +[Be+2] +[Al+] +[te] +[11CH3] +[Sm] +[Pr] +[La] +%31 +[Al-] +[Ta] +[125I] +[BH2-] +[Nb] +[Si@] +%32 +[14c] +[Sb+3] +[Ba] +%33 +[Os+2] +[Si@@] +[La+3] +[15n] +[15NH2] +[Nd+3] +%34 +[14CH2] +[18O] +[Nd] +[GeH] +[Ni+3] +[Eu] +[Dy+3] +[Sc] +%36 +[Se-2] +[As+] +%35 +[AsH] +[Tb] +[Sb+5] +[Se+] +[Ce+3] +[c+] +[In+3] +[SnH] +[Mo+4] +%37 +[V+4] +[Eu+3] +[Hf+2] +%38 +[Pt+] +[p+] +[123I] +[Tl+] +[Sm+3] +%39 +[Yb+3] +%40 +[Yb] +[Os+] +%41 +[10B] +[Sc+3] +[Al+2] +%42 +[Sr] +[Tb+3] +[Po] +[Tc] +[PH-] +[AlH3] +[Ar] +[U+4] +[SnH2] +[Cl+2] +[si] +[Fe+] +[14CH3] +[U+3] +[Cl+] +%43 +[GeH2] +%44 +[Er+3] +[Mo+3] +[I+2] +[Fe+4] +[99Tc] +%45 +[11C] +%46 +[SnH3] +[S] +[Te+] +[Er] +[Lu+3] +[11B] +%47 +%48 +[P] +[Tm] +[Th] +[Dy] +[Pr+3] +[Ta+5] +[Nb+5] +[Rb] +[GeH3] +[Br+2] +%49 +[131I] +[Fm] +[Cs] +[BH4-] +[Lu] +[15nH] +%50 +[Ru+6] +[b-] +[Ho] +[Th+4] +[Ru+4] +%52 +[14CH] +%51 +[Cr+6] +[18OH] +[Ho+3] +[Ce+4] +[Bi+2] +[Co+] +%53 +[Yb+2] +[Fe+6] +[Be] +%54 +[SH3+] +[Np] +[As-] +%55 +[14C@@H] +[Ir+2] +[GaH3] +[p-] +[GeH4] +[Sn+3] +[Os+4] +%56 +[14C@H] +[sH+] +[19F] +[Eu+2] +[TlH] +%57 +[Cr+4] +%58 +[B@@-] +[SiH+] +[At] +[Am] +[Fe+5] +[AsH2] +[Si+4] +[B@-] +[Pu] +[SbH] +[P-2] +[Tm+3] +* +%59 +[se+] +[IH-] +%60 +[oH+] +[1H] +[15N+] +[124I] +[S@@+] +[P-3] +[H] +[IH2+] +[TeH] +[Xe] +[PH4+] +[Cr+] +[Cm] +[I+3] +%61 +[Nb+2] +[Ru+5] +%62 +[Ta+2] +[Tc+4] +[CH3+] +[Pm] +[Si@H] +[No] +%63 +[Cr+5] +[Th+2] +[Zn-2] +[13C@] +[Lr] +%64 +[99Tc+3] +%65 +[13C@@] +%66 +[Fe-] +[17O] +[siH] +[Sb+] +[OH] +[IH] +[11CH2] +[Cf] +[SiH2+] +[Gd+2] +[In+] +[Si@@H] +[Mn+] +[99Tc+4] +[Ga-] +%67 +[S@+] +[Ge+4] +[Tl+3] +[16OH] +%68 +[2H-] +[Ra] +[si-] +[NiH2] +[P@@H] +[Rh+] +[12C] +[35S] +[32P] +[SiH2-] +[AlH2+] +[16O] +%69 +[BiH] +[BiH2] +[Zn-] +[BH] +[Tc+3] +[Ir+] +[Ni+] +%70 +[InH2] +[InH] +[Nb+3] +[PbH] +[Bi+] +%71 +[As+3] +%72 +[18O-] +[68Ga+3] +%73 +[Pa] +[76Br] +[Tc+5] +[pH+] +[64Cu+2] +[Ru+8] +%74 +[PH2-] +[Si+2] +[17OH] +[RuH] +[111In+3] +[AlH+] +%75 +%76 +[W+] +[SbH2] +[PoH] +[Ru-] +[XeH] +[Tc+2] +[13C-] +[Br+] +[Pt-2] +[Es] +[Cu-] +[Mg+] +[3HH] +[P@H] +[ClH2+] +%77 +[SH] +[Au-] +[2HH] +%78 +[Sn-] +[11CH] +[PdH2] +0 +[Os+6] +%79 +[Mo+] +%80 +[al] +[PbH2] +[64Cu] +[Cl] +[12CH3] +%81 +[Tc+7] +[11c] +%82 +[Li-] +[99Tc+5] +[He] +[12c] +[Kr] +[RuH+2] +[35Cl] +[Pd-2] +[GaH2] +[4H] +[Sg] +[Cu-2] +[Br+3] +%83 +[37Cl] +[211At] +[IrH+2] +[Mt] +[Ir-2] +[In-] +[12cH] +[12CH2] +[RuH2] +[99Tc+7] +%84 +[15n+] +[ClH2+2] +[16N] +[111In] +[Tc+] +[Ru-2] +[12CH] +[si+] +[Tc+6] +%85 +%86 +[90Y] +[Pd-] +[188Re] +[RuH+] +[NiH] +[SiH3-] +[14n] +[CH3] +[14N] +[10BH2] +%88 +%89 +%90 +[34S] +[77Br] +[GaH] +[Br] +[Ge@] +[B@@H-] +[CuH] +[SiH4] +[3H-] +%87 +%91 +%92 +[67Cu] +[I] +[177Lu] +[ReH] +[67Ga+3] +[Db] +[177Lu+3] +[AlH2-] +[Si+3] +[Ti-2] +[RuH+3] +[al+] +[68Ga] +[2H+] +[B@H-] +[WH2] +[OsH] +[Ir-3] +[AlH-] +[Bk] +[75Se] +[14C@] +[Pt-] +[N@@H+] +[Nb-] +[13NH2] +%93 +[186Re] +[Tb+4] +[PtH] +[IrH2] +[Hg-2] +[AlH3-] +[PdH+] +[Md] +[RhH+2] +[11cH] +[Co-2] +[15N-] +[ZrH2] +%94 +[Hg-] +[127I] +[AsH2+] +[MoH2] +[Te+4] +[14C@@] +[As+5] +[SnH+3] +[Ge@@] +[6Li+] +[WH] +[Ne] +[14NH2] +[14NH] +[12C@@H] +[Os+7] +[RhH] +[Al-3] +[SnH+] +[15NH3+] +[Zr+] +[197Hg+] +%95 +%96 +[90Y+3] +[Os-2] +[98Tc+5] +[15NH3] +[bH-] +[33P] +[Zr-2] +[15O] +[Rh-] +[PbH3] +[PH2] +[Ni-] +[CuH+] +%97 +%98 +%99 +[Os+5] +[PtH+] +[ReH4] +[16NH] +[82Br] +[W-] +[18F-] +[15NH4+] +[Se+4] +[SeH-] +[SH4] +[67Cu+2] +[12C@H] +[AsH3] +[HgH] +[10B-] +[99Tc+6] +[117Sn+4] +[Te@] +[P@+] +[35SH] +[SeH+] +[Ni-2] +[Al-2] +[TeH2] +[Bh] +[99Tc+2] +[Os+8] +[PH-2] +[7Li+] +[14nH] +[AlH+2] +[18FH] +[SnH4] +[18O-2] +[IrH] +[13N] +[Te@@] +[Rh-3] +[15NH+] +[AsH3+] +[SeH2] +[AsH+] +[CoH2] +[16NH2] +[AsH-] +[203Hg+] +[P@@+] +[166Ho+3] +[60Co+3] +[13CH2-] +[SeH2+] +[75Br] +[TlH2] +[80Br] +[siH+] +[Ca+] +[153Sm+3] +[PdH] +[225Ac] +[13CH3-] +[AlH4-] +[FeH] +[13CH-] +[14C-] +[11C-] +[153Sm] +[Re-] +[te+] +[13CH4] +[ClH+2] +[8CH2] +[99Mo] +[ClH3+3] +[SbH3] +[25Mg+2] +[16N+] +[SnH2+] +[PH4] +[11C@H] +[122I] +[Re-2] +[RuH2+2] +[ZrH] +[Bi-] +[Pr+] +[Rn] +[Fr] +[36Cl] +[18o] +[YH] +[79Br] +[121I] +[113In+3] +[InH4-] +[TaH] +[RhH2] +[Ta-] +[67Ga] +[ZnH+] +[SnH2-] +[OsH2] +[16F] +[FeH2] +[14O] +[PbH2+2] +[BH2] +[6H] +[125Te] +[197Hg] +[TaH2] +[TaH3] +[76As] +[Nb-2] +[14N+] +[125I-] +[33S] +[IH2+2] +[NH2] +[PtH2] +[MnH] +[19C] +[17F] +[1H-] +[SnH4+2] +[Mn-2] +[15NH2+] +[TiH2] +[ReH7] +[Cd-2] +[Fe-3] +[SH2] +[17O-] +[siH-] +[CoH+] +[VH] +[10BH] +[Ru-3] +[13O] +[5H] +[CoH] +[PH5] +[15n-] +[153Gd] +[12C@] +[11CH3-] +[IrH3] +[RuH3] +[74Se] +[Se@] +[Hf+] +[77Se] +[166Ho] +[59Fe+2] +[203Hg] +[18OH-] +[8CH] +[12C@@] +[11CH4] +[15C] +[249Cf] +[PbH4] +[64Zn] +[PH3] +[99Tc+] +[14c-] +[149Pm] +[IrH4] +[Se@@] +[13OH] +[14CH3-] +[28Si] +[Rh-2] +[Fe-2] +[131I-] +[51Cr] +[62Cu+2] +[81Br] +[121Sb] +[7Li] +[89Zr+4] +[SbH3+] +[11C@@H] +[98Tc] +[59Fe+3] +[BiH2+] +[SbH+] +[TiH] +[14NH3] +[15OH] +[119Sn] +[201Hg] +[MnH+] +[201Tl] +[51Cr+3] +[123I-] +[MoH] +[AlH6-3] +[MnH2] +[WH3] +[213Bi+3] +[SnH2+2] +[123IH] +[13CH+] +[Zr-] +[74As] +[13C+] +[32P+] +[KrH] +[SiH+2] +[ClH3+2] +[13NH] +[9CH2] +[ZrH2+2] +[87Sr+2] +[35s] +[239Pu] +[198Au] +[241Am] +[203Hg+2] +[V+] +[YH2] +[SH5] +[195Pt] +[203Pb] +[RuH4] +[ThH2] +[AuH] +[66Ga+3] +[11B-] +[F] +[24Na+] +[85Sr+2] +[201Tl+] +[14CH4] +[32S] +[TeH2+] +[ClH2+3] +[AgH] +[Ge@H] +[44Ca+2] +[Os-] +[31P] +[15nH+] +[SbH4] +[TiH+] +[Ba+] +[57Co+2] +[Ta+] +[125IH] +[77As] +[129I] +[Fe-4] +[Ta-2] +[19O] +[12O] +[BiH3] +[237Np] +[252Cf] +[86Y] +[Cr-2] +[89Y] +[195Pt+2] +[si+2] +[58Fe+2] +[Hs] +[S@@H] +[OsH6] +[GdH2] +[IH3] +[8CH4] +[164Dy+3] +[47Ca+2] +[57Co] +[NbH2] +[ReH2] +[ZnH2] +[CrH2] +[17NH] +[ZrH3] +[RhH3] +[12C-] +[18O+] +[Bi-2] +[ClH4+3] +[Ni-3] +[Ag-] +[111In-] +[Mo-2] +[55Fe+3] +[204Hg+] +[35Cl-] +[211Pb] +[75Ge] +[8B] +[TeH3] +[SnH3+] +[Zr-3] +[28F] +[249Bk] +[169Yb] +[34SH] +[6Li] +[94Tc] +[197Au] +[195Pt+4] +[169Yb+3] +[32Cl] +[82Se] +[159Gd+3] +[213Bi] +[CoH+2] +[36S] +[35P] +[Ru-4] +[Cr-3] +[60Co] +[1H+] +[18CH2] +[Cd-] +[152Sm+3] +[106Ru] +[238Pu] +[220Rn] +[45Ca+2] +[89Sr+2] +[239Np] +[90Sr+2] +[137Cs+] +[165Dy] +[68GaH3] +[65Zn+2] +[89Zr] +[BiH2+2] +[62Cu] +[165Dy+3] +[238U] +[105Rh+3] +[70Zn] +[12B] +[12OH] +[18CH] +[17CH] +[OsH3] +[SbH-] +[SH6] +[AlH2-2] +[42K] +[76Br-] +[71As] +[NbH3] +[ReH3] +[OsH-] +[WH4] +[MoH3] +[OsH4] +[RuH6] +[PtH3] +[CuH2] +[CoH3] +[TiH4] +[64Zn+2] +[Si-2] +[79BrH] +[14CH2-] +[PtH2+2] +[Os-3] +[29Si] +[Ti-] +[Se+6] +[22Na+] +[42K+] +[131Cs+] +[86Rb+] +[134Cs+] +[209Po] +[208Po] +[81Rb+] +[203Tl+] +[Zr-4] +[148Sm] +[147Sm] +[37Cl-] +[12CH4] +[Ge@@H] +[63Cu] +[13CH2+] +[AsH2-] +[CeH] +[SnH-] +[UH] +[9c] +[21CH3] +[TeH+] +[57Co+3] +[8BH2] +[12BH2] +[19BH2] +[9BH2] +[YbH2] +[CrH+2] +[208Bi] +[152Gd] +[61Cu] +[115In] +[60Co+2] +[13NH2-] +[120I] +[18OH2] +[75SeH] +[SbH2+] +[144Ce] +[16n] +[113In] +[22nH] +[129I-] +[InH3] +[32PH3] +[234U] +[235U] +[59Fe] +[82Rb+] +[65Zn] +[244Cm] +[147Pm] +[91Y] +[237Pu] +[231Pa] +[253Cf] +[127Te] +[187Re] +[236Np] +[235Np] +[72Zn] +[253Es] +[159Dy] +[62Zn] +[101Tc] +[149Tb] +[124I-] +[SeH3+] +[210Pb] +[40K] +[210Po] +[214Pb] +[218Po] +[214Po] +[7Be] +[212Pb] +[205Pb] +[209Pb] +[123Te] +[202Pb] +[72As] +[201Pb] +[70As] +[73Ge] +[200Pb] +[198Pb] +[66Ga] +[73Se] +[195Pb] +[199Pb] +[144Ce+3] +[235U+2] +[90Tc] +[114In+3] +[128I] +[100Tc+] +[82Br-] +[191Pt+2] +[191Pt+4] +[193Pt+4] +[31PH3] +[125I+2] +[131I+2] +[125Te+4] +[82Sr+2] +[149Sm] +[81BrH] +[129Xe] +[193Pt+2] +[123I+2] +[Cr-] +[Co-] +[227Th+4] +[249Cf+3] +[252Cf+3] +[187Os] +[16O-] +[17O+] +[16OH-] +[98Tc+7] +[58Co+2] +[69Ga+3] +[57Fe+2] +[43K+] +[16C] +[52Fe+3] +[SeH5] +[194Pb] +[196Pb] +[197Pb] +[213Pb] +[9B] +[19B] +[11CH-] +[9CH] +[20OH] +[25OH] +[8cH] +[TiH+3] +[SnH6+3] +[N@H+] +[ZnH] +[VH3] +[52Mn+2] +[64Ga] +[13B] +[216Bi] +[117Sn+2] +[232Th] +[SnH+2] +[BiH5] +[77Kr] +[103Cd] +[62Ni] +[LaH3] +[SmH3] +[EuH3] +[MoH5] +[64Ni] +[66Zn] +[68Zn] +[186W] +[FeH4] +[MoH4] +[HgH2] +[15NH2-] +[UH2] +[204Hg] +[GaH4-] +[ThH4] +[WH6] +[PtH4] +[VH2] +[UH3] +[FeH3] +[RuH5] +[BiH4] +[80Br-] +[CeH3] +[37ClH] +[157Gd+3] +[205Tl] +[203Tl] +[62Cu+] +[64Cu+] +[61Cu+] +[37SH2] +[30Si] +[28Al] +[19OH2] +[8He] +[6He] +[153Pm] +[209Bi] +[66Zn+2] +[10CH4] +[191Ir] +[66Cu] +[16O+] +[25O] +[10c] +[Co-3] +[Sn@@] +[17OH-] +[206Po] +[204Po] +[202Po] +[201Po] +[200Po] +[199Po] +[198Po] +[197Po] +[196Po] +[195Po] +[194Po] +[193Po] +[192Po] +[191Po] +[190Po] +[217Po] +[BiH4-] +[TeH4] +[222Ra] +[62Ga] +[39Ar] +[144Sm] +[58Fe] +[153Eu] +[85Rb] +[171Yb] +[172Yb] +[114Cd] +[51Fe] +[142Ce] +[207Tl] +[92Mo] +[115Sn] +[140Ce] +[202Hg] +[180W] +[182W] +[183W] +[184W] +[96Mo] +[47Ti] +[111Cd] +[143Nd] +[145Nd] +[126Te] +[128Te] +[130Te] +[185Re] +[97Mo] +[98Mo] +[183Re] +[52V] +[80Se] +[87Kr] +[137Xe] +[196Au] +[146Ce] +[88Kr] +[51Ti] +[138Xe] +[112Cd] +[116Sn] +[120Sn] +[28SiH3] +[35S-] +[15NH-] +[13CH3+] +[34S+] +[34s] +[SiH4-] +[100Tc+5] +[NiH2+2] +[239Th] +[186Lu] +[AuH3] +[I@@-] +[XeH2] +[B+] +[16CH2] +[8C] +[TaH5] +[FeH4-] +[19C@H] +[10NH] +[FeH6-3] +[22CH] +[25N] +[25N+] +[25N-] +[21CH2] +[18cH] +[113I] +[ScH3] +[30PH3] +[43Ca+2] +[41Ca+2] +[106Cd] +[122Sn] +[18CH3] +[58Co+3] +[98Tc+4] +[70Ge] +[76Ge] +[108Cd] +[116Cd] +[130Xe] +[94Mo] +[124Sn] +[186Os] +[188Os] +[190Os] +[192Os] +[106Pd] +[110Pd] +[120Te] +[132Ba] +[134Ba] +[136Ba] +[136Ce] +[138Ce] +[156Dy] +[158Dy] +[160Dy] +[163Dy] +[162Er] +[164Er] +[167Er] +[176Hf] +[26Mg] +[144Nd] +[150Nd] +[41K] +[46Ti] +[48Ti] +[49Ti] +[50Ti] +[170Yb] +[173Yb] +[91Zr] +[92Zr] +[96Zr] +[34S-] +[CuH2-] +[38Cl] +[25Mg] +[51V] +[93Nb] +[95Mo] +[45Sc] +[123Sb] +[139La] +[9Be] +[99Y+3] +[99Y] +[156Ho] +[67Zn] +[144Ce+4] +[210Tl] +[42Ca] +[54Fe] +[193Ir] +[92Nb] +[141Cs] +[52Cr] +[35ClH] +[46Ca] +[139Cs] +[65Cu] +[71Ga] +[60Ni] +[16NH3] +[148Nd] +[72Ge] +[161Dy] +[49Ca] +[43Ca] +[8Be] +[48Ca] +[44Ca] +[120Xe] +[80Rb] +[215At] +[180Re] +[146Sm] +[19Ne] +[74Kr] +[134La] +[76Kr] +[219Fr] +[121Xe] +[220Fr] +[216At] +[223Ac] +[218At] +[37Ar] +[135I] +[110Cd] +[94Tc+7] +[86Y+3] +[135I-] +[15O-2] +[151Eu+3] +[161Tb+3] +[197Hg+2] +[109Cd+2] +[191Os+4] +[170Tm+3] +[205Bi+3] +[233U+4] +[126Sb+3] +[127Sb+3] +[132Cs+] +[136Eu+3] +[136Eu] +[125Sn+4] +[175Yb+3] +[100Mo] +[22Ne] +[13c-] +[13NH4+] +[17C] +[9C] +[31S] +[31SH] +[133I] +[126I] +[36SH] +[30S] +[32SH] +[19CH2] +[19c] +[18c] +[15F] +[10C] +[RuH-] +[62Zn+2] +[32ClH] +[33ClH] +[78BrH] +[12Li+] +[12Li] +[233Ra] +[68Ge+4] +[44Sc+3] +[91Y+3] +[106Ru+3] +[PoH2] +[AtH] +[55Fe] +[233U] +[210PoH2] +[230Th] +[228Th] +[222Rn] +[35SH2] +[227Th] +[192Ir] +[133Xe] +[81Kr] +[95Zr] +[240Pu] +[54Mn] +[103Ru] +[95Nb] +[109Cd] +[141Ce] +[85Kr] +[110Ag] +[58Co] +[241Pu] +[234Th] +[140La] +[63Ni] +[152Eu] +[132IH] +[226Rn] +[154Eu] +[36ClH] +[228Ac] +[155Eu] +[106Rh] +[243Am] +[227Ac] +[243Cm] +[236U] +[144Pr] +[232U] +[32SH2] +[88Y] +[82BrH] +[135IH] +[242Cm] +[115Cd] +[242Pu] +[46Sc] +[56Mn] +[234Pa] +[41Ar] +[147Nd] +[187W] +[151Sm] +[59Ni] +[233Pa] +[52Mn] +[94Nb] +[219Rn] +[236Pu] +[13NH3] +[93Zr] +[51Cr+6] +[TlH3] +[123Xe] +[160Tb] +[170Tm] +[182Ta] +[175Yb] +[93Mo] +[143Ce] +[191Os] +[126IH] +[48V] +[113Cd] +[47Sc] +[181Hf] +[185W] +[143Pr] +[191Pt] +[181W] +[33PH3] +[97Ru] +[97Tc] +[111Ag] +[169Er] +[107Pd] +[103Ru+2] +[34SH2] +[137Ce] +[242Am] +[117SnH2] +[57Ni] +[239U] +[60Cu] +[250Cf] +[193Au] +[69Zn] +[55Co] +[139Ce] +[127Xe] +[159Gd] +[56Co] +[177Hf] +[244Pu] +[38ClH] +[142Pr] +[199Hg] +[179Hf] +[178Hf] +[237U] +[156Eu] +[157Eu] +[105Ru] +[171Tm] +[199Au] +[155Sm] +[80BrH] +[108Ag] +[128IH] +[48Sc] +[45Ti] +[176Lu] +[121SnH2] +[148Pm] +[57Fe] +[10BH3] +[96Tc] +[133IH] +[143Pm] +[105Rh] +[130IH] +[134IH] +[131IH] +[71Zn] +[105Ag] +[97Zr] +[235Pu] +[231Th] +[109Pd] +[93Y] +[190Ir] +[135Xe] +[53Mn] +[134Ce] +[234Np] +[240Am] +[246Cf] +[240Cm] +[241Cm] +[226Th] +[39ClH] +[229Th] +[245Cm] +[240U] +[240Np] +[249Cm] +[243Pu] +[145Pm] +[199Pt] +[246Bk] +[193Pt] +[230U] +[250Cm] +[44Ti] +[175Hf] +[254Fm] +[255Fm] +[257Fm] +[92Y] +[188Ir] +[171Lu] +[257Md] +[247Bk] +[121IH] +[250Bk] +[179Lu] +[224Ac] +[195Hg] +[244Am] +[246Pu] +[194Au] +[252Fm] +[173Hf] +[246Cm] +[135Ce] +[49Cr] +[248Cf] +[247Cm] +[248Cm] +[174Ta] +[176Ta] +[154Tb] +[172Ta] +[177Ta] +[175Ta] +[180Ta] +[158Tb] +[115Ag] +[189Os] +[251Cf] +[145Pr] +[147Pr] +[76BrH] +[102Rh] +[238Np] +[185Os] +[246Am] +[233Np] +[166Dy] +[254Es] +[244Cf] +[193Os] +[245Am] +[245Bk] +[239Am] +[238Am] +[97Nb] +[245Pu] +[254Cf] +[188W] +[250Es] +[251Es] +[237Am] +[182Hf] +[258Md] +[232Np] +[238Cm] +[60Fe] +[109Pd+2] +[234Pu] +[141Ce+3] +[136Nd] +[136Pr] +[173Ta] +[110Ru] +[147Tb] +[253Fm] +[139Nd] +[178Re] +[177Re] +[200Au] +[182Re] +[156Tb] +[155Tb] +[157Tb] +[161Tb] +[161Ho] +[167Tm] +[173Lu] +[179Ta] +[171Er] +[44Sc] +[49Sc] +[49V] +[51Mn] +[90Nb] +[88Nb] +[88Zr] +[36SH2] +[174Yb] +[178Lu] +[179W] +[83BrH] +[107Cd] +[75BrH] +[62Co] +[48Cr] +[63Zn] +[102Ag] +[154Sm] +[168Er] +[65Ni] +[137La] +[187Ir] +[144Pm] +[146Pm] +[160Gd] +[166Yb] +[162Dy] +[47V] +[141Nd] +[141Sm] +[166Er] +[150Sm] +[146Eu] +[149Eu] +[174Lu] +[17NH3] +[102Ru] +[170Hf] +[188Pt] +[61Ni] +[56Ni] +[149Gd] +[151Gd] +[141Pm] +[147Gd] +[146Gd] +[161Er] +[103Ag] +[145Eu] +[153Tb] +[155Dy] +[184Re] +[180Os] +[182Os] +[186Pt] +[181Os] +[181Re] +[151Tb] +[178Ta] +[178W] +[189Pt] +[194Hg] +[145Sm] +[150Tb] +[132La] +[158Gd] +[104Ag] +[193Hg] +[94Ru] +[137Pr] +[155Ho] +[117Cd] +[99Ru] +[146Nd] +[218Rn] +[95Y] +[79Kr] +[120IH] +[138Pr] +[100Pd] +[166Tm] +[90Mo] +[151Nd] +[231U] +[138Nd] +[89Nb] +[98Nb] +[162Ho] +[142Sm] +[186Ta] +[104Tc] +[184Ta] +[185Ta] +[170Er] +[107Rh] +[131La] +[169Lu] +[74BrH] +[150Pm] +[172Tm] +[197Pt] +[230Pu] +[170Lu] +[86Zr] +[176W] +[177W] +[101Pd] +[105Pd] +[108Pd] +[149Nd] +[164Ho] +[159Ho] +[167Ho] +[176Yb] +[156Sm] +[77BrH] +[189Re] +[99Rh] +[100Rh] +[151Pm] +[232Pa] +[228Pa] +[230Pa] +[66Ni] +[194Os] +[135La] +[138La] +[141La] +[142La] +[195Ir] +[96Nb] +[157Ho] +[183Hf] +[162Tm] +[172Er] +[148Eu] +[150Eu] +[15CH4] +[89Kr] +[143La] +[58Ni] +[61Co] +[158Eu] +[165Er] +[167Yb] +[173Tm] +[175Tm] +[172Hf] +[172Lu] +[93Tc] +[177Yb] +[124IH] +[194Ir] +[147Eu] +[101Mo] +[180Hf] +[189Ir] +[87Y] +[43Sc] +[195Au] +[112Ag] +[84BrH] +[106Ag] +[109Ag] +[101Rh] +[162Yb] +[228Rn] +[139Pr] +[94Y] +[201Au] +[40PH3] +[110Ag+] +[104Cd] +[133Ba+2] +[226Ac] +[145Gd] +[186Ir] +[184Ir] +[224Rn] +[185Ir] +[182Ir] +[184Hf] +[200Pt] +[227Pa] +[178Yb] +[72Br-] +[72BrH] +[248Am] +[238Th] +[161Gd] +[35S-2] +[107Ag] +[FeH6-4] +[89Sr] +[SnH3-] +[SeH3] +[TeH3+] +[SbH4+] +[AsH4+] +[4He] +[AsH3-] +[1HH] +[3H+] +[82Rb] +[85Sr] +[90Sr] +[137Cs] +[133Ba] +[131Cs] +[SbH5] +[224Ra] +[22Na] +[210Bi] +[214Bi] +[228Ra] +[127Sb] +[136Cs] +[125Sb] +[134Cs] +[140Ba] +[45Ca] +[206Pb] +[207Pb] +[24Na] +[86Rb] +[212Bi] +[208Pb] +[124Sb] +[204Pb] +[44K] +[129Te] +[113Sn] +[204Tl] +[87Sr] +[208Tl] +[87Rb] +[47Ca] +[135Cs] +[216Po] +[137Ba] +[207Bi] +[212Po] +[79Se] +[223Ra] +[86Sr] +[122Sb] +[26Al] +[32Si] +[126Sn] +[225Ra] +[114In] +[72Ga] +[132Te] +[10Be] +[125Sn] +[73As] +[206Bi] +[117Sn] +[40Ca] +[41Ca] +[89Rb] +[116In] +[129Sb] +[91Sr] +[71Ge] +[139Ba] +[69Ga] +[120Sb] +[121Sn] +[123Sn] +[131Te] +[77Ge] +[135Ba] +[82Sr] +[43K] +[131Ba] +[92Sr] +[88Rb] +[129Cs] +[144Cs] +[127Cs] +[200Tl] +[202Tl] +[141Ba] +[117Sb] +[116Sb] +[78As] +[131Sb] +[126Sb] +[128Sb] +[130Sb] +[67Ge] +[68Ge] +[78Ge] +[66Ge] +[223Fr] +[132Cs] +[125Cs] +[138Cs] +[133Te] +[84Rb] +[83Rb] +[81Rb] +[142Ba] +[200Bi] +[115Sb] +[194Tl] +[70Se] +[112In] +[118Sb] +[70Ga] +[27Mg] +[202Bi] +[83Se] +[9Li] +[69As] +[79Rb] +[81Sr] +[83Sr] +[78Se] +[109In] +[29Al] +[118Sn] +[117In] +[119Sb] +[114Sn] +[138Ba] +[69Ge] +[73Ga] +[74Ge] +[206Tl] +[199Tl] +[130Cs] +[28Mg] +[116Te] +[112Sn] +[126Ba] +[211Bi] +[81Se] +[127Sn] +[143Cs] +[134Te] +[80Sr] +[45K] +[215Po] +[207Po] +[111Sn] +[211Po] +[128Ba] +[198Tl] +[227Ra] +[213Po] +[220Ra] +[128Sn] +[203Po] +[205Po] +[65Ga] +[197Tl] +[88Sr] +[110In] +[31Si] +[201Bi] +[121Te] +[205Bi] +[203Bi] +[195Tl] +[209Tl] +[110Sn] +[222Fr] +[207At] +[119In] +[As@] +[129IH] +[157Dy] +[111IH] +[230Ra] +[144Pr+3] +[SiH3+] +[3He] +[AsH5] +[72Se] +[95Tc] +[103Pd] +[121Sn+2] +[211Rn] +[38SH2] +[127IH] +[74Br-] +[133I-] +[100Tc+4] +[100Tc] +[36Cl-] +[89Y+3] +[104Rh] +[152Sm] +[226Ra] +[19FH] +[104Pd] +[148Gd] +[157Lu] +[33SH2] +[121I-] +[17FH] +[71Se] +[157Sm] +[148Tb] +[164Dy] +[15OH2] +[15O+] +[39K] +[40Ar] +[50Cr+3] +[50Cr] +[52Ti] +[103Pd+2] +[130Ba] +[142Pm] +[153Gd+3] +[151Eu] +[103Rh] +[124Xe] +[152Tb] +[17OH2] +[20Ne] +[52Fe] +[94Zr+4] +[94Zr] +[149Pr] +[16OH2] +[53Cr+6] +[53Cr] +[81Br-] +[112Pd] +[125Xe] +[155Gd] +[157Gd] +[168Yb] +[184Os] +[166Tb] +[221Fr] +[212Ra] +[75Br-] +[79Br-] +[113Ag] +[23Na] +[34Cl-] +[34ClH] +[38Cl-] +[56Fe] +[68Cu] +[77Br-] +[90Zr+4] +[90Zr] +[102Pd] +[154Eu+3] +[57Mn] +[165Tm] +[152Dy] +[217At] +[77se] +[13cH-] +[122Te] +[156Gd] +[124Te] +[53Ni] +[131Xe] +[174Hf+4] +[174Hf] +[76Se] +[168Tm] +[167Dy] +[154Gd] +[95Ru] +[210At] +[85Br] +[59Co] +[122Xe] +[27Al] +[54Cr] +[198Hg] +[85Rb+] +[214Tl] +[229Rn] +[218Pb] +[218Bi] +[167Tm+3] +[18o+] +[P@@H+] +[P@H+] +[13N+] +[212Pb+2] +[217Bi] +[249Cf+2] +[18OH3+] +[90Sr-] +[Cf+3] +[200Hg] +[86Tc] +[141Pr+3] +[141Pr] +[16nH] +[14NH4+] +[132Xe] +[83Kr] +[70Zn+2] +[137Ba+2] +[36Ar] +[38Ar] +[21Ne] +[126Xe] +[136Xe] +[128Xe] +[134Xe] +[84Kr] +[86Kr] +[78Kr] +[80Kr] +[82Kr] +[67Zn+2] +[65Cu+2] +[110Te] +[58Fe+3] +[142Nd] +[38K] +[198Au+3] +[122IH] +[38PH3] +[130I-] +[40K+] +[38K+] +[28Mg+2] +[208Tl+] +[13OH2] +[198Bi] +[192Bi] +[194Bi] +[196Bi] +[132I-] +[83Sr+2] +[169Er+3] +[122I-] +[120I-] +[92Sr+2] +[126I-] +[24Mg] +[84Sr] +[118Pd+2] +[118Pd] +[AsH4] +[127I-] +[9C-] +[11CH3+] +[17B] +[7B] +[4HH] +[18C-] +[22CH3-] +[22CH4] +[17C-] +[15CH3] +[16CH3] +[11NH3] +[21NH3] +[11N-] +[11NH] +[16CH] +[17CH2] +[99Ru+2] +[181Ta+2] +[181Ta] +[20CH] +[32PH2] +[55Fe+2] +[SH3] +[S@H] +[Mn-] +[IH4] +[ThH] +[GaH-] +[BiH+] +[EuH2] +[FeH4-3] +[FeH6] +[IH5] +[NiH+] +[SrH2] +[VH4] +[YH3] +[seH+] +<unk> diff --git a/models/smi_ted/inference/smi_ted_light/load.py b/models/smi_ted/inference/smi_ted_light/load.py new file mode 100644 index 0000000000000000000000000000000000000000..54d8422b6e54ca164f42c8ea6bc9e84b8c3e3103 --- /dev/null +++ b/models/smi_ted/inference/smi_ted_light/load.py @@ -0,0 +1,672 @@ +PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +# Transformers +from fast_transformers.attention import AttentionLayer +from fast_transformers.events import QKVEvent +from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer +from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder +from fast_transformers.builders.attention_builders import AttentionBuilder +from fast_transformers.feature_maps import GeneralizedRandomFeatures +from fast_transformers.masking import LengthMask +from transformers import BertTokenizer + +# Data +import numpy as np +import pandas as pd + +# Chemistry +from rdkit import Chem +from rdkit.Chem import PandasTools +from rdkit.Chem import Descriptors +PandasTools.RenderImagesInAllDataFrames(True) + +# Standard library +from functools import partial +import regex as re +import random +import os +import gc +from tqdm import tqdm +tqdm.pandas() + + +# function to canonicalize SMILES +def normalize_smiles(smi, canonical=True, isomeric=False): + try: + normalized = Chem.MolToSmiles( + Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric + ) + except: + normalized = None + return normalized + + +class MolTranBertTokenizer(BertTokenizer): + def __init__(self, vocab_file: str = '', + do_lower_case=False, + unk_token='<pad>', + sep_token='<eos>', + pad_token='<pad>', + cls_token='<bos>', + mask_token='<mask>', + **kwargs): + super().__init__(vocab_file, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + self.regex_tokenizer = re.compile(PATTERN) + self.wordpiece_tokenizer = None + self.basic_tokenizer = None + with open(vocab_file) as f: + self.padding_idx = f.readlines().index(pad_token+'\n') + + def _tokenize(self, text): + split_tokens = self.regex_tokenizer.findall(text) + return split_tokens + + def convert_idx_to_tokens(self, idx_tensor): + tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()] + return tokens + + def convert_tokens_to_string(self, tokens): + stopwords = ['<bos>', '<eos>'] + clean_tokens = [word for word in tokens if word not in stopwords] + out_string = ''.join(clean_tokens) + return out_string + + def get_padding_idx(self): + return self.padding_idx + + def idx_to_smiles(self, torch_model, idx): + '''Convert tokens idx back to SMILES text''' + rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx) + flat_list_tokens = [item for sublist in rev_tokens for item in sublist] + decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens) + return decoded_smiles + + +## Transformer layers +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = 0 + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self.cos_cached = emb.cos()[None,:, None, :] + self.sin_cached = emb.sin()[None,:, None, :] + + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotateAttentionLayer(AttentionLayer): + """Rotate attention layer inherits from fast_transformer attention layer. + The only thing added is an Embedding encoding, for more information + on the attention layer see the fast_transformers code + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, + d_values=d_values, event_dispatcher=event_dispatcher) + + self.rotaryemb = RotaryEmbedding(d_keys) + print('Using Rotation Embedding') + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """ + Using the same frame work as the fast_Transformers attention layer + but injecting rotary information to the queries and the keys + after the keys and queries are projected. + In the argument description we make use of the following sizes + - N: the batch size + - L: The maximum length of the queries + - S: The maximum length of the keys (the actual length per sequence + is given by the length mask) + - D: The input feature dimensionality passed in the constructor as + 'd_model' + Arguments + --------- + queries: (N, L, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + Returns + ------- + The new value for each query as a tensor of shape (N, L, D). + """ + # Extract the dimensions into local variables + N, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + # Project the queries/keys/values + queries = self.query_projection(queries).view(N, L, H, -1) + keys = self.key_projection(keys).view(N, S, H, -1) + cos, sin = self.rotaryemb(queries) + queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) + values = self.value_projection(values).view(N, S, H, -1) + # Let the world know of the qkv + self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) + + + # Compute the attention + new_values = self.inner_attention( + queries, + keys, + values, + attn_mask, + query_lengths, + key_lengths + ).view(N, L, -1) + + # Project the output and return + return self.out_projection(new_values) + +class RotateEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a batch transformer encoder with Relative Rotary embeddings + for training or processing of sequences all elements at a time. + Example usage: + builder = RotateEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + return AttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + return RotateAttentionLayer + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + return TransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + return TransformerEncoderLayer + + +class AutoEncoderLayer(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.encoder = self.Encoder(feature_size, latent_size) + self.decoder = self.Decoder(feature_size, latent_size) + + class Encoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(feature_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.lat = nn.Linear(latent_size, latent_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.lat.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.lat(x) + return x # -> (N, D) + + class Decoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(latent_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.rec = nn.Linear(latent_size, feature_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.rec.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.rec(x) + return x # -> (N, L*D) + + +class LangLayer(nn.Module): + + def __init__(self, n_embd, n_vocab): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.embed = nn.Linear(n_embd, n_embd) + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, n_vocab, bias=False) + + def forward(self, tensor): + if self.is_cuda_available: + self.embed.cuda() + self.ln_f.cuda() + self.head.cuda() + tensor = tensor.cuda() + tensor = self.embed(tensor) + tensor = F.gelu(tensor) + tensor = self.ln_f(tensor) + tensor = self.head(tensor) + return tensor + + +class Net(nn.Module): + + def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2): + super().__init__() + self.desc_skip_connection = True + self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.relu1 = nn.GELU() + self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim) + self.dropout2 = nn.Dropout(dropout) + self.relu2 = nn.GELU() + self.final = nn.Linear(smiles_embed_dim, n_output) + + def forward(self, smiles_emb, multitask=False): + x_out = self.fc1(smiles_emb) + x_out = self.dropout1(x_out) + x_out = self.relu1(x_out) + + if self.desc_skip_connection is True: + x_out = x_out + smiles_emb + + z = self.fc2(x_out) + z = self.dropout2(z) + z = self.relu2(z) + if self.desc_skip_connection is True: + z = self.final(z + x_out) + else: + z = self.final(z) + + if multitask: + return F.sigmoid(z) + return z + + +class MoLEncoder(nn.Module): + + def __init__(self, config, n_vocab): + super(MoLEncoder, self).__init__() + + # embeddings + self.config = config + self.tok_emb = nn.Embedding(n_vocab, config['n_embd']) + self.drop = nn.Dropout(config['d_dropout']) + + # transformer + builder = RotateEncoderBuilder.from_kwargs( + n_layers=config['n_layer'], + n_heads=config['n_head'], + query_dimensions=config['n_embd']//config['n_head'], + value_dimensions=config['n_embd']//config['n_head'], + feed_forward_dimensions=config['n_embd'], + attention_type='linear', + # unless we do deterministic_eval here, we will have random outputs + feature_map=partial(GeneralizedRandomFeatures, + n_dims=config['num_feats'], + deterministic_eval=True), + activation='gelu' + ) + self.blocks = builder.get() + + # classification + self.lang_model = LangLayer(config['n_embd'], n_vocab) + + def forward(self, idx, mask): + # transformer encoder + x = self.tok_emb(idx) # each index maps to a (learnable) vector + x = self.drop(x) + x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1])) + + # add padding + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + mask_embeddings = (token_embeddings * input_mask_expanded) + token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0) + + return token_embeddings + + +class MoLDecoder(nn.Module): + + def __init__(self, n_vocab, max_len, n_embd, n_gpu=None): + super(MoLDecoder, self).__init__() + + self.max_len = max_len + self.n_embd = n_embd + self.n_gpu = n_gpu + self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd) + self.lang_model = LangLayer(n_embd, n_vocab) + + +class Smi_ted(nn.Module): + """materials.smi-ted-Light 289M Parameters""" + + def __init__(self, tokenizer, config=None): + super(Smi_ted, self).__init__() + + # configuration + self.config = config + self.tokenizer = tokenizer + self.padding_idx = tokenizer.get_padding_idx() + self.n_vocab = len(self.tokenizer.vocab) + self.is_cuda_available = torch.cuda.is_available() + + # instantiate modules + if self.config: + self.encoder = MoLEncoder(self.config, self.n_vocab) + self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd']) + self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['d_dropout']) + + def load_checkpoint(self, ckpt_path): + # load checkpoint file + checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) + + # load hyparameters + self.config = checkpoint['hparams'] + self.max_len = self.config['max_len'] + self.n_embd = self.config['n_embd'] + self._set_seed(self.config['seed']) + + # instantiate modules + self.encoder = MoLEncoder(self.config, self.n_vocab) + self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd) + self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else 1, dropout=self.config['d_dropout']) + + # load weights + if 'state_dict' in checkpoint: + if isinstance(checkpoint['state_dict'], list): + self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False) + self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False) + else: + self.load_state_dict(checkpoint['state_dict'], strict=False) + elif 'MODEL_STATE' in checkpoint: + self.load_state_dict(checkpoint['MODEL_STATE'], strict=False) + + # load RNG states each time the model and states are loaded from checkpoint + if 'rng' in self.config: + rng = self.config['rng'] + for key, value in rng.items(): + if key =='torch_state': + torch.set_rng_state(value.cpu()) + elif key =='cuda_state': + torch.cuda.set_rng_state(value.cpu()) + elif key =='numpy_state': + np.random.set_state(value) + elif key =='python_state': + random.setstate(value) + else: + print('unrecognized state') + + def _set_seed(self, value): + print('Random Seed:', value) + random.seed(value) + torch.manual_seed(value) + torch.cuda.manual_seed(value) + torch.cuda.manual_seed_all(value) + np.random.seed(value) + cudnn.deterministic = True + cudnn.benchmark = False + + def forward(self, smiles, batch_size=100): + return self.decode(self.encode(smiles, batch_size=batch_size, return_torch=True)) + + def tokenize(self, smiles): + """Tokenize a string into tokens.""" + if isinstance(smiles, str): + batch = [smiles] + else: + batch = smiles + + tokens = self.tokenizer( + batch, + padding=True, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + max_length=self.max_len, + ) + + idx = tokens['input_ids'].clone().detach() + mask = tokens['attention_mask'].clone().detach() + + if self.is_cuda_available: + return idx.cuda(), mask.cuda() + + return idx, mask + + def extract_all(self, smiles): + """Extract all elements from each part of smi-ted. Be careful.""" + # evaluation mode + self.encoder.eval() + self.decoder.eval() + if self.is_cuda_available: + self.encoder.cuda() + self.decoder.cuda() + + # handle single str or a list of str + smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles)) + + # SMILES normalization + smiles = smiles.apply(normalize_smiles) + null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize + smiles = smiles.dropna() + + # tokenizer + idx, mask = self.tokenize(smiles.to_list()) + + ########### + # Encoder # + ########### + # encoder forward + x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector + x = self.encoder.drop(x) + x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1))) + + # mean pooling + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + true_set = sum_embeddings / sum_mask # DO NOT USE THIS FOR DOWNSTREAM TASKS, USE `pred_set` INSTEAD + + # add padding + mask_embeddings = (token_embeddings * input_mask_expanded) + token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.max_len - mask_embeddings.shape[1]), value=0) + idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=2) + + true_ids = idx + true_cte = token_embeddings + true_cte = true_cte.view(-1, self.max_len*self.n_embd) + + ########### + # Decoder # + ########### + # CTE autoencoder + pred_set = self.decoder.autoencoder.encoder(true_cte) + pred_cte = self.decoder.autoencoder.decoder(pred_set) + + # reconstruct tokens + pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd)) + pred_ids = torch.argmax(pred_ids, axis=-1) + + # replacing null SMILES with NaN values + for idx in null_idx: + true_ids = true_ids.tolist() + pred_ids = pred_ids.tolist() + true_cte = true_cte.tolist() + pred_cte = pred_cte.tolist() + true_set = true_set.tolist() + pred_set = pred_set.tolist() + + true_ids.insert(idx, np.array([np.nan]*self.config['max_len'])) + pred_ids.insert(idx, np.array([np.nan]*self.config['max_len'])) + true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd']))) + pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd']))) + true_set.insert(idx, np.array([np.nan]*self.config['n_embd'])) + pred_set.insert(idx, np.array([np.nan]*self.config['n_embd'])) + + if len(null_idx) > 0: + true_ids = torch.tensor(true_ids) + pred_ids = torch.tensor(pred_ids) + true_cte = torch.tensor(true_cte) + pred_cte = torch.tensor(pred_cte) + true_set = torch.tensor(true_set) + pred_set = torch.tensor(pred_set) + + return ((true_ids, pred_ids), # tokens + (true_cte, pred_cte), # token embeddings + (true_set, pred_set)) # smiles embeddings + + def extract_embeddings(self, smiles): + """Extract token and SMILES embeddings.""" + # evaluation mode + self.encoder.eval() + if self.is_cuda_available: + self.encoder.cuda() + + # tokenizer + idx, mask = self.tokenize(smiles) + + # encoder forward + token_embeddings = self.encoder(idx, mask) + + # aggregate token embeddings (similar to mean pooling) + # CAUTION: use the embeddings from the autoencoder. + smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd)) + + # add padding + idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=self.padding_idx) + + return idx, token_embeddings, smiles_embeddings + + def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False): + """Extract efficiently SMILES embeddings per batches.""" + # TODO: remove useCuda argument + + # handle single str or a list of str + smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles)) + + # SMILES normalization + smiles = smiles.apply(normalize_smiles) + null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize + smiles = smiles.dropna() + + # process in batches + n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0] + embeddings = [ + self.extract_embeddings(list(batch))[2].cpu().detach().numpy() + for batch in tqdm(np.array_split(smiles, n_split)) + ] + flat_list = [item for sublist in embeddings for item in sublist] + + # clear GPU memory + if self.is_cuda_available: + torch.cuda.empty_cache() + gc.collect() + + # replacing null SMILES with NaN values + for idx in null_idx: + flat_list.insert(idx, np.array([np.nan]*self.config['n_embd'])) + flat_list = np.asarray(flat_list) + + if return_torch: + return torch.tensor(flat_list) + return pd.DataFrame(flat_list) + + def decode(self, smiles_embeddings): + """Decode SMILES embeddings back to SMILES.""" + # evaluation mode + self.decoder.eval() + if self.is_cuda_available: + self.decoder.cuda() + + # reconstruct token embeddings + pred_token_embds = self.decoder.autoencoder.decoder(smiles_embeddings) + + # reconstruct tokens + pred_idx = self.decoder.lang_model(pred_token_embds.view(-1, self.max_len, self.n_embd)) + pred_idx = torch.argmax(pred_idx, axis=-1).cpu().detach().numpy() + + # convert idx to tokens + pred_smiles = [] + for i in range(pred_idx.shape[0]): + idx = pred_idx[i] + smiles = self.tokenizer.idx_to_smiles(self, idx) + smiles = smiles.replace('<bos>', '') # begin token + smiles = smiles.replace('<eos>', '') # end token + smiles = smiles.replace('<pad>', '') # pad token + pred_smiles.append(smiles) + + # clear GPU memory + if self.is_cuda_available: + torch.cuda.empty_cache() + gc.collect() + + return pred_smiles + + def __str__(self): + return 'smi-ted-Light' + + +def load_smi_ted(folder="./smi_ted_light", + ckpt_filename="smi-ted-Light_40.pt", + vocab_filename="bert_vocab_curated.txt" + ): + tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename)) + model = Smi_ted(tokenizer) + model.load_checkpoint(os.path.join(folder, ckpt_filename)) + model.eval() + print('Vocab size:', len(tokenizer.vocab)) + print(f'[INFERENCE MODE - {str(model)}]') + return model \ No newline at end of file diff --git a/models/smi_ted/notebooks/data/moses_test.csv b/models/smi_ted/notebooks/data/moses_test.csv new file mode 100644 index 0000000000000000000000000000000000000000..8160fe3ed057ad7524ce1191e389e08558a5e864 --- /dev/null +++ b/models/smi_ted/notebooks/data/moses_test.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0248e682a9c29ca7649184dd88a4edc83f48f0e84af6b923990069c3da4501b6 +size 6490097 diff --git a/models/smi_ted/notebooks/smi_ted_encoder_decoder_example.ipynb b/models/smi_ted/notebooks/smi_ted_encoder_decoder_example.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6313b3abfbc282d941f07fd01d6e65582d2585e5 --- /dev/null +++ b/models/smi_ted/notebooks/smi_ted_encoder_decoder_example.ipynb @@ -0,0 +1,334 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# granite.materials.smi-TED - Encoder & Decoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../inference')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# materials.smi-ted (smi-ted)\n", + "from smi_ted_light.load import load_smi_ted\n", + "\n", + "# Data\n", + "import pandas as pd\n", + "import numpy as np\n", + "import torch\n", + "\n", + "# Chemistry\n", + "from rdkit import Chem\n", + "from rdkit.Chem import PandasTools\n", + "from rdkit.Chem import Descriptors\n", + "from rdkit.Chem import AllChem\n", + "from rdkit.DataStructs import FingerprintSimilarity\n", + "from rdkit.DataStructs import TanimotoSimilarity" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# function to canonicalize SMILES\n", + "def normalize_smiles(smi, canonical=True, isomeric=False):\n", + " try:\n", + " normalized = Chem.MolToSmiles(\n", + " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n", + " )\n", + " except:\n", + " normalized = None\n", + " return normalized\n", + "\n", + "# function to calculate pairwise Tanimoto similarity\n", + "def calculate_tanimoto_similarities(fps1, fps2):\n", + " similarities = []\n", + " for i in range(len(fps1)):\n", + " sim = TanimotoSimilarity(fps1[i], fps2[i])\n", + " similarities.append(sim)\n", + " return similarities" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load smi-ted" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Random Seed: 12345\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Vocab size: 2393\n", + "[INFERENCE MODE - smi-ted-Light]\n" + ] + } + ], + "source": [ + "model_smi_ted = load_smi_ted(\n", + " folder='../inference/smi_ted_light',\n", + " ckpt_filename='smi-ted-Light_40.pt'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "df_moses = pd.read_csv(\"./data/moses_test.csv\", nrows=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1000, 1)\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>SMILES</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>Clc1ccccc1-c1nc(-c2ccncc2)no1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " SMILES\n", + "0 CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1\n", + "1 COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O\n", + "2 CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2\n", + "3 Clc1ccccc1-c1nc(-c2ccncc2)no1\n", + "4 CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_moses['SMILES'] = df_moses['SMILES'].apply(normalize_smiles)\n", + "df_test_normalized = df_moses.dropna()\n", + "print(df_test_normalized.shape)\n", + "df_test_normalized.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Encode SMILES - smi-ted" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10/10 [00:06<00:00, 1.52it/s]\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " encode_embeddings = model_smi_ted.encode(df_moses['SMILES'], return_torch=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Decode smi-ted embeddings into SMILES" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " decoded_smiles = model_smi_ted.decode(encode_embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1',\n", + " 'COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O',\n", + " 'CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2',\n", + " 'Clc1ccccc1-c1nc(-c2ccncc2)no1',\n", + " 'CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1']" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decoded_smiles[0:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare similarities" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean Tanimoto Similarity: 1.00\n" + ] + } + ], + "source": [ + "# Convert SMILES to RDKit molecule objects\n", + "mols1 = [Chem.MolFromSmiles(smiles) for smiles in df_moses['SMILES'].to_list()]\n", + "mols2 = [Chem.MolFromSmiles(smiles) for smiles in decoded_smiles]\n", + "\n", + "# Compute fingerprints for each molecule\n", + "fps1 = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols1]\n", + "fps2 = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols2]\n", + "\n", + "# Calculate Tanimoto similarities\n", + "tanimoto_similarities = calculate_tanimoto_similarities(fps1, fps2)\n", + "\n", + "# Calculate the mean similarity\n", + "mean_similarity = np.mean(tanimoto_similarities)\n", + "\n", + "# Print the mean similarity\n", + "print(f\"Mean Tanimoto Similarity: {mean_similarity:.2f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/smi_ted/notebooks/smi_ted_frozen_inference_example1.ipynb b/models/smi_ted/notebooks/smi_ted_frozen_inference_example1.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5ab26dd0c948661119caeef0ab6437384d74151a --- /dev/null +++ b/models/smi_ted/notebooks/smi_ted_frozen_inference_example1.ipynb @@ -0,0 +1,1412 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# granite.materials.smi-TED - INFERENCE (Classification)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install extra packages for notebook\n", + "%pip install seaborn xgboost" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../inference')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# materials.smi-ted\n", + "from smi_ted_light.load import load_smi_ted\n", + "\n", + "# Data\n", + "import torch\n", + "import pandas as pd\n", + "\n", + "# Chemistry\n", + "from rdkit import Chem\n", + "from rdkit.Chem import PandasTools\n", + "from rdkit.Chem import Descriptors\n", + "PandasTools.RenderImagesInAllDataFrames(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# function to canonicalize SMILES\n", + "def normalize_smiles(smi, canonical=True, isomeric=False):\n", + " try:\n", + " normalized = Chem.MolToSmiles(\n", + " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n", + " )\n", + " except:\n", + " normalized = None\n", + " return normalized" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import smi-ted" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Random Seed: 12345\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Vocab size: 2393\n", + "[INFERENCE MODE - smi-ted-Light]\n" + ] + } + ], + "source": [ + "model_smi_ted = load_smi_ted(\n", + " folder='../inference/smi_ted_light',\n", + " ckpt_filename='smi-ted-Light_40.pt'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## BBBP Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Experiments - Data Load" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "df_train = pd.read_csv(\"../finetune/moleculenet/bbbp/train.csv\")\n", + "df_test = pd.read_csv(\"../finetune/moleculenet/bbbp/test.csv\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SMILES canonization" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1634, 5)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[22:56:14] Explicit valence for atom # 1 N, 4, is greater than permitted\n", + "[22:56:14] Explicit valence for atom # 6 N, 4, is greater than permitted\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] Explicit valence for atom # 6 N, 4, is greater than permitted\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] Explicit valence for atom # 11 N, 4, is greater than permitted\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] Explicit valence for atom # 5 N, 4, is greater than permitted\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>num</th>\n", + " <th>name</th>\n", + " <th>p_np</th>\n", + " <th>smiles</th>\n", + " <th>norm_smiles</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>1</td>\n", + " <td>Propanolol</td>\n", + " <td>1</td>\n", + " <td>[Cl].CC(C)NCC(O)COc1cccc2ccccc12</td>\n", + " <td>CC(C)NCC(O)COc1cccc2ccccc12.[Cl]</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>2</td>\n", + " <td>Terbutylchlorambucil</td>\n", + " <td>1</td>\n", + " <td>C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl</td>\n", + " <td>CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>3</td>\n", + " <td>40730</td>\n", + " <td>1</td>\n", + " <td>c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...</td>\n", + " <td>CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>4</td>\n", + " <td>24</td>\n", + " <td>1</td>\n", + " <td>C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C</td>\n", + " <td>CC(=O)NCCCOc1cccc(CN2CCCCC2)c1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>6</td>\n", + " <td>cefoperazone</td>\n", + " <td>1</td>\n", + " <td>CCN1CCN(C(=O)N[C@@H](C(=O)N[C@H]2[C@H]3SCC(=C(...</td>\n", + " <td>CCN1CCN(C(=O)NC(C(=O)NC2C(=O)N3C(C(=O)O)=C(CSc...</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " num name p_np \\\n", + "0 1 Propanolol 1 \n", + "1 2 Terbutylchlorambucil 1 \n", + "2 3 40730 1 \n", + "3 4 24 1 \n", + "4 6 cefoperazone 1 \n", + "\n", + " smiles \\\n", + "0 [Cl].CC(C)NCC(O)COc1cccc2ccccc12 \n", + "1 C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl \n", + "2 c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO... \n", + "3 C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C \n", + "4 CCN1CCN(C(=O)N[C@@H](C(=O)N[C@H]2[C@H]3SCC(=C(... \n", + "\n", + " norm_smiles \n", + "0 CC(C)NCC(O)COc1cccc2ccccc12.[Cl] \n", + "1 CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1 \n", + "2 CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23 \n", + "3 CC(=O)NCCCOc1cccc(CN2CCCCC2)c1 \n", + "4 CCN1CCN(C(=O)NC(C(=O)NC2C(=O)N3C(C(=O)O)=C(CSc... " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n", + "df_train_normalized = df_train.dropna()\n", + "print(df_train_normalized.shape)\n", + "df_train_normalized.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(192, 5)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[22:56:17] Explicit valence for atom # 12 N, 4, is greater than permitted\n", + "[22:56:17] Explicit valence for atom # 5 N, 4, is greater than permitted\n", + "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n", + "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>num</th>\n", + " <th>name</th>\n", + " <th>p_np</th>\n", + " <th>smiles</th>\n", + " <th>norm_smiles</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>13</td>\n", + " <td>18</td>\n", + " <td>1</td>\n", + " <td>C(Cl)Cl</td>\n", + " <td>ClCCl</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>23</td>\n", + " <td>SKF-93619</td>\n", + " <td>0</td>\n", + " <td>c1cc2c(cc(CC3=CNC(=NC3=O)NCCSCc3oc(cc3)CN(C)C)...</td>\n", + " <td>CN(C)Cc1ccc(CSCCNc2nc(=O)c(Cc3ccc4ccccc4c3)c[n...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>36</td>\n", + " <td>etomidate</td>\n", + " <td>1</td>\n", + " <td>CCOC(=O)c1cncn1C(C)c2ccccc2</td>\n", + " <td>CCOC(=O)c1cncn1C(C)c1ccccc1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>37</td>\n", + " <td>11a</td>\n", + " <td>0</td>\n", + " <td>CN(C)c1cc(C2=NC(N)=NN2)ccn1</td>\n", + " <td>CN(C)c1cc(-c2nc(N)n[nH]2)ccn1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>79</td>\n", + " <td>compound 45</td>\n", + " <td>1</td>\n", + " <td>N1(Cc2cc(OCCCNc3oc4ccccc4n3)ccc2)CCCCC1</td>\n", + " <td>c1cc(CN2CCCCC2)cc(OCCCNc2nc3ccccc3o2)c1</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " num name p_np smiles \\\n", + "0 13 18 1 C(Cl)Cl \n", + "1 23 SKF-93619 0 c1cc2c(cc(CC3=CNC(=NC3=O)NCCSCc3oc(cc3)CN(C)C)... \n", + "2 36 etomidate 1 CCOC(=O)c1cncn1C(C)c2ccccc2 \n", + "3 37 11a 0 CN(C)c1cc(C2=NC(N)=NN2)ccn1 \n", + "4 79 compound 45 1 N1(Cc2cc(OCCCNc3oc4ccccc4n3)ccc2)CCCCC1 \n", + "\n", + " norm_smiles \n", + "0 ClCCl \n", + "1 CN(C)Cc1ccc(CSCCNc2nc(=O)c(Cc3ccc4ccccc4c3)c[n... \n", + "2 CCOC(=O)c1cncn1C(C)c1ccccc1 \n", + "3 CN(C)c1cc(-c2nc(N)n[nH]2)ccn1 \n", + "4 c1cc(CN2CCCCC2)cc(OCCCNc2nc3ccccc3o2)c1 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n", + "df_test_normalized = df_test.dropna()\n", + "print(df_test_normalized.shape)\n", + "df_test_normalized.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Embeddings extraction " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### smi-ted embeddings extraction" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 16/16 [00:21<00:00, 1.35s/it]\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>0</th>\n", + " <th>1</th>\n", + " <th>2</th>\n", + " <th>3</th>\n", + " <th>4</th>\n", + " <th>5</th>\n", + " <th>6</th>\n", + " <th>7</th>\n", + " <th>8</th>\n", + " <th>9</th>\n", + " <th>...</th>\n", + " <th>758</th>\n", + " <th>759</th>\n", + " <th>760</th>\n", + " <th>761</th>\n", + " <th>762</th>\n", + " <th>763</th>\n", + " <th>764</th>\n", + " <th>765</th>\n", + " <th>766</th>\n", + " <th>767</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0.437218</td>\n", + " <td>-0.591727</td>\n", + " <td>0.064328</td>\n", + " <td>0.374019</td>\n", + " <td>0.530676</td>\n", + " <td>-0.644067</td>\n", + " <td>1.308136</td>\n", + " <td>0.089772</td>\n", + " <td>0.790524</td>\n", + " <td>0.208749</td>\n", + " <td>...</td>\n", + " <td>-1.325162</td>\n", + " <td>-0.083578</td>\n", + " <td>0.169544</td>\n", + " <td>0.359247</td>\n", + " <td>-0.652742</td>\n", + " <td>0.720496</td>\n", + " <td>-0.674184</td>\n", + " <td>0.693000</td>\n", + " <td>0.586143</td>\n", + " <td>-0.159641</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0.344508</td>\n", + " <td>-0.417009</td>\n", + " <td>0.095745</td>\n", + " <td>0.355959</td>\n", + " <td>0.573049</td>\n", + " <td>-0.590279</td>\n", + " <td>1.069699</td>\n", + " <td>0.067724</td>\n", + " <td>0.788815</td>\n", + " <td>0.159197</td>\n", + " <td>...</td>\n", + " <td>-1.312421</td>\n", + " <td>-0.108732</td>\n", + " <td>0.217020</td>\n", + " <td>0.303697</td>\n", + " <td>-0.598966</td>\n", + " <td>0.647903</td>\n", + " <td>-0.665967</td>\n", + " <td>0.791804</td>\n", + " <td>0.620691</td>\n", + " <td>-0.107859</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0.429205</td>\n", + " <td>-0.463542</td>\n", + " <td>0.056441</td>\n", + " <td>0.449925</td>\n", + " <td>0.536788</td>\n", + " <td>-0.749906</td>\n", + " <td>1.193816</td>\n", + " <td>0.082596</td>\n", + " <td>0.860276</td>\n", + " <td>0.162548</td>\n", + " <td>...</td>\n", + " <td>-1.304979</td>\n", + " <td>-0.148620</td>\n", + " <td>0.242045</td>\n", + " <td>0.344730</td>\n", + " <td>-0.704636</td>\n", + " <td>0.644773</td>\n", + " <td>-0.781017</td>\n", + " <td>0.737207</td>\n", + " <td>0.585380</td>\n", + " <td>-0.101722</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0.433097</td>\n", + " <td>-0.523078</td>\n", + " <td>0.089728</td>\n", + " <td>0.410127</td>\n", + " <td>0.543400</td>\n", + " <td>-0.643014</td>\n", + " <td>1.203858</td>\n", + " <td>0.034177</td>\n", + " <td>0.769413</td>\n", + " <td>0.202445</td>\n", + " <td>...</td>\n", + " <td>-1.358915</td>\n", + " <td>-0.077463</td>\n", + " <td>0.228710</td>\n", + " <td>0.317884</td>\n", + " <td>-0.680220</td>\n", + " <td>0.531601</td>\n", + " <td>-0.709799</td>\n", + " <td>0.731386</td>\n", + " <td>0.567806</td>\n", + " <td>-0.087713</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0.388423</td>\n", + " <td>-0.505908</td>\n", + " <td>0.072539</td>\n", + " <td>0.366502</td>\n", + " <td>0.533689</td>\n", + " <td>-0.701559</td>\n", + " <td>1.035554</td>\n", + " <td>0.038419</td>\n", + " <td>0.822917</td>\n", + " <td>0.163062</td>\n", + " <td>...</td>\n", + " <td>-1.271012</td>\n", + " <td>-0.176412</td>\n", + " <td>0.119734</td>\n", + " <td>0.294143</td>\n", + " <td>-0.677721</td>\n", + " <td>0.647655</td>\n", + " <td>-0.844419</td>\n", + " <td>0.756321</td>\n", + " <td>0.570513</td>\n", + " <td>-0.240003</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 768 columns</p>\n", + "</div>" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "0 0.437218 -0.591727 0.064328 0.374019 0.530676 -0.644067 1.308136 \n", + "1 0.344508 -0.417009 0.095745 0.355959 0.573049 -0.590279 1.069699 \n", + "2 0.429205 -0.463542 0.056441 0.449925 0.536788 -0.749906 1.193816 \n", + "3 0.433097 -0.523078 0.089728 0.410127 0.543400 -0.643014 1.203858 \n", + "4 0.388423 -0.505908 0.072539 0.366502 0.533689 -0.701559 1.035554 \n", + "\n", + " 7 8 9 ... 758 759 760 761 \\\n", + "0 0.089772 0.790524 0.208749 ... -1.325162 -0.083578 0.169544 0.359247 \n", + "1 0.067724 0.788815 0.159197 ... -1.312421 -0.108732 0.217020 0.303697 \n", + "2 0.082596 0.860276 0.162548 ... -1.304979 -0.148620 0.242045 0.344730 \n", + "3 0.034177 0.769413 0.202445 ... -1.358915 -0.077463 0.228710 0.317884 \n", + "4 0.038419 0.822917 0.163062 ... -1.271012 -0.176412 0.119734 0.294143 \n", + "\n", + " 762 763 764 765 766 767 \n", + "0 -0.652742 0.720496 -0.674184 0.693000 0.586143 -0.159641 \n", + "1 -0.598966 0.647903 -0.665967 0.791804 0.620691 -0.107859 \n", + "2 -0.704636 0.644773 -0.781017 0.737207 0.585380 -0.101722 \n", + "3 -0.680220 0.531601 -0.709799 0.731386 0.567806 -0.087713 \n", + "4 -0.677721 0.647655 -0.844419 0.756321 0.570513 -0.240003 \n", + "\n", + "[5 rows x 768 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.no_grad():\n", + " df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n", + "df_embeddings_train.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:04<00:00, 4.23s/it]\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>0</th>\n", + " <th>1</th>\n", + " <th>2</th>\n", + " <th>3</th>\n", + " <th>4</th>\n", + " <th>5</th>\n", + " <th>6</th>\n", + " <th>7</th>\n", + " <th>8</th>\n", + " <th>9</th>\n", + " <th>...</th>\n", + " <th>758</th>\n", + " <th>759</th>\n", + " <th>760</th>\n", + " <th>761</th>\n", + " <th>762</th>\n", + " <th>763</th>\n", + " <th>764</th>\n", + " <th>765</th>\n", + " <th>766</th>\n", + " <th>767</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0.374249</td>\n", + " <td>-0.319257</td>\n", + " <td>-0.007041</td>\n", + " <td>0.444741</td>\n", + " <td>0.326734</td>\n", + " <td>-0.791476</td>\n", + " <td>1.121707</td>\n", + " <td>-0.082401</td>\n", + " <td>0.611457</td>\n", + " <td>0.289225</td>\n", + " <td>...</td>\n", + " <td>-1.462539</td>\n", + " <td>-0.302055</td>\n", + " <td>0.295551</td>\n", + " <td>-0.058293</td>\n", + " <td>-0.830319</td>\n", + " <td>0.545099</td>\n", + " <td>-0.460271</td>\n", + " <td>1.121117</td>\n", + " <td>0.685016</td>\n", + " <td>-0.452698</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0.429158</td>\n", + " <td>-0.568104</td>\n", + " <td>0.112739</td>\n", + " <td>0.352429</td>\n", + " <td>0.512565</td>\n", + " <td>-0.604153</td>\n", + " <td>1.181846</td>\n", + " <td>0.067963</td>\n", + " <td>0.786978</td>\n", + " <td>0.128077</td>\n", + " <td>...</td>\n", + " <td>-1.226941</td>\n", + " <td>-0.078927</td>\n", + " <td>0.209468</td>\n", + " <td>0.266113</td>\n", + " <td>-0.762261</td>\n", + " <td>0.610685</td>\n", + " <td>-0.755705</td>\n", + " <td>0.734550</td>\n", + " <td>0.592976</td>\n", + " <td>-0.148252</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0.411906</td>\n", + " <td>-0.510477</td>\n", + " <td>0.073015</td>\n", + " <td>0.346871</td>\n", + " <td>0.512772</td>\n", + " <td>-0.617252</td>\n", + " <td>1.191621</td>\n", + " <td>0.040103</td>\n", + " <td>0.722577</td>\n", + " <td>0.188638</td>\n", + " <td>...</td>\n", + " <td>-1.300554</td>\n", + " <td>-0.150735</td>\n", + " <td>0.148252</td>\n", + " <td>0.282791</td>\n", + " <td>-0.694712</td>\n", + " <td>0.556029</td>\n", + " <td>-0.660645</td>\n", + " <td>0.771226</td>\n", + " <td>0.558996</td>\n", + " <td>-0.000660</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0.356793</td>\n", + " <td>-0.530959</td>\n", + " <td>0.050350</td>\n", + " <td>0.433593</td>\n", + " <td>0.592601</td>\n", + " <td>-0.573508</td>\n", + " <td>1.221865</td>\n", + " <td>0.025491</td>\n", + " <td>0.833164</td>\n", + " <td>0.214604</td>\n", + " <td>...</td>\n", + " <td>-1.406141</td>\n", + " <td>-0.107165</td>\n", + " <td>0.200131</td>\n", + " <td>0.289469</td>\n", + " <td>-0.770149</td>\n", + " <td>0.572746</td>\n", + " <td>-0.776739</td>\n", + " <td>0.855064</td>\n", + " <td>0.662797</td>\n", + " <td>-0.194417</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0.422133</td>\n", + " <td>-0.490610</td>\n", + " <td>0.044333</td>\n", + " <td>0.367861</td>\n", + " <td>0.579025</td>\n", + " <td>-0.629409</td>\n", + " <td>1.139824</td>\n", + " <td>0.039823</td>\n", + " <td>0.728825</td>\n", + " <td>0.145327</td>\n", + " <td>...</td>\n", + " <td>-1.312777</td>\n", + " <td>-0.105049</td>\n", + " <td>0.175286</td>\n", + " <td>0.336176</td>\n", + " <td>-0.738813</td>\n", + " <td>0.530226</td>\n", + " <td>-0.763357</td>\n", + " <td>0.764998</td>\n", + " <td>0.583681</td>\n", + " <td>-0.109683</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 768 columns</p>\n", + "</div>" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "0 0.374249 -0.319257 -0.007041 0.444741 0.326734 -0.791476 1.121707 \n", + "1 0.429158 -0.568104 0.112739 0.352429 0.512565 -0.604153 1.181846 \n", + "2 0.411906 -0.510477 0.073015 0.346871 0.512772 -0.617252 1.191621 \n", + "3 0.356793 -0.530959 0.050350 0.433593 0.592601 -0.573508 1.221865 \n", + "4 0.422133 -0.490610 0.044333 0.367861 0.579025 -0.629409 1.139824 \n", + "\n", + " 7 8 9 ... 758 759 760 761 \\\n", + "0 -0.082401 0.611457 0.289225 ... -1.462539 -0.302055 0.295551 -0.058293 \n", + "1 0.067963 0.786978 0.128077 ... -1.226941 -0.078927 0.209468 0.266113 \n", + "2 0.040103 0.722577 0.188638 ... -1.300554 -0.150735 0.148252 0.282791 \n", + "3 0.025491 0.833164 0.214604 ... -1.406141 -0.107165 0.200131 0.289469 \n", + "4 0.039823 0.728825 0.145327 ... -1.312777 -0.105049 0.175286 0.336176 \n", + "\n", + " 762 763 764 765 766 767 \n", + "0 -0.830319 0.545099 -0.460271 1.121117 0.685016 -0.452698 \n", + "1 -0.762261 0.610685 -0.755705 0.734550 0.592976 -0.148252 \n", + "2 -0.694712 0.556029 -0.660645 0.771226 0.558996 -0.000660 \n", + "3 -0.770149 0.572746 -0.776739 0.855064 0.662797 -0.194417 \n", + "4 -0.738813 0.530226 -0.763357 0.764998 0.583681 -0.109683 \n", + "\n", + "[5 rows x 768 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.no_grad():\n", + " df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n", + "df_embeddings_test.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Experiments - BBBP prediction using smi-ted latent spaces" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### XGBoost prediction using the whole Latent Space" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from xgboost import XGBClassifier\n", + "from sklearn.metrics import roc_auc_score" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<style>#sk-container-id-1 {\n", + " /* Definition of color scheme common for light and dark mode */\n", + " --sklearn-color-text: black;\n", + " --sklearn-color-line: gray;\n", + " /* Definition of color scheme for unfitted estimators */\n", + " --sklearn-color-unfitted-level-0: #fff5e6;\n", + " --sklearn-color-unfitted-level-1: #f6e4d2;\n", + " --sklearn-color-unfitted-level-2: #ffe0b3;\n", + " --sklearn-color-unfitted-level-3: chocolate;\n", + " /* Definition of color scheme for fitted estimators */\n", + " --sklearn-color-fitted-level-0: #f0f8ff;\n", + " --sklearn-color-fitted-level-1: #d4ebff;\n", + " --sklearn-color-fitted-level-2: #b3dbfd;\n", + " --sklearn-color-fitted-level-3: cornflowerblue;\n", + "\n", + " /* Specific color for light theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-icon: #696969;\n", + "\n", + " @media (prefers-color-scheme: dark) {\n", + " /* Redefinition of color scheme for dark theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-icon: #878787;\n", + " }\n", + "}\n", + "\n", + "#sk-container-id-1 {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "#sk-container-id-1 pre {\n", + " padding: 0;\n", + "}\n", + "\n", + "#sk-container-id-1 input.sk-hidden--visually {\n", + " border: 0;\n", + " clip: rect(1px 1px 1px 1px);\n", + " clip: rect(1px, 1px, 1px, 1px);\n", + " height: 1px;\n", + " margin: -1px;\n", + " overflow: hidden;\n", + " padding: 0;\n", + " position: absolute;\n", + " width: 1px;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-dashed-wrapped {\n", + " border: 1px dashed var(--sklearn-color-line);\n", + " margin: 0 0.4em 0.5em 0.4em;\n", + " box-sizing: border-box;\n", + " padding-bottom: 0.4em;\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-container {\n", + " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", + " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", + " so we also need the `!important` here to be able to override the\n", + " default hidden behavior on the sphinx rendered scikit-learn.org.\n", + " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", + " display: inline-block !important;\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-text-repr-fallback {\n", + " display: none;\n", + "}\n", + "\n", + "div.sk-parallel-item,\n", + "div.sk-serial,\n", + "div.sk-item {\n", + " /* draw centered vertical line to link estimators */\n", + " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", + " background-size: 2px 100%;\n", + " background-repeat: no-repeat;\n", + " background-position: center center;\n", + "}\n", + "\n", + "/* Parallel-specific style estimator block */\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item::after {\n", + " content: \"\";\n", + " width: 100%;\n", + " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", + " flex-grow: 1;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel {\n", + " display: flex;\n", + " align-items: stretch;\n", + " justify-content: center;\n", + " background-color: var(--sklearn-color-background);\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item {\n", + " display: flex;\n", + " flex-direction: column;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n", + " align-self: flex-end;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n", + " align-self: flex-start;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n", + " width: 0;\n", + "}\n", + "\n", + "/* Serial-specific style estimator block */\n", + "\n", + "#sk-container-id-1 div.sk-serial {\n", + " display: flex;\n", + " flex-direction: column;\n", + " align-items: center;\n", + " background-color: var(--sklearn-color-background);\n", + " padding-right: 1em;\n", + " padding-left: 1em;\n", + "}\n", + "\n", + "\n", + "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", + "clickable and can be expanded/collapsed.\n", + "- Pipeline and ColumnTransformer use this feature and define the default style\n", + "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", + "*/\n", + "\n", + "/* Pipeline and ColumnTransformer style (default) */\n", + "\n", + "#sk-container-id-1 div.sk-toggleable {\n", + " /* Default theme specific background. It is overwritten whether we have a\n", + " specific estimator or a Pipeline/ColumnTransformer */\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "/* Toggleable label */\n", + "#sk-container-id-1 label.sk-toggleable__label {\n", + " cursor: pointer;\n", + " display: block;\n", + " width: 100%;\n", + " margin-bottom: 0;\n", + " padding: 0.5em;\n", + " box-sizing: border-box;\n", + " text-align: center;\n", + "}\n", + "\n", + "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n", + " /* Arrow on the left of the label */\n", + " content: \"▸\";\n", + " float: left;\n", + " margin-right: 0.25em;\n", + " color: var(--sklearn-color-icon);\n", + "}\n", + "\n", + "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "/* Toggleable content - dropdown */\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content {\n", + " max-height: 0;\n", + " max-width: 0;\n", + " overflow: hidden;\n", + " text-align: left;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content pre {\n", + " margin: 0.2em;\n", + " border-radius: 0.25em;\n", + " color: var(--sklearn-color-text);\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", + " /* Expand drop-down */\n", + " max-height: 200px;\n", + " max-width: 100%;\n", + " overflow: auto;\n", + "}\n", + "\n", + "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", + " content: \"▾\";\n", + "}\n", + "\n", + "/* Pipeline/ColumnTransformer-specific style */\n", + "\n", + "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator-specific style */\n", + "\n", + "/* Colorize estimator box */\n", + "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n", + "#sk-container-id-1 div.sk-label label {\n", + " /* The background is the default theme color */\n", + " color: var(--sklearn-color-text-on-default-background);\n", + "}\n", + "\n", + "/* On hover, darken the color of the background */\n", + "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "/* Label box, darken color on hover, fitted */\n", + "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator label */\n", + "\n", + "#sk-container-id-1 div.sk-label label {\n", + " font-family: monospace;\n", + " font-weight: bold;\n", + " display: inline-block;\n", + " line-height: 1.2em;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-label-container {\n", + " text-align: center;\n", + "}\n", + "\n", + "/* Estimator-specific */\n", + "#sk-container-id-1 div.sk-estimator {\n", + " font-family: monospace;\n", + " border: 1px dotted var(--sklearn-color-border-box);\n", + " border-radius: 0.25em;\n", + " box-sizing: border-box;\n", + " margin-bottom: 0.5em;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-estimator.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "/* on hover */\n", + "#sk-container-id-1 div.sk-estimator:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-estimator.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", + "\n", + "/* Common style for \"i\" and \"?\" */\n", + "\n", + ".sk-estimator-doc-link,\n", + "a:link.sk-estimator-doc-link,\n", + "a:visited.sk-estimator-doc-link {\n", + " float: right;\n", + " font-size: smaller;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1em;\n", + " height: 1em;\n", + " width: 1em;\n", + " text-decoration: none !important;\n", + " margin-left: 1ex;\n", + " /* unfitted */\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted,\n", + "a:link.sk-estimator-doc-link.fitted,\n", + "a:visited.sk-estimator-doc-link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "/* Span, style for the box shown on hovering the info icon */\n", + ".sk-estimator-doc-link span {\n", + " display: none;\n", + " z-index: 9999;\n", + " position: relative;\n", + " font-weight: normal;\n", + " right: .2ex;\n", + " padding: .5ex;\n", + " margin: .5ex;\n", + " width: min-content;\n", + " min-width: 20ex;\n", + " max-width: 50ex;\n", + " color: var(--sklearn-color-text);\n", + " box-shadow: 2pt 2pt 4pt #999;\n", + " /* unfitted */\n", + " background: var(--sklearn-color-unfitted-level-0);\n", + " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted span {\n", + " /* fitted */\n", + " background: var(--sklearn-color-fitted-level-0);\n", + " border: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link:hover span {\n", + " display: block;\n", + "}\n", + "\n", + "/* \"?\"-specific style due to the `<a>` HTML tag */\n", + "\n", + "#sk-container-id-1 a.estimator_doc_link {\n", + " float: right;\n", + " font-size: 1rem;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1rem;\n", + " height: 1rem;\n", + " width: 1rem;\n", + " text-decoration: none;\n", + " /* unfitted */\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + "}\n", + "\n", + "#sk-container-id-1 a.estimator_doc_link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "#sk-container-id-1 a.estimator_doc_link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n", + " colsample_bylevel=None, colsample_bynode=None,\n", + " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", + " enable_categorical=False, eval_metric=None, feature_types=None,\n", + " gamma=None, grow_policy=None, importance_type=None,\n", + " interaction_constraints=None, learning_rate=0.04, max_bin=None,\n", + " max_cat_threshold=None, max_cat_to_onehot=None,\n", + " max_delta_step=None, max_depth=8, max_leaves=None,\n", + " min_child_weight=None, missing=nan, monotone_constraints=None,\n", + " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", + " num_parallel_tree=None, random_state=None, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> XGBClassifier<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n", + " colsample_bylevel=None, colsample_bynode=None,\n", + " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", + " enable_categorical=False, eval_metric=None, feature_types=None,\n", + " gamma=None, grow_policy=None, importance_type=None,\n", + " interaction_constraints=None, learning_rate=0.04, max_bin=None,\n", + " max_cat_threshold=None, max_cat_to_onehot=None,\n", + " max_delta_step=None, max_depth=8, max_leaves=None,\n", + " min_child_weight=None, missing=nan, monotone_constraints=None,\n", + " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", + " num_parallel_tree=None, random_state=None, ...)</pre></div> </div></div></div></div>" + ], + "text/plain": [ + "XGBClassifier(base_score=None, booster=None, callbacks=None,\n", + " colsample_bylevel=None, colsample_bynode=None,\n", + " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", + " enable_categorical=False, eval_metric=None, feature_types=None,\n", + " gamma=None, grow_policy=None, importance_type=None,\n", + " interaction_constraints=None, learning_rate=0.04, max_bin=None,\n", + " max_cat_threshold=None, max_cat_to_onehot=None,\n", + " max_delta_step=None, max_depth=8, max_leaves=None,\n", + " min_child_weight=None, missing=nan, monotone_constraints=None,\n", + " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", + " num_parallel_tree=None, random_state=None, ...)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xgb_predict = XGBClassifier(n_estimators=2000, learning_rate=0.04, max_depth=8)\n", + "xgb_predict.fit(df_embeddings_train, df_train_normalized['p_np'])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# get XGBoost predictions\n", + "y_prob = xgb_predict.predict_proba(df_embeddings_test)[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ROC-AUC Score: 0.9194\n" + ] + } + ], + "source": [ + "roc_auc = roc_auc_score(df_test_normalized[\"p_np\"], y_prob)\n", + "print(f\"ROC-AUC Score: {roc_auc:.4f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/smi_ted/notebooks/smi_ted_frozen_inference_example2.ipynb b/models/smi_ted/notebooks/smi_ted_frozen_inference_example2.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..045044361900ecf0876c5e05b51f023d733197f3 --- /dev/null +++ b/models/smi_ted/notebooks/smi_ted_frozen_inference_example2.ipynb @@ -0,0 +1,1327 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# materials.smi-TED - INFERENCE (Regression)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install extra packages for notebook\n", + "%pip install seaborn xgboost" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../inference')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# materials.smi-ted (smi-ted)\n", + "from smi_ted_light.load import load_smi_ted\n", + "\n", + "# Data\n", + "import torch\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "# Chemistry\n", + "from rdkit import Chem\n", + "from rdkit.Chem import PandasTools\n", + "from rdkit.Chem import Descriptors\n", + "PandasTools.RenderImagesInAllDataFrames(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# function to canonicalize SMILES\n", + "def normalize_smiles(smi, canonical=True, isomeric=False):\n", + " try:\n", + " normalized = Chem.MolToSmiles(\n", + " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n", + " )\n", + " except:\n", + " normalized = None\n", + " return normalized" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import smi-ted" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Random Seed: 12345\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Using Rotation Embedding\n", + "Vocab size: 2393\n", + "[INFERENCE MODE - smi-ted-Light]\n" + ] + } + ], + "source": [ + "model_smi_ted = load_smi_ted(\n", + " folder='../inference/smi_ted_light',\n", + " ckpt_filename='smi-ted-Light_40.pt'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Lipophilicity Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Experiments - Data Load" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "df_train = pd.read_csv(\"../finetune/moleculenet/lipophilicity/train.csv\")\n", + "df_test = pd.read_csv(\"../finetune/moleculenet/lipophilicity/test.csv\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SMILES canonization" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(3360, 3)\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>smiles</th>\n", + " <th>y</th>\n", + " <th>norm_smiles</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4</td>\n", + " <td>0.814313</td>\n", + " <td>Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12...</td>\n", + " <td>0.446346</td>\n", + " <td>COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5...</td>\n", + " <td>1.148828</td>\n", + " <td>CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13</td>\n", + " <td>0.404532</td>\n", + " <td>O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4...</td>\n", + " <td>-0.164144</td>\n", + " <td>O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)...</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " smiles y \\\n", + "0 Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4 0.814313 \n", + "1 COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12... 0.446346 \n", + "2 CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5... 1.148828 \n", + "3 Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13 0.404532 \n", + "4 Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... -0.164144 \n", + "\n", + " norm_smiles \n", + "0 Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1 \n", + "1 COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3... \n", + "2 CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c... \n", + "3 O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12 \n", + "4 O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)... " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n", + "df_train_normalized = df_train.dropna()\n", + "print(df_train_normalized.shape)\n", + "df_train_normalized.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(420, 3)\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>smiles</th>\n", + " <th>y</th>\n", + " <th>norm_smiles</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>N(c1ccccc1)c2ccnc3ccccc23</td>\n", + " <td>0.488161</td>\n", + " <td>c1ccc(Nc2ccnc3ccccc23)cc1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1</td>\n", + " <td>0.070017</td>\n", + " <td>Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5</td>\n", + " <td>-0.415030</td>\n", + " <td>NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc...</td>\n", + " <td>0.897942</td>\n", + " <td>O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>NS(=O)(=O)c1nc2ccccc2s1</td>\n", + " <td>-0.707731</td>\n", + " <td>NS(=O)(=O)c1nc2ccccc2s1</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " smiles y \\\n", + "0 N(c1ccccc1)c2ccnc3ccccc23 0.488161 \n", + "1 Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1 0.070017 \n", + "2 NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 -0.415030 \n", + "3 OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc... 0.897942 \n", + "4 NS(=O)(=O)c1nc2ccccc2s1 -0.707731 \n", + "\n", + " norm_smiles \n", + "0 c1ccc(Nc2ccnc3ccccc23)cc1 \n", + "1 Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2 \n", + "2 NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1 \n", + "3 O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[... \n", + "4 NS(=O)(=O)c1nc2ccccc2s1 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n", + "df_test_normalized = df_test.dropna()\n", + "print(df_test_normalized.shape)\n", + "df_test_normalized.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Embeddings extraction " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### smi-ted embeddings extraction" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 33/33 [00:38<00:00, 1.15s/it]\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>0</th>\n", + " <th>1</th>\n", + " <th>2</th>\n", + " <th>3</th>\n", + " <th>4</th>\n", + " <th>5</th>\n", + " <th>6</th>\n", + " <th>7</th>\n", + " <th>8</th>\n", + " <th>9</th>\n", + " <th>...</th>\n", + " <th>758</th>\n", + " <th>759</th>\n", + " <th>760</th>\n", + " <th>761</th>\n", + " <th>762</th>\n", + " <th>763</th>\n", + " <th>764</th>\n", + " <th>765</th>\n", + " <th>766</th>\n", + " <th>767</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0.367646</td>\n", + " <td>-0.504889</td>\n", + " <td>0.040485</td>\n", + " <td>0.385314</td>\n", + " <td>0.564923</td>\n", + " <td>-0.684497</td>\n", + " <td>1.160397</td>\n", + " <td>0.071218</td>\n", + " <td>0.799428</td>\n", + " <td>0.181323</td>\n", + " <td>...</td>\n", + " <td>-1.379994</td>\n", + " <td>-0.167221</td>\n", + " <td>0.104886</td>\n", + " <td>0.239571</td>\n", + " <td>-0.744390</td>\n", + " <td>0.590423</td>\n", + " <td>-0.808946</td>\n", + " <td>0.792584</td>\n", + " <td>0.550898</td>\n", + " <td>-0.176831</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0.455316</td>\n", + " <td>-0.485554</td>\n", + " <td>0.062206</td>\n", + " <td>0.387994</td>\n", + " <td>0.567590</td>\n", + " <td>-0.713285</td>\n", + " <td>1.144267</td>\n", + " <td>-0.057046</td>\n", + " <td>0.753016</td>\n", + " <td>0.112180</td>\n", + " <td>...</td>\n", + " <td>-1.332142</td>\n", + " <td>-0.096662</td>\n", + " <td>0.221944</td>\n", + " <td>0.327923</td>\n", + " <td>-0.739358</td>\n", + " <td>0.659803</td>\n", + " <td>-0.775723</td>\n", + " <td>0.745837</td>\n", + " <td>0.566330</td>\n", + " <td>-0.111946</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0.442309</td>\n", + " <td>-0.484732</td>\n", + " <td>0.084945</td>\n", + " <td>0.384787</td>\n", + " <td>0.564752</td>\n", + " <td>-0.704130</td>\n", + " <td>1.159491</td>\n", + " <td>0.021168</td>\n", + " <td>0.846539</td>\n", + " <td>0.118463</td>\n", + " <td>...</td>\n", + " <td>-1.324177</td>\n", + " <td>-0.110403</td>\n", + " <td>0.207824</td>\n", + " <td>0.281665</td>\n", + " <td>-0.780818</td>\n", + " <td>0.693484</td>\n", + " <td>-0.832626</td>\n", + " <td>0.763095</td>\n", + " <td>0.532460</td>\n", + " <td>-0.196708</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0.527961</td>\n", + " <td>-0.519151</td>\n", + " <td>0.091635</td>\n", + " <td>0.353518</td>\n", + " <td>0.421795</td>\n", + " <td>-0.724220</td>\n", + " <td>1.093752</td>\n", + " <td>0.148574</td>\n", + " <td>0.804047</td>\n", + " <td>0.194627</td>\n", + " <td>...</td>\n", + " <td>-1.358414</td>\n", + " <td>-0.111483</td>\n", + " <td>0.151692</td>\n", + " <td>0.186741</td>\n", + " <td>-0.601867</td>\n", + " <td>0.641591</td>\n", + " <td>-0.747422</td>\n", + " <td>0.794239</td>\n", + " <td>0.640765</td>\n", + " <td>-0.239649</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0.464432</td>\n", + " <td>-0.511090</td>\n", + " <td>0.038785</td>\n", + " <td>0.346217</td>\n", + " <td>0.492919</td>\n", + " <td>-0.619387</td>\n", + " <td>1.048157</td>\n", + " <td>0.095910</td>\n", + " <td>0.738604</td>\n", + " <td>0.119270</td>\n", + " <td>...</td>\n", + " <td>-1.223927</td>\n", + " <td>-0.109863</td>\n", + " <td>0.151280</td>\n", + " <td>0.244834</td>\n", + " <td>-0.686610</td>\n", + " <td>0.759327</td>\n", + " <td>-0.756338</td>\n", + " <td>0.766427</td>\n", + " <td>0.610454</td>\n", + " <td>-0.197345</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 768 columns</p>\n", + "</div>" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "0 0.367646 -0.504889 0.040485 0.385314 0.564923 -0.684497 1.160397 \n", + "1 0.455316 -0.485554 0.062206 0.387994 0.567590 -0.713285 1.144267 \n", + "2 0.442309 -0.484732 0.084945 0.384787 0.564752 -0.704130 1.159491 \n", + "3 0.527961 -0.519151 0.091635 0.353518 0.421795 -0.724220 1.093752 \n", + "4 0.464432 -0.511090 0.038785 0.346217 0.492919 -0.619387 1.048157 \n", + "\n", + " 7 8 9 ... 758 759 760 761 \\\n", + "0 0.071218 0.799428 0.181323 ... -1.379994 -0.167221 0.104886 0.239571 \n", + "1 -0.057046 0.753016 0.112180 ... -1.332142 -0.096662 0.221944 0.327923 \n", + "2 0.021168 0.846539 0.118463 ... -1.324177 -0.110403 0.207824 0.281665 \n", + "3 0.148574 0.804047 0.194627 ... -1.358414 -0.111483 0.151692 0.186741 \n", + "4 0.095910 0.738604 0.119270 ... -1.223927 -0.109863 0.151280 0.244834 \n", + "\n", + " 762 763 764 765 766 767 \n", + "0 -0.744390 0.590423 -0.808946 0.792584 0.550898 -0.176831 \n", + "1 -0.739358 0.659803 -0.775723 0.745837 0.566330 -0.111946 \n", + "2 -0.780818 0.693484 -0.832626 0.763095 0.532460 -0.196708 \n", + "3 -0.601867 0.641591 -0.747422 0.794239 0.640765 -0.239649 \n", + "4 -0.686610 0.759327 -0.756338 0.766427 0.610454 -0.197345 \n", + "\n", + "[5 rows x 768 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.no_grad():\n", + " df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n", + "df_embeddings_train.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 4/4 [00:05<00:00, 1.46s/it]\n" + ] + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>0</th>\n", + " <th>1</th>\n", + " <th>2</th>\n", + " <th>3</th>\n", + " <th>4</th>\n", + " <th>5</th>\n", + " <th>6</th>\n", + " <th>7</th>\n", + " <th>8</th>\n", + " <th>9</th>\n", + " <th>...</th>\n", + " <th>758</th>\n", + " <th>759</th>\n", + " <th>760</th>\n", + " <th>761</th>\n", + " <th>762</th>\n", + " <th>763</th>\n", + " <th>764</th>\n", + " <th>765</th>\n", + " <th>766</th>\n", + " <th>767</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0.392252</td>\n", + " <td>-0.504846</td>\n", + " <td>0.056791</td>\n", + " <td>0.356297</td>\n", + " <td>0.475918</td>\n", + " <td>-0.648899</td>\n", + " <td>1.157862</td>\n", + " <td>-0.022914</td>\n", + " <td>0.703240</td>\n", + " <td>0.192023</td>\n", + " <td>...</td>\n", + " <td>-1.208714</td>\n", + " <td>-0.094441</td>\n", + " <td>0.128845</td>\n", + " <td>0.403995</td>\n", + " <td>-0.782782</td>\n", + " <td>0.541907</td>\n", + " <td>-0.707272</td>\n", + " <td>0.901041</td>\n", + " <td>0.629461</td>\n", + " <td>-0.020630</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0.387422</td>\n", + " <td>-0.481142</td>\n", + " <td>0.049675</td>\n", + " <td>0.353058</td>\n", + " <td>0.601170</td>\n", + " <td>-0.646099</td>\n", + " <td>1.142392</td>\n", + " <td>0.060092</td>\n", + " <td>0.763799</td>\n", + " <td>0.110331</td>\n", + " <td>...</td>\n", + " <td>-1.248282</td>\n", + " <td>-0.139790</td>\n", + " <td>0.075585</td>\n", + " <td>0.202242</td>\n", + " <td>-0.729794</td>\n", + " <td>0.705914</td>\n", + " <td>-0.771751</td>\n", + " <td>0.843173</td>\n", + " <td>0.618850</td>\n", + " <td>-0.213584</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0.390975</td>\n", + " <td>-0.510056</td>\n", + " <td>0.070656</td>\n", + " <td>0.380695</td>\n", + " <td>0.601486</td>\n", + " <td>-0.595827</td>\n", + " <td>1.182193</td>\n", + " <td>0.011085</td>\n", + " <td>0.688093</td>\n", + " <td>0.056453</td>\n", + " <td>...</td>\n", + " <td>-1.294595</td>\n", + " <td>-0.164846</td>\n", + " <td>0.194435</td>\n", + " <td>0.240742</td>\n", + " <td>-0.773443</td>\n", + " <td>0.608631</td>\n", + " <td>-0.747181</td>\n", + " <td>0.791911</td>\n", + " <td>0.611874</td>\n", + " <td>-0.125455</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0.423924</td>\n", + " <td>-0.557325</td>\n", + " <td>0.083810</td>\n", + " <td>0.328703</td>\n", + " <td>0.399589</td>\n", + " <td>-0.622818</td>\n", + " <td>1.079945</td>\n", + " <td>0.097611</td>\n", + " <td>0.724030</td>\n", + " <td>0.135976</td>\n", + " <td>...</td>\n", + " <td>-1.412060</td>\n", + " <td>-0.106541</td>\n", + " <td>0.153314</td>\n", + " <td>0.209962</td>\n", + " <td>-0.699690</td>\n", + " <td>0.648061</td>\n", + " <td>-0.716241</td>\n", + " <td>0.757986</td>\n", + " <td>0.615963</td>\n", + " <td>-0.258693</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0.335576</td>\n", + " <td>-0.559591</td>\n", + " <td>0.119437</td>\n", + " <td>0.364141</td>\n", + " <td>0.375474</td>\n", + " <td>-0.639833</td>\n", + " <td>1.144707</td>\n", + " <td>0.077512</td>\n", + " <td>0.791759</td>\n", + " <td>0.164201</td>\n", + " <td>...</td>\n", + " <td>-1.279041</td>\n", + " <td>-0.186733</td>\n", + " <td>0.106963</td>\n", + " <td>0.254949</td>\n", + " <td>-0.651694</td>\n", + " <td>0.594167</td>\n", + " <td>-0.680426</td>\n", + " <td>0.887482</td>\n", + " <td>0.651587</td>\n", + " <td>-0.144996</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 768 columns</p>\n", + "</div>" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "0 0.392252 -0.504846 0.056791 0.356297 0.475918 -0.648899 1.157862 \n", + "1 0.387422 -0.481142 0.049675 0.353058 0.601170 -0.646099 1.142392 \n", + "2 0.390975 -0.510056 0.070656 0.380695 0.601486 -0.595827 1.182193 \n", + "3 0.423924 -0.557325 0.083810 0.328703 0.399589 -0.622818 1.079945 \n", + "4 0.335576 -0.559591 0.119437 0.364141 0.375474 -0.639833 1.144707 \n", + "\n", + " 7 8 9 ... 758 759 760 761 \\\n", + "0 -0.022914 0.703240 0.192023 ... -1.208714 -0.094441 0.128845 0.403995 \n", + "1 0.060092 0.763799 0.110331 ... -1.248282 -0.139790 0.075585 0.202242 \n", + "2 0.011085 0.688093 0.056453 ... -1.294595 -0.164846 0.194435 0.240742 \n", + "3 0.097611 0.724030 0.135976 ... -1.412060 -0.106541 0.153314 0.209962 \n", + "4 0.077512 0.791759 0.164201 ... -1.279041 -0.186733 0.106963 0.254949 \n", + "\n", + " 762 763 764 765 766 767 \n", + "0 -0.782782 0.541907 -0.707272 0.901041 0.629461 -0.020630 \n", + "1 -0.729794 0.705914 -0.771751 0.843173 0.618850 -0.213584 \n", + "2 -0.773443 0.608631 -0.747181 0.791911 0.611874 -0.125455 \n", + "3 -0.699690 0.648061 -0.716241 0.757986 0.615963 -0.258693 \n", + "4 -0.651694 0.594167 -0.680426 0.887482 0.651587 -0.144996 \n", + "\n", + "[5 rows x 768 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.no_grad():\n", + " df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n", + "df_embeddings_test.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Experiments - Lipophilicity prediction using smi-ted latent spaces" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### XGBoost prediction using the whole Latent Space" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from xgboost import XGBRegressor\n", + "from sklearn.metrics import mean_squared_error" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<style>#sk-container-id-1 {\n", + " /* Definition of color scheme common for light and dark mode */\n", + " --sklearn-color-text: black;\n", + " --sklearn-color-line: gray;\n", + " /* Definition of color scheme for unfitted estimators */\n", + " --sklearn-color-unfitted-level-0: #fff5e6;\n", + " --sklearn-color-unfitted-level-1: #f6e4d2;\n", + " --sklearn-color-unfitted-level-2: #ffe0b3;\n", + " --sklearn-color-unfitted-level-3: chocolate;\n", + " /* Definition of color scheme for fitted estimators */\n", + " --sklearn-color-fitted-level-0: #f0f8ff;\n", + " --sklearn-color-fitted-level-1: #d4ebff;\n", + " --sklearn-color-fitted-level-2: #b3dbfd;\n", + " --sklearn-color-fitted-level-3: cornflowerblue;\n", + "\n", + " /* Specific color for light theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-icon: #696969;\n", + "\n", + " @media (prefers-color-scheme: dark) {\n", + " /* Redefinition of color scheme for dark theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-icon: #878787;\n", + " }\n", + "}\n", + "\n", + "#sk-container-id-1 {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "#sk-container-id-1 pre {\n", + " padding: 0;\n", + "}\n", + "\n", + "#sk-container-id-1 input.sk-hidden--visually {\n", + " border: 0;\n", + " clip: rect(1px 1px 1px 1px);\n", + " clip: rect(1px, 1px, 1px, 1px);\n", + " height: 1px;\n", + " margin: -1px;\n", + " overflow: hidden;\n", + " padding: 0;\n", + " position: absolute;\n", + " width: 1px;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-dashed-wrapped {\n", + " border: 1px dashed var(--sklearn-color-line);\n", + " margin: 0 0.4em 0.5em 0.4em;\n", + " box-sizing: border-box;\n", + " padding-bottom: 0.4em;\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-container {\n", + " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", + " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", + " so we also need the `!important` here to be able to override the\n", + " default hidden behavior on the sphinx rendered scikit-learn.org.\n", + " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", + " display: inline-block !important;\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-text-repr-fallback {\n", + " display: none;\n", + "}\n", + "\n", + "div.sk-parallel-item,\n", + "div.sk-serial,\n", + "div.sk-item {\n", + " /* draw centered vertical line to link estimators */\n", + " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", + " background-size: 2px 100%;\n", + " background-repeat: no-repeat;\n", + " background-position: center center;\n", + "}\n", + "\n", + "/* Parallel-specific style estimator block */\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item::after {\n", + " content: \"\";\n", + " width: 100%;\n", + " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", + " flex-grow: 1;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel {\n", + " display: flex;\n", + " align-items: stretch;\n", + " justify-content: center;\n", + " background-color: var(--sklearn-color-background);\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item {\n", + " display: flex;\n", + " flex-direction: column;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n", + " align-self: flex-end;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n", + " align-self: flex-start;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n", + " width: 0;\n", + "}\n", + "\n", + "/* Serial-specific style estimator block */\n", + "\n", + "#sk-container-id-1 div.sk-serial {\n", + " display: flex;\n", + " flex-direction: column;\n", + " align-items: center;\n", + " background-color: var(--sklearn-color-background);\n", + " padding-right: 1em;\n", + " padding-left: 1em;\n", + "}\n", + "\n", + "\n", + "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", + "clickable and can be expanded/collapsed.\n", + "- Pipeline and ColumnTransformer use this feature and define the default style\n", + "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", + "*/\n", + "\n", + "/* Pipeline and ColumnTransformer style (default) */\n", + "\n", + "#sk-container-id-1 div.sk-toggleable {\n", + " /* Default theme specific background. It is overwritten whether we have a\n", + " specific estimator or a Pipeline/ColumnTransformer */\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "/* Toggleable label */\n", + "#sk-container-id-1 label.sk-toggleable__label {\n", + " cursor: pointer;\n", + " display: block;\n", + " width: 100%;\n", + " margin-bottom: 0;\n", + " padding: 0.5em;\n", + " box-sizing: border-box;\n", + " text-align: center;\n", + "}\n", + "\n", + "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n", + " /* Arrow on the left of the label */\n", + " content: \"▸\";\n", + " float: left;\n", + " margin-right: 0.25em;\n", + " color: var(--sklearn-color-icon);\n", + "}\n", + "\n", + "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "/* Toggleable content - dropdown */\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content {\n", + " max-height: 0;\n", + " max-width: 0;\n", + " overflow: hidden;\n", + " text-align: left;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content pre {\n", + " margin: 0.2em;\n", + " border-radius: 0.25em;\n", + " color: var(--sklearn-color-text);\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", + " /* Expand drop-down */\n", + " max-height: 200px;\n", + " max-width: 100%;\n", + " overflow: auto;\n", + "}\n", + "\n", + "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", + " content: \"▾\";\n", + "}\n", + "\n", + "/* Pipeline/ColumnTransformer-specific style */\n", + "\n", + "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator-specific style */\n", + "\n", + "/* Colorize estimator box */\n", + "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n", + "#sk-container-id-1 div.sk-label label {\n", + " /* The background is the default theme color */\n", + " color: var(--sklearn-color-text-on-default-background);\n", + "}\n", + "\n", + "/* On hover, darken the color of the background */\n", + "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "/* Label box, darken color on hover, fitted */\n", + "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator label */\n", + "\n", + "#sk-container-id-1 div.sk-label label {\n", + " font-family: monospace;\n", + " font-weight: bold;\n", + " display: inline-block;\n", + " line-height: 1.2em;\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-label-container {\n", + " text-align: center;\n", + "}\n", + "\n", + "/* Estimator-specific */\n", + "#sk-container-id-1 div.sk-estimator {\n", + " font-family: monospace;\n", + " border: 1px dotted var(--sklearn-color-border-box);\n", + " border-radius: 0.25em;\n", + " box-sizing: border-box;\n", + " margin-bottom: 0.5em;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-estimator.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "/* on hover */\n", + "#sk-container-id-1 div.sk-estimator:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-1 div.sk-estimator.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", + "\n", + "/* Common style for \"i\" and \"?\" */\n", + "\n", + ".sk-estimator-doc-link,\n", + "a:link.sk-estimator-doc-link,\n", + "a:visited.sk-estimator-doc-link {\n", + " float: right;\n", + " font-size: smaller;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1em;\n", + " height: 1em;\n", + " width: 1em;\n", + " text-decoration: none !important;\n", + " margin-left: 1ex;\n", + " /* unfitted */\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted,\n", + "a:link.sk-estimator-doc-link.fitted,\n", + "a:visited.sk-estimator-doc-link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "/* Span, style for the box shown on hovering the info icon */\n", + ".sk-estimator-doc-link span {\n", + " display: none;\n", + " z-index: 9999;\n", + " position: relative;\n", + " font-weight: normal;\n", + " right: .2ex;\n", + " padding: .5ex;\n", + " margin: .5ex;\n", + " width: min-content;\n", + " min-width: 20ex;\n", + " max-width: 50ex;\n", + " color: var(--sklearn-color-text);\n", + " box-shadow: 2pt 2pt 4pt #999;\n", + " /* unfitted */\n", + " background: var(--sklearn-color-unfitted-level-0);\n", + " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted span {\n", + " /* fitted */\n", + " background: var(--sklearn-color-fitted-level-0);\n", + " border: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link:hover span {\n", + " display: block;\n", + "}\n", + "\n", + "/* \"?\"-specific style due to the `<a>` HTML tag */\n", + "\n", + "#sk-container-id-1 a.estimator_doc_link {\n", + " float: right;\n", + " font-size: 1rem;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1rem;\n", + " height: 1rem;\n", + " width: 1rem;\n", + " text-decoration: none;\n", + " /* unfitted */\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + "}\n", + "\n", + "#sk-container-id-1 a.estimator_doc_link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "#sk-container-id-1 a.estimator_doc_link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n", + " colsample_bylevel=None, colsample_bynode=None,\n", + " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", + " enable_categorical=False, eval_metric=None, feature_types=None,\n", + " gamma=None, grow_policy=None, importance_type=None,\n", + " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n", + " max_cat_threshold=None, max_cat_to_onehot=None,\n", + " max_delta_step=None, max_depth=4, max_leaves=None,\n", + " min_child_weight=None, missing=nan, monotone_constraints=None,\n", + " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", + " num_parallel_tree=None, random_state=None, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> XGBRegressor<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n", + " colsample_bylevel=None, colsample_bynode=None,\n", + " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", + " enable_categorical=False, eval_metric=None, feature_types=None,\n", + " gamma=None, grow_policy=None, importance_type=None,\n", + " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n", + " max_cat_threshold=None, max_cat_to_onehot=None,\n", + " max_delta_step=None, max_depth=4, max_leaves=None,\n", + " min_child_weight=None, missing=nan, monotone_constraints=None,\n", + " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", + " num_parallel_tree=None, random_state=None, ...)</pre></div> </div></div></div></div>" + ], + "text/plain": [ + "XGBRegressor(base_score=None, booster=None, callbacks=None,\n", + " colsample_bylevel=None, colsample_bynode=None,\n", + " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", + " enable_categorical=False, eval_metric=None, feature_types=None,\n", + " gamma=None, grow_policy=None, importance_type=None,\n", + " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n", + " max_cat_threshold=None, max_cat_to_onehot=None,\n", + " max_delta_step=None, max_depth=4, max_leaves=None,\n", + " min_child_weight=None, missing=nan, monotone_constraints=None,\n", + " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", + " num_parallel_tree=None, random_state=None, ...)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xgb_predict = XGBRegressor(n_estimators=2000, learning_rate=0.05, max_depth=4)\n", + "xgb_predict.fit(df_embeddings_train, df_train_normalized['y'])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# get XGBoost predictions\n", + "y_pred = xgb_predict.predict(df_embeddings_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RMSE Score: 0.6485\n" + ] + } + ], + "source": [ + "rmse = np.sqrt(mean_squared_error(df_test_normalized[\"y\"], y_pred))\n", + "print(f\"RMSE Score: {rmse:.4f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/smi_ted/paper/smi-ted_preprint.pdf b/models/smi_ted/paper/smi-ted_preprint.pdf new file mode 100644 index 0000000000000000000000000000000000000000..5c6c3282b163b354a9beded102db50495605dc75 --- /dev/null +++ b/models/smi_ted/paper/smi-ted_preprint.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75b2e2dc74c2d9a87cb19d19a4770e0db353b4ade883d8cc366b1faef4ba053f +size 3343180 diff --git a/models/smi_ted/requirements.txt b/models/smi_ted/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c8d2ad17095979cb7654b7588808a4b8015131d4 --- /dev/null +++ b/models/smi_ted/requirements.txt @@ -0,0 +1,9 @@ +transformers +torch-optimizer +datasets +scikit-learn +scipy>=1.12.0 +numpy==1.26.4 +pandas==1.4.0 +tqdm>=4.66.4 +rdkit>=2024.3.5 \ No newline at end of file diff --git a/models/smi_ted/smi_ted_light/load.py b/models/smi_ted/smi_ted_light/load.py index aa08775357ac66b2a1fd3f077e2f90ed5a2fc9ae..adbb9239b5a3e37af7197abbea1ced068a8f99d7 100644 --- a/models/smi_ted/smi_ted_light/load.py +++ b/models/smi_ted/smi_ted_light/load.py @@ -675,4 +675,6 @@ def load_smi_ted(folder="./smi_ted_light", model.eval() print('Vocab size:', len(tokenizer.vocab)) print(f'[INFERENCE MODE - {str(model)}]') - return model \ No newline at end of file + return model + + diff --git a/models/smi_ted/training/args.py b/models/smi_ted/training/args.py new file mode 100644 index 0000000000000000000000000000000000000000..282d78326279fbfb1179c880c454bafd0df8ad84 --- /dev/null +++ b/models/smi_ted/training/args.py @@ -0,0 +1,254 @@ +import argparse + + +def get_parser(parser=None): + if parser is None: + parser = argparse.ArgumentParser() + + # Model + #model_arg = parser.add_argument_group('Model') + parser.add_argument('--n_head', + type=int, default=8, + help='GPT number of heads') + parser.add_argument('--n_layer', + type=int, default=12, + help='GPT number of layers') + parser.add_argument('--q_dropout', + type=float, default=0.5, + help='Encoder layers dropout') + parser.add_argument('--d_dropout', + type=float, default=0.1, + help='Decoder layers dropout') + parser.add_argument('--n_embd', + type=int, default=768, + help='Latent vector dimensionality') + parser.add_argument('--fc_h', + type=int, default=512, + help='Fully connected hidden dimensionality') + + + # Train + #train_arg = parser.add_argument_group('Train') + parser.add_argument('--n_batch', + type=int, default=512, + help='Batch size') + parser.add_argument('--unlike_alpha', + type=float, default=1.0, + help='unlikelihood loss alpha weight') + parser.add_argument('--from_scratch', + action='store_true', default=False, + help='train on qm9 from scratch') + parser.add_argument('--unlikelihood', + action='store_true', default=False, + help='use unlikelihood loss with gpt pretrain') + parser.add_argument('--grad_acc', + type=int, default=1, + help='number of batches to accumulate gradients') + parser.add_argument('--checkpoint_every', + type=int, default=1000, + help='save checkpoint every x iterations') + parser.add_argument('--clip_grad', + type=int, default=50, + help='Clip gradients to this value') + parser.add_argument('--lr_start', + type=float, default=3 * 1e-4, + help='Initial lr value') + parser.add_argument('--lr_end', + type=float, default=3 * 1e-4, + help='Maximum lr weight value') + parser.add_argument('--lr_multiplier', + type=int, default=1, + help='lr weight multiplier') + parser.add_argument('--n_last', + type=int, default=1000, + help='Number of iters to smooth loss calc') + parser.add_argument('--n_jobs', + type=int, default=1, + help='Number of threads') + parser.add_argument('--accelerator', + type=str, default='ddp', + help='The accelerator backend to use (previously known as distributed_backend)') + parser.add_argument('--num_nodes', + type=int, default=1, + help='number of GPU nodes for distributed training') + parser.add_argument('--device', + type=str, default='cuda', + help='Device to run: "cpu" or "cuda:<device number>"') + parser.add_argument('--seed', + type=int, default=12345, + help='Seed') + parser.add_argument('--init_params_from', + type=str, default='', + help='Path to a ckpt used to initialize the parameters if no restart_path is provided') + parser.add_argument('--train_decoder_every', + type=int, default=10, + help='Optimize decoder params every n batches') + parser.add_argument('--lr_decoder', + type=float, default=1e-4, + help='Learning rate for decoder part') + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') + parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') + parser.add_argument('--save_checkpoint_path', default='/data', help='checkpoint saving path') + parser.add_argument('--load_checkpoint_path', default='', help='checkpoint loading path') + + #common_arg = parser.add_argument_group('Common') + parser.add_argument('--vocab_load', + type=str, required=False, + help='Where to load the vocab') + parser.add_argument('--n_samples', + type=int, required=False, + help='Number of samples to sample') + parser.add_argument('--gen_save', + type=str, required=False, + help='Where to save the gen molecules') + parser.add_argument("--max_len", + type=int, default=100, + help="Max of length of SMILES") + parser.add_argument('--train_load', + type=str, required=False, + help='Where to load the model') + parser.add_argument('--val_load', + type=str, required=False, + help='Where to load the model') + parser.add_argument('--n_workers', + type=int, required=False, default=1, + help='Where to load the model') + #beam search hyper parameters + parser.add_argument('--beam_size', type=int, default=0, + help="Number of beams to generate") + parser.add_argument('--num_seq_returned', type=int, default=0, + help="number of beams to be returned (must be <= beam_size") + parser.add_argument('--min_len', type=int, default=1, + help="minimum length to be generated") + parser.add_argument('--nucleus_thresh', type=float, default=.9, + help="nucleus sampling threshold") + parser.add_argument('--finetune_path', + type=str, default="", + help='path to trainer file to continue training') + parser.add_argument('--restart_path', + type=str, default="", + help='path to trainer file to continue training') + parser.add_argument('--data_path', + type=str, default="", + help='path to pubchem file') + parser.add_argument('--pretext_size', + type=int, default=0, + help='number of k-mers to pretext') + parser.add_argument('--model_save_dir', + type=str, required=False, default='./models_dump/', + help='Where to save the models/log/config/vocab') + parser.add_argument('--model_save', + type=str, required=False, default='model.pt', + help='Where to save the model') + #parser.add_argument('--save_frequency', + # type=int, default=20, + # help='How often to save the model') + parser.add_argument('--num_epoch', + type=int, default=1, + help='number of epochs to train') + #parser.add_argument('--num_iter', + # type=int, default=-1, + # help='how many itersations per epoch (for unlikelihood tuning)') + parser.add_argument('--log_file', + type=str, required=False, + help='Where to save the log') + parser.add_argument('--tb_loc', + type=str, required=False, + help='Where to save the tensorflow location') + parser.add_argument('--config_save', + type=str, required=False, + help='Where to save the config') + parser.add_argument('--vocab_save', + type=str, + help='Where to save the vocab') + + # resume_arg = parser.add_argument_group('Resume') + parser.add_argument('--debug', + default=False, action='store_true', + help='do not erase cache at end of program') + parser.add_argument('--fast_dev_run', + default=False, + help='This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).') + parser.add_argument('--freeze_model', + default=False, action='store_true', + help='freeze weights of bert model during fine tuning') + parser.add_argument('--resume', + default=False, action='store_true', + help='Resume from a saved model') + parser.add_argument('--rotate', + default=False, action='store_true', + help='use rotational relative embedding') + parser.add_argument('--model_load', + type=str, required=False, + help='Where to load the model') + parser.add_argument('--root_dir', + type=str, required=False, default='.', + help='location of root dir') + parser.add_argument('--config_load', + type=str, required=False, + help='Where to load the config') + parser.add_argument('--gpus', + type=int, required=False, default=1, + help='number of gpus to use') + #parser.add_argument('--start_epoch', + # type=int, required=False, default=0, + # help='Where to load the config') + + parser.add_argument('--model_arch', + type=str, required=False, + help='used to teack model arch in params') + parser.add_argument('--eval_every', + type=int, default=50000, + help='run evaluation every x iterations') + parser.add_argument('--num_feats', + type=int, required=False, default=32, + help='number of random reatures for FAVOR+') + parser.add_argument('--max_epochs', + type=int, required=False, default=1, + help='max number of epochs') + + # debug() FINE TUNEING + # parser.add_argument('--save_dir', type=str, required=True) + parser.add_argument('--mode', + type=str, default='cls', + help='type of pooling to use') + parser.add_argument("--dataset_length", type=int, default=None, required=False) + parser.add_argument("--num_workers", type=int, default=0, required=False) + parser.add_argument("--dropout", type=float, default=0.1, required=False) + #parser.add_argument("--dims", type=int, nargs="*", default="", required=False) + parser.add_argument( + "--smiles_embedding", + type=str, + default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt", + ) + # parser.add_argument("--train_pct", type=str, required=False, default="95") + #parser.add_argument("--aug", type=int, required=True) + parser.add_argument("--dataset_name", type=str, required=False, default="sol") + parser.add_argument("--measure_name", type=str, required=False, default="measure") + parser.add_argument("--smi_ted_version", type=str, required=True, default="v1") + #parser.add_argument("--emb_type", type=str, required=True) + #parser.add_argument("--checkpoints_folder", type=str, required=True) + #parser.add_argument("--results_dir", type=str, required=True) + #parser.add_argument("--patience_epochs", type=int, required=True) + + parser.add_argument( + "--data_root", + type=str, + required=False, + default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity", + ) + # parser.add_argument("--use_bn", type=int, default=0) + parser.add_argument("--use_linear", type=int, default=0) + + parser.add_argument("--lr", type=float, default=0.001) + # parser.add_argument("--weight_decay", type=float, default=5e-4) + # parser.add_argument("--val_check_interval", type=float, default=1.0) + parser.add_argument("--batch_size", type=int, default=64) + + return parser +def parse_args(): + parser = get_parser() + args = parser.parse_args() + return args + diff --git a/models/smi_ted/training/bert_vocab_curated.txt b/models/smi_ted/training/bert_vocab_curated.txt new file mode 100644 index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd --- /dev/null +++ b/models/smi_ted/training/bert_vocab_curated.txt @@ -0,0 +1,2393 @@ +<bos> +<eos> +<pad> +<mask> +C +c +( +) +1 +O +N +2 += +n +3 +[C@H] +[C@@H] +F +S +4 +Cl +- +o +s +[nH] +# +/ +Br +[C@] +[C@@] +[N+] +[O-] +5 +\ +. +I +6 +[S@] +[S@@] +P +[N-] +[Si] +7 +[n+] +[2H] +8 +[NH+] +B +9 +[C-] +[Na+] +[Cl-] +[c-] +[CH] +%10 +[NH2+] +[P+] +[B] +[I-] +%11 +[CH2-] +[O+] +[NH3+] +[C] +[Br-] +[IH2] +[S-] +[cH-] +%12 +[nH+] +[B-] +[K+] +[Sn] +[Se] +[CH-] +[HH] +[Y] +[n-] +[CH3-] +[SiH] +[S+] +%13 +[SiH2] +[Li+] +[NH-] +%14 +[Na] +[CH2] +[O-2] +[U+2] +[W] +[Al] +[P@] +[Fe+2] +[PH+] +%15 +[Cl+3] +[Zn+2] +[Ir] +[Mg+2] +[Pt+2] +[OH2+] +[As] +[Fe] +[OH+] +[Zr+2] +[3H] +[Ge] +[SiH3] +[OH-] +[NH4+] +[Cu+2] +[P@@] +p +[Pt] +%16 +[Ca+2] +[Zr] +[F-] +[C+] +[Ti] +[P-] +[V] +[se] +[U] +[O] +[Ni+2] +[Zn] +[Co] +[Ni] +[Pd+2] +[Cu] +%17 +[Cu+] +[Te] +[H+] +[CH+] +[Li] +[Pd] +[Mo] +[Ru+2] +[o+] +[Re] +[SH+] +%18 +[Ac] +[Cr] +[NH2-] +[K] +[13CH2] +[c] +[Zr+4] +[Tl] +[13C] +[Mn] +[N@+] +[Hg] +[Rh] +[Ti+4] +[Sb] +[Co+2] +[Ag+] +[Ru] +%19 +[N@@+] +[Ti+2] +[Al+3] +[Pb] +[I+] +[18F] +[s+] +[Rb+] +[Ba+2] +[H-] +[Fe+3] +[Ir+3] +[13cH] +%20 +[AlH2] +[Au+] +[13c] +[SH2+] +[Sn+2] +[Mn+2] +[Si-] +[Ag] +[N] +[Bi] +%21 +[In] +[CH2+] +[Y+3] +[Ga] +%22 +[Co+3] +[Au] +[13CH3] +[Mg] +[Cs+] +[W+2] +[Hf] +[Zn+] +[Se-] +[S-2] +[Ca] +[pH] +[ClH+] +[Ti+3] +%23 +[Ru+] +[SH-] +[13CH] +[IH+] +[Hf+4] +[Rf] +[OH3+] +%24 +[Pt+4] +[Zr+3] +[PH3+] +[Sr+2] +[Cd+2] +[Cd] +%25 +[Os] +[BH-] +[Sn+4] +[Cr+3] +[Ru+3] +[PH2+] +[Rh+2] +[V+2] +%26 +[Gd+3] +[Pb+2] +[PH] +[Hg+] +[Mo+2] +[AlH] +[Sn+] +%27 +[Pd+] +b +[Rh+3] +[Hg+2] +[15NH] +[14C] +%28 +[Mn+3] +[Si+] +[SeH] +[13C@H] +[NH] +[Ga+3] +[SiH-] +[13C@@H] +[Ce] +[Au+3] +[Bi+3] +[15N] +%29 +[BH3-] +[14cH] +[Ti+] +[Gd] +[cH+] +[Cr+2] +[Sb-] +%30 +[Be+2] +[Al+] +[te] +[11CH3] +[Sm] +[Pr] +[La] +%31 +[Al-] +[Ta] +[125I] +[BH2-] +[Nb] +[Si@] +%32 +[14c] +[Sb+3] +[Ba] +%33 +[Os+2] +[Si@@] +[La+3] +[15n] +[15NH2] +[Nd+3] +%34 +[14CH2] +[18O] +[Nd] +[GeH] +[Ni+3] +[Eu] +[Dy+3] +[Sc] +%36 +[Se-2] +[As+] +%35 +[AsH] +[Tb] +[Sb+5] +[Se+] +[Ce+3] +[c+] +[In+3] +[SnH] +[Mo+4] +%37 +[V+4] +[Eu+3] +[Hf+2] +%38 +[Pt+] +[p+] +[123I] +[Tl+] +[Sm+3] +%39 +[Yb+3] +%40 +[Yb] +[Os+] +%41 +[10B] +[Sc+3] +[Al+2] +%42 +[Sr] +[Tb+3] +[Po] +[Tc] +[PH-] +[AlH3] +[Ar] +[U+4] +[SnH2] +[Cl+2] +[si] +[Fe+] +[14CH3] +[U+3] +[Cl+] +%43 +[GeH2] +%44 +[Er+3] +[Mo+3] +[I+2] +[Fe+4] +[99Tc] +%45 +[11C] +%46 +[SnH3] +[S] +[Te+] +[Er] +[Lu+3] +[11B] +%47 +%48 +[P] +[Tm] +[Th] +[Dy] +[Pr+3] +[Ta+5] +[Nb+5] +[Rb] +[GeH3] +[Br+2] +%49 +[131I] +[Fm] +[Cs] +[BH4-] +[Lu] +[15nH] +%50 +[Ru+6] +[b-] +[Ho] +[Th+4] +[Ru+4] +%52 +[14CH] +%51 +[Cr+6] +[18OH] +[Ho+3] +[Ce+4] +[Bi+2] +[Co+] +%53 +[Yb+2] +[Fe+6] +[Be] +%54 +[SH3+] +[Np] +[As-] +%55 +[14C@@H] +[Ir+2] +[GaH3] +[p-] +[GeH4] +[Sn+3] +[Os+4] +%56 +[14C@H] +[sH+] +[19F] +[Eu+2] +[TlH] +%57 +[Cr+4] +%58 +[B@@-] +[SiH+] +[At] +[Am] +[Fe+5] +[AsH2] +[Si+4] +[B@-] +[Pu] +[SbH] +[P-2] +[Tm+3] +* +%59 +[se+] +[IH-] +%60 +[oH+] +[1H] +[15N+] +[124I] +[S@@+] +[P-3] +[H] +[IH2+] +[TeH] +[Xe] +[PH4+] +[Cr+] +[Cm] +[I+3] +%61 +[Nb+2] +[Ru+5] +%62 +[Ta+2] +[Tc+4] +[CH3+] +[Pm] +[Si@H] +[No] +%63 +[Cr+5] +[Th+2] +[Zn-2] +[13C@] +[Lr] +%64 +[99Tc+3] +%65 +[13C@@] +%66 +[Fe-] +[17O] +[siH] +[Sb+] +[OH] +[IH] +[11CH2] +[Cf] +[SiH2+] +[Gd+2] +[In+] +[Si@@H] +[Mn+] +[99Tc+4] +[Ga-] +%67 +[S@+] +[Ge+4] +[Tl+3] +[16OH] +%68 +[2H-] +[Ra] +[si-] +[NiH2] +[P@@H] +[Rh+] +[12C] +[35S] +[32P] +[SiH2-] +[AlH2+] +[16O] +%69 +[BiH] +[BiH2] +[Zn-] +[BH] +[Tc+3] +[Ir+] +[Ni+] +%70 +[InH2] +[InH] +[Nb+3] +[PbH] +[Bi+] +%71 +[As+3] +%72 +[18O-] +[68Ga+3] +%73 +[Pa] +[76Br] +[Tc+5] +[pH+] +[64Cu+2] +[Ru+8] +%74 +[PH2-] +[Si+2] +[17OH] +[RuH] +[111In+3] +[AlH+] +%75 +%76 +[W+] +[SbH2] +[PoH] +[Ru-] +[XeH] +[Tc+2] +[13C-] +[Br+] +[Pt-2] +[Es] +[Cu-] +[Mg+] +[3HH] +[P@H] +[ClH2+] +%77 +[SH] +[Au-] +[2HH] +%78 +[Sn-] +[11CH] +[PdH2] +0 +[Os+6] +%79 +[Mo+] +%80 +[al] +[PbH2] +[64Cu] +[Cl] +[12CH3] +%81 +[Tc+7] +[11c] +%82 +[Li-] +[99Tc+5] +[He] +[12c] +[Kr] +[RuH+2] +[35Cl] +[Pd-2] +[GaH2] +[4H] +[Sg] +[Cu-2] +[Br+3] +%83 +[37Cl] +[211At] +[IrH+2] +[Mt] +[Ir-2] +[In-] +[12cH] +[12CH2] +[RuH2] +[99Tc+7] +%84 +[15n+] +[ClH2+2] +[16N] +[111In] +[Tc+] +[Ru-2] +[12CH] +[si+] +[Tc+6] +%85 +%86 +[90Y] +[Pd-] +[188Re] +[RuH+] +[NiH] +[SiH3-] +[14n] +[CH3] +[14N] +[10BH2] +%88 +%89 +%90 +[34S] +[77Br] +[GaH] +[Br] +[Ge@] +[B@@H-] +[CuH] +[SiH4] +[3H-] +%87 +%91 +%92 +[67Cu] +[I] +[177Lu] +[ReH] +[67Ga+3] +[Db] +[177Lu+3] +[AlH2-] +[Si+3] +[Ti-2] +[RuH+3] +[al+] +[68Ga] +[2H+] +[B@H-] +[WH2] +[OsH] +[Ir-3] +[AlH-] +[Bk] +[75Se] +[14C@] +[Pt-] +[N@@H+] +[Nb-] +[13NH2] +%93 +[186Re] +[Tb+4] +[PtH] +[IrH2] +[Hg-2] +[AlH3-] +[PdH+] +[Md] +[RhH+2] +[11cH] +[Co-2] +[15N-] +[ZrH2] +%94 +[Hg-] +[127I] +[AsH2+] +[MoH2] +[Te+4] +[14C@@] +[As+5] +[SnH+3] +[Ge@@] +[6Li+] +[WH] +[Ne] +[14NH2] +[14NH] +[12C@@H] +[Os+7] +[RhH] +[Al-3] +[SnH+] +[15NH3+] +[Zr+] +[197Hg+] +%95 +%96 +[90Y+3] +[Os-2] +[98Tc+5] +[15NH3] +[bH-] +[33P] +[Zr-2] +[15O] +[Rh-] +[PbH3] +[PH2] +[Ni-] +[CuH+] +%97 +%98 +%99 +[Os+5] +[PtH+] +[ReH4] +[16NH] +[82Br] +[W-] +[18F-] +[15NH4+] +[Se+4] +[SeH-] +[SH4] +[67Cu+2] +[12C@H] +[AsH3] +[HgH] +[10B-] +[99Tc+6] +[117Sn+4] +[Te@] +[P@+] +[35SH] +[SeH+] +[Ni-2] +[Al-2] +[TeH2] +[Bh] +[99Tc+2] +[Os+8] +[PH-2] +[7Li+] +[14nH] +[AlH+2] +[18FH] +[SnH4] +[18O-2] +[IrH] +[13N] +[Te@@] +[Rh-3] +[15NH+] +[AsH3+] +[SeH2] +[AsH+] +[CoH2] +[16NH2] +[AsH-] +[203Hg+] +[P@@+] +[166Ho+3] +[60Co+3] +[13CH2-] +[SeH2+] +[75Br] +[TlH2] +[80Br] +[siH+] +[Ca+] +[153Sm+3] +[PdH] +[225Ac] +[13CH3-] +[AlH4-] +[FeH] +[13CH-] +[14C-] +[11C-] +[153Sm] +[Re-] +[te+] +[13CH4] +[ClH+2] +[8CH2] +[99Mo] +[ClH3+3] +[SbH3] +[25Mg+2] +[16N+] +[SnH2+] +[PH4] +[11C@H] +[122I] +[Re-2] +[RuH2+2] +[ZrH] +[Bi-] +[Pr+] +[Rn] +[Fr] +[36Cl] +[18o] +[YH] +[79Br] +[121I] +[113In+3] +[InH4-] +[TaH] +[RhH2] +[Ta-] +[67Ga] +[ZnH+] +[SnH2-] +[OsH2] +[16F] +[FeH2] +[14O] +[PbH2+2] +[BH2] +[6H] +[125Te] +[197Hg] +[TaH2] +[TaH3] +[76As] +[Nb-2] +[14N+] +[125I-] +[33S] +[IH2+2] +[NH2] +[PtH2] +[MnH] +[19C] +[17F] +[1H-] +[SnH4+2] +[Mn-2] +[15NH2+] +[TiH2] +[ReH7] +[Cd-2] +[Fe-3] +[SH2] +[17O-] +[siH-] +[CoH+] +[VH] +[10BH] +[Ru-3] +[13O] +[5H] +[CoH] +[PH5] +[15n-] +[153Gd] +[12C@] +[11CH3-] +[IrH3] +[RuH3] +[74Se] +[Se@] +[Hf+] +[77Se] +[166Ho] +[59Fe+2] +[203Hg] +[18OH-] +[8CH] +[12C@@] +[11CH4] +[15C] +[249Cf] +[PbH4] +[64Zn] +[PH3] +[99Tc+] +[14c-] +[149Pm] +[IrH4] +[Se@@] +[13OH] +[14CH3-] +[28Si] +[Rh-2] +[Fe-2] +[131I-] +[51Cr] +[62Cu+2] +[81Br] +[121Sb] +[7Li] +[89Zr+4] +[SbH3+] +[11C@@H] +[98Tc] +[59Fe+3] +[BiH2+] +[SbH+] +[TiH] +[14NH3] +[15OH] +[119Sn] +[201Hg] +[MnH+] +[201Tl] +[51Cr+3] +[123I-] +[MoH] +[AlH6-3] +[MnH2] +[WH3] +[213Bi+3] +[SnH2+2] +[123IH] +[13CH+] +[Zr-] +[74As] +[13C+] +[32P+] +[KrH] +[SiH+2] +[ClH3+2] +[13NH] +[9CH2] +[ZrH2+2] +[87Sr+2] +[35s] +[239Pu] +[198Au] +[241Am] +[203Hg+2] +[V+] +[YH2] +[SH5] +[195Pt] +[203Pb] +[RuH4] +[ThH2] +[AuH] +[66Ga+3] +[11B-] +[F] +[24Na+] +[85Sr+2] +[201Tl+] +[14CH4] +[32S] +[TeH2+] +[ClH2+3] +[AgH] +[Ge@H] +[44Ca+2] +[Os-] +[31P] +[15nH+] +[SbH4] +[TiH+] +[Ba+] +[57Co+2] +[Ta+] +[125IH] +[77As] +[129I] +[Fe-4] +[Ta-2] +[19O] +[12O] +[BiH3] +[237Np] +[252Cf] +[86Y] +[Cr-2] +[89Y] +[195Pt+2] +[si+2] +[58Fe+2] +[Hs] +[S@@H] +[OsH6] +[GdH2] +[IH3] +[8CH4] +[164Dy+3] +[47Ca+2] +[57Co] +[NbH2] +[ReH2] +[ZnH2] +[CrH2] +[17NH] +[ZrH3] +[RhH3] +[12C-] +[18O+] +[Bi-2] +[ClH4+3] +[Ni-3] +[Ag-] +[111In-] +[Mo-2] +[55Fe+3] +[204Hg+] +[35Cl-] +[211Pb] +[75Ge] +[8B] +[TeH3] +[SnH3+] +[Zr-3] +[28F] +[249Bk] +[169Yb] +[34SH] +[6Li] +[94Tc] +[197Au] +[195Pt+4] +[169Yb+3] +[32Cl] +[82Se] +[159Gd+3] +[213Bi] +[CoH+2] +[36S] +[35P] +[Ru-4] +[Cr-3] +[60Co] +[1H+] +[18CH2] +[Cd-] +[152Sm+3] +[106Ru] +[238Pu] +[220Rn] +[45Ca+2] +[89Sr+2] +[239Np] +[90Sr+2] +[137Cs+] +[165Dy] +[68GaH3] +[65Zn+2] +[89Zr] +[BiH2+2] +[62Cu] +[165Dy+3] +[238U] +[105Rh+3] +[70Zn] +[12B] +[12OH] +[18CH] +[17CH] +[OsH3] +[SbH-] +[SH6] +[AlH2-2] +[42K] +[76Br-] +[71As] +[NbH3] +[ReH3] +[OsH-] +[WH4] +[MoH3] +[OsH4] +[RuH6] +[PtH3] +[CuH2] +[CoH3] +[TiH4] +[64Zn+2] +[Si-2] +[79BrH] +[14CH2-] +[PtH2+2] +[Os-3] +[29Si] +[Ti-] +[Se+6] +[22Na+] +[42K+] +[131Cs+] +[86Rb+] +[134Cs+] +[209Po] +[208Po] +[81Rb+] +[203Tl+] +[Zr-4] +[148Sm] +[147Sm] +[37Cl-] +[12CH4] +[Ge@@H] +[63Cu] +[13CH2+] +[AsH2-] +[CeH] +[SnH-] +[UH] +[9c] +[21CH3] +[TeH+] +[57Co+3] +[8BH2] +[12BH2] +[19BH2] +[9BH2] +[YbH2] +[CrH+2] +[208Bi] +[152Gd] +[61Cu] +[115In] +[60Co+2] +[13NH2-] +[120I] +[18OH2] +[75SeH] +[SbH2+] +[144Ce] +[16n] +[113In] +[22nH] +[129I-] +[InH3] +[32PH3] +[234U] +[235U] +[59Fe] +[82Rb+] +[65Zn] +[244Cm] +[147Pm] +[91Y] +[237Pu] +[231Pa] +[253Cf] +[127Te] +[187Re] +[236Np] +[235Np] +[72Zn] +[253Es] +[159Dy] +[62Zn] +[101Tc] +[149Tb] +[124I-] +[SeH3+] +[210Pb] +[40K] +[210Po] +[214Pb] +[218Po] +[214Po] +[7Be] +[212Pb] +[205Pb] +[209Pb] +[123Te] +[202Pb] +[72As] +[201Pb] +[70As] +[73Ge] +[200Pb] +[198Pb] +[66Ga] +[73Se] +[195Pb] +[199Pb] +[144Ce+3] +[235U+2] +[90Tc] +[114In+3] +[128I] +[100Tc+] +[82Br-] +[191Pt+2] +[191Pt+4] +[193Pt+4] +[31PH3] +[125I+2] +[131I+2] +[125Te+4] +[82Sr+2] +[149Sm] +[81BrH] +[129Xe] +[193Pt+2] +[123I+2] +[Cr-] +[Co-] +[227Th+4] +[249Cf+3] +[252Cf+3] +[187Os] +[16O-] +[17O+] +[16OH-] +[98Tc+7] +[58Co+2] +[69Ga+3] +[57Fe+2] +[43K+] +[16C] +[52Fe+3] +[SeH5] +[194Pb] +[196Pb] +[197Pb] +[213Pb] +[9B] +[19B] +[11CH-] +[9CH] +[20OH] +[25OH] +[8cH] +[TiH+3] +[SnH6+3] +[N@H+] +[ZnH] +[VH3] +[52Mn+2] +[64Ga] +[13B] +[216Bi] +[117Sn+2] +[232Th] +[SnH+2] +[BiH5] +[77Kr] +[103Cd] +[62Ni] +[LaH3] +[SmH3] +[EuH3] +[MoH5] +[64Ni] +[66Zn] +[68Zn] +[186W] +[FeH4] +[MoH4] +[HgH2] +[15NH2-] +[UH2] +[204Hg] +[GaH4-] +[ThH4] +[WH6] +[PtH4] +[VH2] +[UH3] +[FeH3] +[RuH5] +[BiH4] +[80Br-] +[CeH3] +[37ClH] +[157Gd+3] +[205Tl] +[203Tl] +[62Cu+] +[64Cu+] +[61Cu+] +[37SH2] +[30Si] +[28Al] +[19OH2] +[8He] +[6He] +[153Pm] +[209Bi] +[66Zn+2] +[10CH4] +[191Ir] +[66Cu] +[16O+] +[25O] +[10c] +[Co-3] +[Sn@@] +[17OH-] +[206Po] +[204Po] +[202Po] +[201Po] +[200Po] +[199Po] +[198Po] +[197Po] +[196Po] +[195Po] +[194Po] +[193Po] +[192Po] +[191Po] +[190Po] +[217Po] +[BiH4-] +[TeH4] +[222Ra] +[62Ga] +[39Ar] +[144Sm] +[58Fe] +[153Eu] +[85Rb] +[171Yb] +[172Yb] +[114Cd] +[51Fe] +[142Ce] +[207Tl] +[92Mo] +[115Sn] +[140Ce] +[202Hg] +[180W] +[182W] +[183W] +[184W] +[96Mo] +[47Ti] +[111Cd] +[143Nd] +[145Nd] +[126Te] +[128Te] +[130Te] +[185Re] +[97Mo] +[98Mo] +[183Re] +[52V] +[80Se] +[87Kr] +[137Xe] +[196Au] +[146Ce] +[88Kr] +[51Ti] +[138Xe] +[112Cd] +[116Sn] +[120Sn] +[28SiH3] +[35S-] +[15NH-] +[13CH3+] +[34S+] +[34s] +[SiH4-] +[100Tc+5] +[NiH2+2] +[239Th] +[186Lu] +[AuH3] +[I@@-] +[XeH2] +[B+] +[16CH2] +[8C] +[TaH5] +[FeH4-] +[19C@H] +[10NH] +[FeH6-3] +[22CH] +[25N] +[25N+] +[25N-] +[21CH2] +[18cH] +[113I] +[ScH3] +[30PH3] +[43Ca+2] +[41Ca+2] +[106Cd] +[122Sn] +[18CH3] +[58Co+3] +[98Tc+4] +[70Ge] +[76Ge] +[108Cd] +[116Cd] +[130Xe] +[94Mo] +[124Sn] +[186Os] +[188Os] +[190Os] +[192Os] +[106Pd] +[110Pd] +[120Te] +[132Ba] +[134Ba] +[136Ba] +[136Ce] +[138Ce] +[156Dy] +[158Dy] +[160Dy] +[163Dy] +[162Er] +[164Er] +[167Er] +[176Hf] +[26Mg] +[144Nd] +[150Nd] +[41K] +[46Ti] +[48Ti] +[49Ti] +[50Ti] +[170Yb] +[173Yb] +[91Zr] +[92Zr] +[96Zr] +[34S-] +[CuH2-] +[38Cl] +[25Mg] +[51V] +[93Nb] +[95Mo] +[45Sc] +[123Sb] +[139La] +[9Be] +[99Y+3] +[99Y] +[156Ho] +[67Zn] +[144Ce+4] +[210Tl] +[42Ca] +[54Fe] +[193Ir] +[92Nb] +[141Cs] +[52Cr] +[35ClH] +[46Ca] +[139Cs] +[65Cu] +[71Ga] +[60Ni] +[16NH3] +[148Nd] +[72Ge] +[161Dy] +[49Ca] +[43Ca] +[8Be] +[48Ca] +[44Ca] +[120Xe] +[80Rb] +[215At] +[180Re] +[146Sm] +[19Ne] +[74Kr] +[134La] +[76Kr] +[219Fr] +[121Xe] +[220Fr] +[216At] +[223Ac] +[218At] +[37Ar] +[135I] +[110Cd] +[94Tc+7] +[86Y+3] +[135I-] +[15O-2] +[151Eu+3] +[161Tb+3] +[197Hg+2] +[109Cd+2] +[191Os+4] +[170Tm+3] +[205Bi+3] +[233U+4] +[126Sb+3] +[127Sb+3] +[132Cs+] +[136Eu+3] +[136Eu] +[125Sn+4] +[175Yb+3] +[100Mo] +[22Ne] +[13c-] +[13NH4+] +[17C] +[9C] +[31S] +[31SH] +[133I] +[126I] +[36SH] +[30S] +[32SH] +[19CH2] +[19c] +[18c] +[15F] +[10C] +[RuH-] +[62Zn+2] +[32ClH] +[33ClH] +[78BrH] +[12Li+] +[12Li] +[233Ra] +[68Ge+4] +[44Sc+3] +[91Y+3] +[106Ru+3] +[PoH2] +[AtH] +[55Fe] +[233U] +[210PoH2] +[230Th] +[228Th] +[222Rn] +[35SH2] +[227Th] +[192Ir] +[133Xe] +[81Kr] +[95Zr] +[240Pu] +[54Mn] +[103Ru] +[95Nb] +[109Cd] +[141Ce] +[85Kr] +[110Ag] +[58Co] +[241Pu] +[234Th] +[140La] +[63Ni] +[152Eu] +[132IH] +[226Rn] +[154Eu] +[36ClH] +[228Ac] +[155Eu] +[106Rh] +[243Am] +[227Ac] +[243Cm] +[236U] +[144Pr] +[232U] +[32SH2] +[88Y] +[82BrH] +[135IH] +[242Cm] +[115Cd] +[242Pu] +[46Sc] +[56Mn] +[234Pa] +[41Ar] +[147Nd] +[187W] +[151Sm] +[59Ni] +[233Pa] +[52Mn] +[94Nb] +[219Rn] +[236Pu] +[13NH3] +[93Zr] +[51Cr+6] +[TlH3] +[123Xe] +[160Tb] +[170Tm] +[182Ta] +[175Yb] +[93Mo] +[143Ce] +[191Os] +[126IH] +[48V] +[113Cd] +[47Sc] +[181Hf] +[185W] +[143Pr] +[191Pt] +[181W] +[33PH3] +[97Ru] +[97Tc] +[111Ag] +[169Er] +[107Pd] +[103Ru+2] +[34SH2] +[137Ce] +[242Am] +[117SnH2] +[57Ni] +[239U] +[60Cu] +[250Cf] +[193Au] +[69Zn] +[55Co] +[139Ce] +[127Xe] +[159Gd] +[56Co] +[177Hf] +[244Pu] +[38ClH] +[142Pr] +[199Hg] +[179Hf] +[178Hf] +[237U] +[156Eu] +[157Eu] +[105Ru] +[171Tm] +[199Au] +[155Sm] +[80BrH] +[108Ag] +[128IH] +[48Sc] +[45Ti] +[176Lu] +[121SnH2] +[148Pm] +[57Fe] +[10BH3] +[96Tc] +[133IH] +[143Pm] +[105Rh] +[130IH] +[134IH] +[131IH] +[71Zn] +[105Ag] +[97Zr] +[235Pu] +[231Th] +[109Pd] +[93Y] +[190Ir] +[135Xe] +[53Mn] +[134Ce] +[234Np] +[240Am] +[246Cf] +[240Cm] +[241Cm] +[226Th] +[39ClH] +[229Th] +[245Cm] +[240U] +[240Np] +[249Cm] +[243Pu] +[145Pm] +[199Pt] +[246Bk] +[193Pt] +[230U] +[250Cm] +[44Ti] +[175Hf] +[254Fm] +[255Fm] +[257Fm] +[92Y] +[188Ir] +[171Lu] +[257Md] +[247Bk] +[121IH] +[250Bk] +[179Lu] +[224Ac] +[195Hg] +[244Am] +[246Pu] +[194Au] +[252Fm] +[173Hf] +[246Cm] +[135Ce] +[49Cr] +[248Cf] +[247Cm] +[248Cm] +[174Ta] +[176Ta] +[154Tb] +[172Ta] +[177Ta] +[175Ta] +[180Ta] +[158Tb] +[115Ag] +[189Os] +[251Cf] +[145Pr] +[147Pr] +[76BrH] +[102Rh] +[238Np] +[185Os] +[246Am] +[233Np] +[166Dy] +[254Es] +[244Cf] +[193Os] +[245Am] +[245Bk] +[239Am] +[238Am] +[97Nb] +[245Pu] +[254Cf] +[188W] +[250Es] +[251Es] +[237Am] +[182Hf] +[258Md] +[232Np] +[238Cm] +[60Fe] +[109Pd+2] +[234Pu] +[141Ce+3] +[136Nd] +[136Pr] +[173Ta] +[110Ru] +[147Tb] +[253Fm] +[139Nd] +[178Re] +[177Re] +[200Au] +[182Re] +[156Tb] +[155Tb] +[157Tb] +[161Tb] +[161Ho] +[167Tm] +[173Lu] +[179Ta] +[171Er] +[44Sc] +[49Sc] +[49V] +[51Mn] +[90Nb] +[88Nb] +[88Zr] +[36SH2] +[174Yb] +[178Lu] +[179W] +[83BrH] +[107Cd] +[75BrH] +[62Co] +[48Cr] +[63Zn] +[102Ag] +[154Sm] +[168Er] +[65Ni] +[137La] +[187Ir] +[144Pm] +[146Pm] +[160Gd] +[166Yb] +[162Dy] +[47V] +[141Nd] +[141Sm] +[166Er] +[150Sm] +[146Eu] +[149Eu] +[174Lu] +[17NH3] +[102Ru] +[170Hf] +[188Pt] +[61Ni] +[56Ni] +[149Gd] +[151Gd] +[141Pm] +[147Gd] +[146Gd] +[161Er] +[103Ag] +[145Eu] +[153Tb] +[155Dy] +[184Re] +[180Os] +[182Os] +[186Pt] +[181Os] +[181Re] +[151Tb] +[178Ta] +[178W] +[189Pt] +[194Hg] +[145Sm] +[150Tb] +[132La] +[158Gd] +[104Ag] +[193Hg] +[94Ru] +[137Pr] +[155Ho] +[117Cd] +[99Ru] +[146Nd] +[218Rn] +[95Y] +[79Kr] +[120IH] +[138Pr] +[100Pd] +[166Tm] +[90Mo] +[151Nd] +[231U] +[138Nd] +[89Nb] +[98Nb] +[162Ho] +[142Sm] +[186Ta] +[104Tc] +[184Ta] +[185Ta] +[170Er] +[107Rh] +[131La] +[169Lu] +[74BrH] +[150Pm] +[172Tm] +[197Pt] +[230Pu] +[170Lu] +[86Zr] +[176W] +[177W] +[101Pd] +[105Pd] +[108Pd] +[149Nd] +[164Ho] +[159Ho] +[167Ho] +[176Yb] +[156Sm] +[77BrH] +[189Re] +[99Rh] +[100Rh] +[151Pm] +[232Pa] +[228Pa] +[230Pa] +[66Ni] +[194Os] +[135La] +[138La] +[141La] +[142La] +[195Ir] +[96Nb] +[157Ho] +[183Hf] +[162Tm] +[172Er] +[148Eu] +[150Eu] +[15CH4] +[89Kr] +[143La] +[58Ni] +[61Co] +[158Eu] +[165Er] +[167Yb] +[173Tm] +[175Tm] +[172Hf] +[172Lu] +[93Tc] +[177Yb] +[124IH] +[194Ir] +[147Eu] +[101Mo] +[180Hf] +[189Ir] +[87Y] +[43Sc] +[195Au] +[112Ag] +[84BrH] +[106Ag] +[109Ag] +[101Rh] +[162Yb] +[228Rn] +[139Pr] +[94Y] +[201Au] +[40PH3] +[110Ag+] +[104Cd] +[133Ba+2] +[226Ac] +[145Gd] +[186Ir] +[184Ir] +[224Rn] +[185Ir] +[182Ir] +[184Hf] +[200Pt] +[227Pa] +[178Yb] +[72Br-] +[72BrH] +[248Am] +[238Th] +[161Gd] +[35S-2] +[107Ag] +[FeH6-4] +[89Sr] +[SnH3-] +[SeH3] +[TeH3+] +[SbH4+] +[AsH4+] +[4He] +[AsH3-] +[1HH] +[3H+] +[82Rb] +[85Sr] +[90Sr] +[137Cs] +[133Ba] +[131Cs] +[SbH5] +[224Ra] +[22Na] +[210Bi] +[214Bi] +[228Ra] +[127Sb] +[136Cs] +[125Sb] +[134Cs] +[140Ba] +[45Ca] +[206Pb] +[207Pb] +[24Na] +[86Rb] +[212Bi] +[208Pb] +[124Sb] +[204Pb] +[44K] +[129Te] +[113Sn] +[204Tl] +[87Sr] +[208Tl] +[87Rb] +[47Ca] +[135Cs] +[216Po] +[137Ba] +[207Bi] +[212Po] +[79Se] +[223Ra] +[86Sr] +[122Sb] +[26Al] +[32Si] +[126Sn] +[225Ra] +[114In] +[72Ga] +[132Te] +[10Be] +[125Sn] +[73As] +[206Bi] +[117Sn] +[40Ca] +[41Ca] +[89Rb] +[116In] +[129Sb] +[91Sr] +[71Ge] +[139Ba] +[69Ga] +[120Sb] +[121Sn] +[123Sn] +[131Te] +[77Ge] +[135Ba] +[82Sr] +[43K] +[131Ba] +[92Sr] +[88Rb] +[129Cs] +[144Cs] +[127Cs] +[200Tl] +[202Tl] +[141Ba] +[117Sb] +[116Sb] +[78As] +[131Sb] +[126Sb] +[128Sb] +[130Sb] +[67Ge] +[68Ge] +[78Ge] +[66Ge] +[223Fr] +[132Cs] +[125Cs] +[138Cs] +[133Te] +[84Rb] +[83Rb] +[81Rb] +[142Ba] +[200Bi] +[115Sb] +[194Tl] +[70Se] +[112In] +[118Sb] +[70Ga] +[27Mg] +[202Bi] +[83Se] +[9Li] +[69As] +[79Rb] +[81Sr] +[83Sr] +[78Se] +[109In] +[29Al] +[118Sn] +[117In] +[119Sb] +[114Sn] +[138Ba] +[69Ge] +[73Ga] +[74Ge] +[206Tl] +[199Tl] +[130Cs] +[28Mg] +[116Te] +[112Sn] +[126Ba] +[211Bi] +[81Se] +[127Sn] +[143Cs] +[134Te] +[80Sr] +[45K] +[215Po] +[207Po] +[111Sn] +[211Po] +[128Ba] +[198Tl] +[227Ra] +[213Po] +[220Ra] +[128Sn] +[203Po] +[205Po] +[65Ga] +[197Tl] +[88Sr] +[110In] +[31Si] +[201Bi] +[121Te] +[205Bi] +[203Bi] +[195Tl] +[209Tl] +[110Sn] +[222Fr] +[207At] +[119In] +[As@] +[129IH] +[157Dy] +[111IH] +[230Ra] +[144Pr+3] +[SiH3+] +[3He] +[AsH5] +[72Se] +[95Tc] +[103Pd] +[121Sn+2] +[211Rn] +[38SH2] +[127IH] +[74Br-] +[133I-] +[100Tc+4] +[100Tc] +[36Cl-] +[89Y+3] +[104Rh] +[152Sm] +[226Ra] +[19FH] +[104Pd] +[148Gd] +[157Lu] +[33SH2] +[121I-] +[17FH] +[71Se] +[157Sm] +[148Tb] +[164Dy] +[15OH2] +[15O+] +[39K] +[40Ar] +[50Cr+3] +[50Cr] +[52Ti] +[103Pd+2] +[130Ba] +[142Pm] +[153Gd+3] +[151Eu] +[103Rh] +[124Xe] +[152Tb] +[17OH2] +[20Ne] +[52Fe] +[94Zr+4] +[94Zr] +[149Pr] +[16OH2] +[53Cr+6] +[53Cr] +[81Br-] +[112Pd] +[125Xe] +[155Gd] +[157Gd] +[168Yb] +[184Os] +[166Tb] +[221Fr] +[212Ra] +[75Br-] +[79Br-] +[113Ag] +[23Na] +[34Cl-] +[34ClH] +[38Cl-] +[56Fe] +[68Cu] +[77Br-] +[90Zr+4] +[90Zr] +[102Pd] +[154Eu+3] +[57Mn] +[165Tm] +[152Dy] +[217At] +[77se] +[13cH-] +[122Te] +[156Gd] +[124Te] +[53Ni] +[131Xe] +[174Hf+4] +[174Hf] +[76Se] +[168Tm] +[167Dy] +[154Gd] +[95Ru] +[210At] +[85Br] +[59Co] +[122Xe] +[27Al] +[54Cr] +[198Hg] +[85Rb+] +[214Tl] +[229Rn] +[218Pb] +[218Bi] +[167Tm+3] +[18o+] +[P@@H+] +[P@H+] +[13N+] +[212Pb+2] +[217Bi] +[249Cf+2] +[18OH3+] +[90Sr-] +[Cf+3] +[200Hg] +[86Tc] +[141Pr+3] +[141Pr] +[16nH] +[14NH4+] +[132Xe] +[83Kr] +[70Zn+2] +[137Ba+2] +[36Ar] +[38Ar] +[21Ne] +[126Xe] +[136Xe] +[128Xe] +[134Xe] +[84Kr] +[86Kr] +[78Kr] +[80Kr] +[82Kr] +[67Zn+2] +[65Cu+2] +[110Te] +[58Fe+3] +[142Nd] +[38K] +[198Au+3] +[122IH] +[38PH3] +[130I-] +[40K+] +[38K+] +[28Mg+2] +[208Tl+] +[13OH2] +[198Bi] +[192Bi] +[194Bi] +[196Bi] +[132I-] +[83Sr+2] +[169Er+3] +[122I-] +[120I-] +[92Sr+2] +[126I-] +[24Mg] +[84Sr] +[118Pd+2] +[118Pd] +[AsH4] +[127I-] +[9C-] +[11CH3+] +[17B] +[7B] +[4HH] +[18C-] +[22CH3-] +[22CH4] +[17C-] +[15CH3] +[16CH3] +[11NH3] +[21NH3] +[11N-] +[11NH] +[16CH] +[17CH2] +[99Ru+2] +[181Ta+2] +[181Ta] +[20CH] +[32PH2] +[55Fe+2] +[SH3] +[S@H] +[Mn-] +[IH4] +[ThH] +[GaH-] +[BiH+] +[EuH2] +[FeH4-3] +[FeH6] +[IH5] +[NiH+] +[SrH2] +[VH4] +[YH3] +[seH+] +<unk> diff --git a/models/smi_ted/training/pubchem_canon_script.py b/models/smi_ted/training/pubchem_canon_script.py new file mode 100644 index 0000000000000000000000000000000000000000..a26146bd42fc365db5d534226d77499b5e960e9c --- /dev/null +++ b/models/smi_ted/training/pubchem_canon_script.py @@ -0,0 +1,71 @@ +import logging +from dataclasses import dataclass +import pyarrow as pa + +import datasets + + +logger = logging.getLogger(__name__) + + +FEATURES = datasets.Features( + { + "text": datasets.Value("string"), + } +) + + +@dataclass +class PubChemConfig(datasets.BuilderConfig): + """BuilderConfig for text files.""" + + encoding: str = "utf-8" + chunksize: int = 10 << 20 # 10MB + + +class PubChem(datasets.ArrowBasedBuilder): + + BUILDER_CONFIG_CLASS = PubChemConfig + + def _info(self): + return datasets.DatasetInfo(features=FEATURES) + + def _split_generators(self, dl_manager): + """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. + + If str or List[str], then the dataset returns only the 'train' split. + If dict, then keys should be from the `datasets.Split` enum. + """ + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + data_files = dl_manager.download_and_extract(self.config.data_files) + if isinstance(data_files, (str, list, tuple)): + files = data_files + if isinstance(files, str): + files = [files] + return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _generate_tables(self, files): + + for file_idx, file in enumerate(files): + batch_idx = 0 + with open(file, "r", encoding=self.config.encoding) as f: + while True: + batch = f.read(self.config.chunksize) + if not batch: + break + batch += f.readline() # finish current line + batch = batch.splitlines() + #batch = [word.split()[-1] for word in batch] + pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()})) + # Uncomment for debugging (will print the Arrow table size and elements) + #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") + #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) + yield (file_idx, batch_idx), pa_table + batch_idx += 1 diff --git a/models/smi_ted/training/pubchem_canon_script.py.lock b/models/smi_ted/training/pubchem_canon_script.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/smi_ted/training/pubchem_canon_zinc_final_vocab_sorted_curated.pth b/models/smi_ted/training/pubchem_canon_zinc_final_vocab_sorted_curated.pth new file mode 100644 index 0000000000000000000000000000000000000000..ab685a3a5a7c95a8900dc186d7d72fae5cb73bc8 --- /dev/null +++ b/models/smi_ted/training/pubchem_canon_zinc_final_vocab_sorted_curated.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc71f36557571ecf91d3e82a917113692974b0ef9dfc73bcfaaf0e2c080eaf09 +size 43635 diff --git a/models/smi_ted/training/pubchem_encoder.py b/models/smi_ted/training/pubchem_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea38ea90fbfe9e69a93fbbadccd9180db31342b --- /dev/null +++ b/models/smi_ted/training/pubchem_encoder.py @@ -0,0 +1,235 @@ +import regex as re +import torch +import numpy as np +import random +import collections + +class Encoder(): + + def __init__(self, max_length=500, add_bos=True, add_eos=True, feature_size=32): + self.vocab_encoder = torch.load('pubchem_canon_zinc_final_vocab_sorted_curated.pth') + + self.max_length = max_length + self.min_length = 1 + self.mod_length = 42 + self.mlm_probability = .15 + self.avg_length = 66 + self.tail = 122 + self.b0_cache=collections.deque() + self.b1_cache=collections.deque() + self.b2_cache=collections.deque() + self.b3_cache=collections.deque() + self.bucket0=collections.deque() + self.bucket1=collections.deque() + self.bucket2=collections.deque() + self.bucket3=collections.deque() + if feature_size == 32: + self.b0_max=1100 + self.b1_max=700 + self.b2_max=150 + self.b3_max=50 + else: + self.b0_max=1382 + self.b1_max=871 + self.b2_max=516 + self.b3_max=311 + values = list(self.vocab_encoder.values()) + num_top = 0 + middle_top = 0 + bottom = 0 + for count in values: + if count > 100000: + num_top += 1 + if count > 50: + middle_top += 1 + middle_top = middle_top - num_top + self.cutoffs = [num_top+4, middle_top] + self.char2id = {"<bos>":0, "<eos>":1, "<pad>":2, "<mask>":3} + self.id2char = {0:"<bos>", 1:"<eos>", 2:"<pad>", 3:"<mask>"} + self.pad = self.char2id['<pad>'] + self.mask = self.char2id['<mask>'] + self.eos = self.char2id['<eos>'] + self.bos = self.char2id['<bos>'] + pos = 0 + for key, value in self.vocab_encoder.items(): + #for pos, key in enumerate(self.vocab_encoder.keys()): + self.char2id[key] = pos+4 + self.id2char[pos+4] = key + pos += 1 + self.char2id["<unk>"] = pos + 4 + self.id2char[pos+4] = "<unk>" + self.pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" + self.regex = re.compile(self.pattern) + self.add_bos = add_bos + self.add_eos = add_eos + #print(self.char2id) + + def encode(self, char): + #if len(char) > self.max_length: + # char = char[:self.max_length] + if self.add_bos == True: + char = ['<bos>'] + char + if self.add_eos == True: + char = char + ['<eos>'] + + return torch.tensor([self.char2id.get(word, self.char2id["<unk>"]) for word in char]) + + def encoder(self, tokens): + #return *map(lambda x: self.encode(x), tokens) + return [self.encode(mol) for mol in tokens] + + def process_text(self, text): + #print(text) + #random length sequences seems to help training + mod_length = self.mod_length #+ random.randint(-1, 3) + avg_length = self.avg_length #+ random.randint(-3, 5) + for mol in text: + #fill up buckets and caches + if '\n' in mol['text']: + print('carriage return in mol') + raw_regex = self.regex.findall(mol['text'].strip('\n')) + length = len(raw_regex) + if length > self.min_length and length < mod_length: + if len(self.bucket0) < self.b0_max: + self.bucket0.append(raw_regex) + else: + self.b0_cache.append(raw_regex) + elif length >= mod_length and length < avg_length: + if len(self.bucket1) < self.b1_max: + self.bucket1.append(raw_regex) + else: + self.b1_cache.append(raw_regex) + elif length >= avg_length and length < self.tail: + if len(self.bucket2) < self.b2_max: + self.bucket2.append(raw_regex) + else: + self.b2_cache.append(raw_regex) + elif length >= self.tail and length < self.max_length: + if len(self.bucket3) < self.b3_max: + self.bucket3.append(raw_regex) + else: + self.b3_cache.append(raw_regex) + # elif length >= avg_length and length < self.tail: + # self.b2_cache.append(raw_regex) + # #if len(bucket2) < self.b2_max: + # # bucket2.append(raw_regex) + # #else: + # # self.b2_cache.append(raw_regex) + # elif length >= self.tail and length < self.max_length: + # self.b3_cache.append(raw_regex) + # #if len(bucket3) < self.b3_max: + # # bucket3.append(raw_regex) + # #else: + # # self.b3_cache.append(raw_regex) + + #print('before Cache size {} {} {} {}'.format(len(self.b0_cache), len(self.b1_cache), len(self.b2_cache), len(self.b3_cache))) + #pour cache elements into any open bucket + if len(self.bucket0) < self.b0_max and len(self.b0_cache) > 0: + cache_size = len(self.b0_cache) + max_margin = self.b0_max-len(self.bucket0) + range0 = min(cache_size, max_margin) + outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + [self.b0_cache.pop() for i in range(range0)] + #self.b0_cache = collections.deque(self.b0_cache[:self.b0_max-len(bucket0)]) + #print('0 type {}'.format(type(self.b0_cache))) + else: + outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + + if len(self.bucket1) < self.b1_max and len(self.b1_cache) > 0: + cache_size = len(self.b1_cache) + max_margin = self.b1_max-len(self.bucket1) + range1 = min(cache_size, max_margin) + outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + [self.b1_cache.pop() for i in range(range1)] + else: + outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + + if len(self.bucket2) < self.b2_max and len(self.b2_cache) > 0: + cache_size = len(self.b2_cache) + max_margin = self.b2_max-len(self.bucket2) + range2 = min(cache_size, max_margin) + outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] + [self.b2_cache.pop() for i in range(range2)] + else: + outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] + + if len(self.bucket3) < self.b3_max and len(self.b3_cache) > 0: + cache_size = len(self.b3_cache) + max_margin = self.b3_max-len(self.bucket3) + range3 = min(cache_size, max_margin) + outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] + [self.b3_cache.pop() for i in range(range3)] + else: + outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] + + # if len(self.b2_cache) > self.b2_max: + # cache_size = len(self.b2_cache) + # max_margin = self.b2_max + # range2 = min(cache_size, max_margin) + # outbucket2 = [self.b2_cache.pop() for i in range(range2)] + # else: + # outbucket2=[] + + # if len(self.b3_cache) > self.b3_max: + # cache_size = len(self.b3_cache) + # max_margin = self.b3_max + # range3 = min(cache_size, max_margin) + # outbucket3 = [self.b3_cache.pop() for i in range(range3)] + # else: + # outbucket3 = [] + + return outbucket0, outbucket1, outbucket2, outbucket3 + + def mask_tokens( self, inputs, special_tokens_mask= None): + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + """ + labels = inputs.clone() + # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) + probability_matrix = torch.full(labels.size(), self.mlm_probability) + if special_tokens_mask is None: + special_tokens_mask = [ + self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + ] + special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) + else: + special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) + #special_tokens_mask = special_tokens_mask.bool() + + #print(special_tokens_mask.size()) + probability_matrix.masked_fill_(special_tokens_mask, value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.size(), 0.8)).bool() & masked_indices + inputs[indices_replaced] = self.mask + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.size(), 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(self.char2id.keys()), labels.size(), dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + def pack_tensors(self, tokens): + array_ids = self.encoder(tokens) + array = torch.nn.utils.rnn.pad_sequence(array_ids, batch_first=True, padding_value=self.pad) + lengths = (array!=self.pad).sum(dim=-1) + #Bert tokenization + special_token_mask = [list(map(lambda x: 1 if x in [self.bos, self.eos, self.pad] else 0, stuff)) for stuff in array.tolist()] + masked_array, masked_labels = self.mask_tokens(array, special_token_mask) + return masked_array, masked_labels, array_ids, lengths + def process(self, text): + arrays = [] + lengths = [] + targets = [] + arrays_ids = [] + for tokens in self.process_text(text): + if len(tokens) > 0: + array, target, array_ids, lgt = self.pack_tensors(tokens) + arrays.append(array) + targets.append(target) + arrays_ids.append(array_ids) + lengths.append(lgt) + return arrays, targets, arrays_ids, lengths + +if __name__ == '__main__': + + text_encoder = Encoder() diff --git a/models/smi_ted/training/pubchem_script.py b/models/smi_ted/training/pubchem_script.py new file mode 100644 index 0000000000000000000000000000000000000000..2164c3b7abf504a145f669fdb5a0d13fef25a14d --- /dev/null +++ b/models/smi_ted/training/pubchem_script.py @@ -0,0 +1,71 @@ +import logging +from dataclasses import dataclass +import pyarrow as pa + +import datasets + + +logger = logging.getLogger(__name__) + + +FEATURES = datasets.Features( + { + "text": datasets.Value("string"), + } +) + + +@dataclass +class PubChemConfig(datasets.BuilderConfig): + """BuilderConfig for text files.""" + + encoding: str = "utf-8" + chunksize: int = 10 << 20 # 10MB + + +class PubChem(datasets.ArrowBasedBuilder): + + BUILDER_CONFIG_CLASS = PubChemConfig + + def _info(self): + return datasets.DatasetInfo(features=FEATURES) + + def _split_generators(self, dl_manager): + """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. + + If str or List[str], then the dataset returns only the 'train' split. + If dict, then keys should be from the `datasets.Split` enum. + """ + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + data_files = dl_manager.download_and_extract(self.config.data_files) + if isinstance(data_files, (str, list, tuple)): + files = data_files + if isinstance(files, str): + files = [files] + return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _generate_tables(self, files): + + for file_idx, file in enumerate(files): + batch_idx = 0 + with open(file, "r", encoding=self.config.encoding) as f: + while True: + batch = f.read(self.config.chunksize) + if not batch: + break + batch += f.readline() # finish current line + batch = batch.splitlines() + batch = [word.split()[-1] for word in batch] + pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()})) + # Uncomment for debugging (will print the Arrow table size and elements) + #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") + #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) + yield (file_idx, batch_idx), pa_table + batch_idx += 1 diff --git a/models/smi_ted/training/pubchem_script.py.lock b/models/smi_ted/training/pubchem_script.py.lock new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/smi_ted/training/run_model_large_training.sh b/models/smi_ted/training/run_model_large_training.sh new file mode 100644 index 0000000000000000000000000000000000000000..6ccd9eebe2015a6ebed6e5ca485b9966b9baadfd --- /dev/null +++ b/models/smi_ted/training/run_model_large_training.sh @@ -0,0 +1,32 @@ +#!/bin/bash +torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=1 \ + train_model_D.py \ + --device cuda \ + --n_batch 48 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.2 \ + --lr_start 3e-5 \ + --lr_multiplier 4 \ + --lr_decoder 3e-5 \ + --n_workers 1 \ + --max_epochs 51 \ + --gpu -1 \ + --num_nodes 1 \ + --num_feats 32 \ + --root_dir . \ + --checkpoint_every 10000 \ + --grad_acc 1 \ + --train_load 'pubchem' \ + --smi_ted_version 'v2' \ + --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \ + --save_checkpoint_path './large_checkpoints' \ + --load_checkpoint_path '' \ + --rotate \ + --debug \ + --model_arch 'BERT__both_rotate' \ \ No newline at end of file diff --git a/models/smi_ted/training/run_model_light_training.sh b/models/smi_ted/training/run_model_light_training.sh new file mode 100644 index 0000000000000000000000000000000000000000..6abb952762608c842c31a711ee3551be38b33966 --- /dev/null +++ b/models/smi_ted/training/run_model_light_training.sh @@ -0,0 +1,32 @@ +#!/bin/bash +torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=1 \ + train_model_ED.py \ + --device cuda \ + --n_batch 288 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.2 \ + --lr_start 3e-5 \ + --lr_multiplier 4 \ + --lr_decoder 3e-5 \ + --n_workers 1 \ + --max_epochs 51 \ + --gpu -1 \ + --num_nodes 1 \ + --num_feats 32 \ + --root_dir . \ + --checkpoint_every 10000 \ + --grad_acc 1 \ + --train_load 'pubchem' \ + --smi_ted_version 'v1' \ + --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \ + --save_checkpoint_path './light_checkpoints' \ + --load_checkpoint_path '' \ + --rotate \ + --debug \ + --model_arch 'BERT__both_rotate' \ \ No newline at end of file diff --git a/models/smi_ted/training/send_job_large.slurm b/models/smi_ted/training/send_job_large.slurm new file mode 100644 index 0000000000000000000000000000000000000000..ede940a075d76f3a23a3dfae0865e649db008417 --- /dev/null +++ b/models/smi_ted/training/send_job_large.slurm @@ -0,0 +1,60 @@ +#!/bin/bash + +# Example of running python script in a batch mode + +#SBATCH -J smi-ted-train +#SBATCH -t 30:00:00 +#SBATCH -o output_smi_ted_large_epoch40_%j.out +#SBATCH --mem=64G +#SBATCH --nodes=10 +#SBATCH --ntasks=10 +#SBATCH --gpus-per-task=5 +#SBATCH --cpus-per-task=20 + +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +echo Node IP: $head_node_ip +export LOGLEVEL=INFO + +# Load software +# module load anaconda3 +source /home/.bashrc +conda activate smi-ted-env + +# Run python script +srun torchrun \ + --nnodes 10 \ + --nproc_per_node 5 \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node_ip:29500 \ + train_model_D.py \ + --device cuda \ + --n_batch 48 \ + --n_layer 24 \ + --n_head 16 \ + --n_embd 1024 \ + --max_len 202 \ + --d_dropout 0.2 \ + --lr_start 3e-5 \ + --lr_multiplier 4 \ + --lr_decoder 3e-5 \ + --n_workers 20 \ + --max_epochs 51 \ + --gpu -1 \ + --num_nodes 1 \ + --num_feats 32 \ + --root_dir . \ + --checkpoint_every 10000 \ + --grad_acc 1 \ + --train_load 'pubchem' \ + --smi_ted_version 'v2' \ + --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \ + --save_checkpoint_path './large_checkpoints' \ + --load_checkpoint_path '' \ + --rotate \ + --debug \ + --model_arch 'BERT__both_rotate' \ \ No newline at end of file diff --git a/models/smi_ted/training/send_job_light.slurm b/models/smi_ted/training/send_job_light.slurm new file mode 100644 index 0000000000000000000000000000000000000000..5fb4f24f99ce91c4f618fcb0312272669e506fea --- /dev/null +++ b/models/smi_ted/training/send_job_light.slurm @@ -0,0 +1,60 @@ +#!/bin/bash + +# Example of running python script in a batch mode + +#SBATCH -J smi-ted-train +#SBATCH -t 6:00:00 +#SBATCH -o output_smi_ted_light_epoch50_%j.out +#SBATCH --mem=64G +#SBATCH --nodes=6 +#SBATCH --ntasks=6 +#SBATCH --gpus-per-task=4 +#SBATCH --cpus-per-task=12 + +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +echo Node IP: $head_node_ip +export LOGLEVEL=INFO + +# Load software +# module load anaconda3 +source /home/.bashrc +conda activate smi-ted-env + +# Run python script +srun torchrun \ + --nnodes 6 \ + --nproc_per_node 4 \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node_ip:29500 \ + train_model_ED.py \ + --device cuda \ + --n_batch 288 \ + --n_layer 12 \ + --n_head 12 \ + --n_embd 768 \ + --max_len 202 \ + --d_dropout 0.2 \ + --lr_start 3e-5 \ + --lr_multiplier 4 \ + --lr_decoder 3e-5 \ + --n_workers 12 \ + --max_epochs 51 \ + --gpu -1 \ + --num_nodes 1 \ + --num_feats 32 \ + --root_dir . \ + --checkpoint_every 10000 \ + --grad_acc 1 \ + --train_load 'pubchem' \ + --smi_ted_version 'v1' \ + --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \ + --save_checkpoint_path './light_checkpoints' \ + --load_checkpoint_path '' \ + --rotate \ + --debug \ + --model_arch 'BERT__both_rotate' \ \ No newline at end of file diff --git a/models/smi_ted/training/smi_ted_large/load.py b/models/smi_ted/training/smi_ted_large/load.py new file mode 100644 index 0000000000000000000000000000000000000000..febd043c4d4fb5fc1664e13315d0e758365a383f --- /dev/null +++ b/models/smi_ted/training/smi_ted_large/load.py @@ -0,0 +1,382 @@ +PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +# Transformers +from fast_transformers.attention import AttentionLayer +from fast_transformers.events import EventDispatcher, QKVEvent +from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer +from fast_transformers.builders.base import BaseBuilder +from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder +from fast_transformers.builders.attention_builders import AttentionBuilder +from fast_transformers.feature_maps import GeneralizedRandomFeatures +from fast_transformers.masking import LengthMask + +from transformers import BertTokenizer + +# Data +import numpy as np + +# Standard library +from functools import partial +import regex as re +import random + + +class MolTranBertTokenizer(BertTokenizer): + def __init__(self, vocab_file: str = '', + do_lower_case=False, + unk_token='<pad>', + sep_token='<eos>', + pad_token='<pad>', + cls_token='<bos>', + mask_token='<mask>', + **kwargs): + super().__init__(vocab_file, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + self.regex_tokenizer = re.compile(PATTERN) + self.wordpiece_tokenizer = None + self.basic_tokenizer = None + + def _tokenize(self, text): + split_tokens = self.regex_tokenizer.findall(text) + return split_tokens + + def convert_idx_to_tokens(self, idx_tensor): + tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()] + return tokens + + def convert_tokens_to_string(self, tokens): + stopwords = ['<bos>', '<eos>'] + clean_tokens = [word for word in tokens if word not in stopwords] + out_string = ''.join(clean_tokens) + return out_string + +## Transformer layers + +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = 0 + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self.cos_cached = emb.cos()[None,:, None, :] + self.sin_cached = emb.sin()[None,:, None, :] + + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotateAttentionLayer(AttentionLayer): + """Rotate attention layer inherits from fast_transformer attention layer. + The only thing added is an Embedding encoding, for more information + on the attention layer see the fast_transformers code + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, + d_values=d_values, event_dispatcher=event_dispatcher) + + self.rotaryemb = RotaryEmbedding(d_keys) + print('Using Rotation Embedding') + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """ + Using the same frame work as the fast_Transformers attention layer + but injecting rotary information to the queries and the keys + after the keys and queries are projected. + In the argument description we make use of the following sizes + - N: the batch size + - L: The maximum length of the queries + - S: The maximum length of the keys (the actual length per sequence + is given by the length mask) + - D: The input feature dimensionality passed in the constructor as + 'd_model' + Arguments + --------- + queries: (N, L, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + Returns + ------- + The new value for each query as a tensor of shape (N, L, D). + """ + # Extract the dimensions into local variables + N, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + # Project the queries/keys/values + queries = self.query_projection(queries).view(N, L, H, -1) + keys = self.key_projection(keys).view(N, S, H, -1) + cos, sin = self.rotaryemb(queries) + queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) + values = self.value_projection(values).view(N, S, H, -1) + # Let the world know of the qkv + self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) + + + # Compute the attention + new_values = self.inner_attention( + queries, + keys, + values, + attn_mask, + query_lengths, + key_lengths + ).view(N, L, -1) + + # Project the output and return + return self.out_projection(new_values) + +class RotateEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a batch transformer encoder with Relative Rotary embeddings + for training or processing of sequences all elements at a time. + Example usage: + builder = RotateEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + return AttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + return RotateAttentionLayer + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + return TransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + return TransformerEncoderLayer + + +class AutoEncoderLayer(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.encoder = self.Encoder(feature_size, latent_size) + self.decoder = self.Decoder(feature_size, latent_size) + + class Encoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(feature_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.lat = nn.Linear(latent_size, latent_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.lat.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.lat(x) + return x # -> (N, D) + + class Decoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(latent_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.rec = nn.Linear(latent_size, feature_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.rec.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.rec(x) + return x # -> (N, L*D) + + +class LangLayer(nn.Module): + def __init__(self, n_embd, n_vocab): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.embed = nn.Linear(n_embd, n_embd) + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, n_vocab, bias=False) + def forward(self, tensor): + if self.is_cuda_available: + self.embed.cuda() + self.ln_f.cuda() + self.head.cuda() + tensor = tensor.cuda() + tensor = self.embed(tensor) + tensor = F.gelu(tensor) + tensor = self.ln_f(tensor) + tensor = self.head(tensor) + return tensor + + +class MoLEncoder(nn.Module): + + def __init__(self, config, n_vocab): + super(MoLEncoder, self).__init__() + + # embeddings + self.tok_emb = nn.Embedding(n_vocab, config.n_embd) + self.drop = nn.Dropout(config.d_dropout) + + # transformer + builder = RotateEncoderBuilder.from_kwargs( + n_layers=config.n_layer, + n_heads=config.n_head, + query_dimensions=config.n_embd//config.n_head, + value_dimensions=config.n_embd//config.n_head, + feed_forward_dimensions=None, + attention_type='linear', + # unless we do deterministic_eval here, we will have random outputs + feature_map=partial(GeneralizedRandomFeatures, + n_dims=config.num_feats, + deterministic_eval=False), + activation='gelu' + ) + self.blocks = builder.get() + + # classification + self.lang_model = LangLayer(config.n_embd, n_vocab) + + def forward(self, idx, mask=None, inference=False): + if not inference: + x = self.tok_emb(idx) # each index maps to a (learnable) vector + x = self.drop(x) + + #masking of the length of the inputs its handled in the Masked language part of the code + #do not attempt to handle it in the forward of the transformer + x = self.blocks(x) + logits = self.lang_model(x) + + return logits + else: + x = self.tok_emb(idx) # each index maps to a (learnable) vector + x = self.drop(x) + + #masking of the length of the inputs its handled in the Masked language part of the code + #do not attempt to handle it in the forward of the transformer + x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1])) + + # mean pooling + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + true_set = sum_embeddings / sum_mask + + return true_set, token_embeddings + + +class MoLDecoder(nn.Module): + + def __init__(self, n_vocab, max_len, n_embd, n_gpu=None): + super(MoLDecoder, self).__init__() + + self.max_len = max_len + self.n_embd = n_embd + self.n_gpu = n_gpu + self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd) + self.lang_model = LangLayer(n_embd, n_vocab) + + def forward(self, token_embeddings): + pred_set = self.autoencoder.encoder(token_embeddings) # (N, D) + pred_cte = self.autoencoder.decoder(pred_set) # (N, L*D) + pred_ids = self.lang_model(pred_cte.view(-1, self.max_len, self.n_embd)) + return pred_set, pred_ids + + +class Smi_ted(nn.Module): + """materials.smi-ted-Large 738M Parameters""" + + def __init__(self, config, vocab): + super(Smi_ted, self).__init__() + + self.config = config + self.padding_idx = 2 + self.is_cuda_available = torch.cuda.is_available() + n_vocab = len(vocab.keys()) + print(n_vocab, config.n_embd) + + self.encoder = MoLEncoder(config, n_vocab) + self.decoder = MoLDecoder(n_vocab, config.max_len, config.n_embd) + + self._set_seed(config.seed) + print('Vocab size:', n_vocab) + print(f'[PRE-TRAINING MODE - {str(self)}]') + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_seed(self, value): + print('Random Seed:', value) + random.seed(value) + torch.manual_seed(value) + torch.cuda.manual_seed(value) + torch.cuda.manual_seed_all(value) + np.random.seed(value) + cudnn.deterministic = True + cudnn.benchmark = False + + def __str__(self): + return 'smi-ted-Large' \ No newline at end of file diff --git a/models/smi_ted/training/smi_ted_light/load.py b/models/smi_ted/training/smi_ted_light/load.py new file mode 100644 index 0000000000000000000000000000000000000000..61598cbb1c4ff0d8f1aca5917a3dbd7cddcdf735 --- /dev/null +++ b/models/smi_ted/training/smi_ted_light/load.py @@ -0,0 +1,382 @@ +PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +# Transformers +from fast_transformers.attention import AttentionLayer +from fast_transformers.events import EventDispatcher, QKVEvent +from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer +from fast_transformers.builders.base import BaseBuilder +from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder +from fast_transformers.builders.attention_builders import AttentionBuilder +from fast_transformers.feature_maps import GeneralizedRandomFeatures +from fast_transformers.masking import LengthMask + +from transformers import BertTokenizer + +# Data +import numpy as np + +# Standard library +from functools import partial +import regex as re +import random + + +class MolTranBertTokenizer(BertTokenizer): + def __init__(self, vocab_file: str = '', + do_lower_case=False, + unk_token='<pad>', + sep_token='<eos>', + pad_token='<pad>', + cls_token='<bos>', + mask_token='<mask>', + **kwargs): + super().__init__(vocab_file, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs) + + self.regex_tokenizer = re.compile(PATTERN) + self.wordpiece_tokenizer = None + self.basic_tokenizer = None + + def _tokenize(self, text): + split_tokens = self.regex_tokenizer.findall(text) + return split_tokens + + def convert_idx_to_tokens(self, idx_tensor): + tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()] + return tokens + + def convert_tokens_to_string(self, tokens): + stopwords = ['<bos>', '<eos>'] + clean_tokens = [word for word in tokens if word not in stopwords] + out_string = ''.join(clean_tokens) + return out_string + +## Transformer layers + +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.seq_len_cached = 0 + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_dim=1): + seq_len = x.shape[seq_dim] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self.cos_cached = emb.cos()[None,:, None, :] + self.sin_cached = emb.sin()[None,:, None, :] + + return self.cos_cached, self.sin_cached + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin): + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + +class RotateAttentionLayer(AttentionLayer): + """Rotate attention layer inherits from fast_transformer attention layer. + The only thing added is an Embedding encoding, for more information + on the attention layer see the fast_transformers code + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, + d_values=d_values, event_dispatcher=event_dispatcher) + + self.rotaryemb = RotaryEmbedding(d_keys) + print('Using Rotation Embedding') + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """ + Using the same frame work as the fast_Transformers attention layer + but injecting rotary information to the queries and the keys + after the keys and queries are projected. + In the argument description we make use of the following sizes + - N: the batch size + - L: The maximum length of the queries + - S: The maximum length of the keys (the actual length per sequence + is given by the length mask) + - D: The input feature dimensionality passed in the constructor as + 'd_model' + Arguments + --------- + queries: (N, L, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + Returns + ------- + The new value for each query as a tensor of shape (N, L, D). + """ + # Extract the dimensions into local variables + N, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + # Project the queries/keys/values + queries = self.query_projection(queries).view(N, L, H, -1) + keys = self.key_projection(keys).view(N, S, H, -1) + cos, sin = self.rotaryemb(queries) + queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) + values = self.value_projection(values).view(N, S, H, -1) + # Let the world know of the qkv + self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) + + + # Compute the attention + new_values = self.inner_attention( + queries, + keys, + values, + attn_mask, + query_lengths, + key_lengths + ).view(N, L, -1) + + # Project the output and return + return self.out_projection(new_values) + +class RotateEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a batch transformer encoder with Relative Rotary embeddings + for training or processing of sequences all elements at a time. + Example usage: + builder = RotateEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + return AttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + return RotateAttentionLayer + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + return TransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + return TransformerEncoderLayer + + +class AutoEncoderLayer(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.encoder = self.Encoder(feature_size, latent_size) + self.decoder = self.Decoder(feature_size, latent_size) + + class Encoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(feature_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.lat = nn.Linear(latent_size, latent_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.lat.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.lat(x) + return x # -> (N, D) + + class Decoder(nn.Module): + + def __init__(self, feature_size, latent_size): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.fc1 = nn.Linear(latent_size, latent_size) + self.ln_f = nn.LayerNorm(latent_size) + self.rec = nn.Linear(latent_size, feature_size, bias=False) + + def forward(self, x): + if self.is_cuda_available: + self.fc1.cuda() + self.ln_f.cuda() + self.rec.cuda() + x = x.cuda() + x = F.gelu(self.fc1(x)) + x = self.ln_f(x) + x = self.rec(x) + return x # -> (N, L*D) + + +class LangLayer(nn.Module): + def __init__(self, n_embd, n_vocab): + super().__init__() + self.is_cuda_available = torch.cuda.is_available() + self.embed = nn.Linear(n_embd, n_embd) + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, n_vocab, bias=False) + def forward(self, tensor): + if self.is_cuda_available: + self.embed.cuda() + self.ln_f.cuda() + self.head.cuda() + tensor = tensor.cuda() + tensor = self.embed(tensor) + tensor = F.gelu(tensor) + tensor = self.ln_f(tensor) + tensor = self.head(tensor) + return tensor + + +class MoLEncoder(nn.Module): + + def __init__(self, config, n_vocab): + super(MoLEncoder, self).__init__() + + # embeddings + self.tok_emb = nn.Embedding(n_vocab, config.n_embd) + self.drop = nn.Dropout(config.d_dropout) + + # transformer + builder = RotateEncoderBuilder.from_kwargs( + n_layers=config.n_layer, + n_heads=config.n_head, + query_dimensions=config.n_embd//config.n_head, + value_dimensions=config.n_embd//config.n_head, + feed_forward_dimensions=config.n_embd, + attention_type='linear', + # unless we do deterministic_eval here, we will have random outputs + feature_map=partial(GeneralizedRandomFeatures, + n_dims=config.num_feats, + deterministic_eval=False), + activation='gelu' + ) + self.blocks = builder.get() + + # classification + self.lang_model = LangLayer(config.n_embd, n_vocab) + + def forward(self, idx, mask=None, inference=False): + if not inference: + x = self.tok_emb(idx) # each index maps to a (learnable) vector + x = self.drop(x) + + #masking of the length of the inputs its handled in the Masked language part of the code + #do not attempt to handle it in the forward of the transformer + x = self.blocks(x) + logits = self.lang_model(x) + + return logits + else: + x = self.tok_emb(idx) # each index maps to a (learnable) vector + x = self.drop(x) + + #masking of the length of the inputs its handled in the Masked language part of the code + #do not attempt to handle it in the forward of the transformer + x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1])) + + # mean pooling + token_embeddings = x + input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + true_set = sum_embeddings / sum_mask + + return true_set, token_embeddings + + +class MoLDecoder(nn.Module): + + def __init__(self, n_vocab, max_len, n_embd, n_gpu=None): + super(MoLDecoder, self).__init__() + + self.max_len = max_len + self.n_embd = n_embd + self.n_gpu = n_gpu + self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd) + self.lang_model = LangLayer(n_embd, n_vocab) + + def forward(self, token_embeddings): + pred_set = self.autoencoder.encoder(token_embeddings) # (N, D) + pred_cte = self.autoencoder.decoder(pred_set) # (N, L*D) + pred_ids = self.lang_model(pred_cte.view(-1, self.max_len, self.n_embd)) + return pred_set, pred_ids + + +class Smi_ted(nn.Module): + """materials.smi-ted-Light 289M Parameters""" + + def __init__(self, config, vocab): + super(Smi_ted, self).__init__() + + self.config = config + self.padding_idx = 2 + self.is_cuda_available = torch.cuda.is_available() + n_vocab = len(vocab.keys()) + print(n_vocab, config.n_embd) + + self.encoder = MoLEncoder(config, n_vocab) + self.decoder = MoLDecoder(n_vocab, config.max_len, config.n_embd) + + self._set_seed(config.seed) + print('Vocab size:', n_vocab) + print(f'[PRE-TRAINING MODE - {str(self)}]') + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_seed(self, value): + print('Random Seed:', value) + random.seed(value) + torch.manual_seed(value) + torch.cuda.manual_seed(value) + torch.cuda.manual_seed_all(value) + np.random.seed(value) + cudnn.deterministic = True + cudnn.benchmark = False + + def __str__(self): + return 'smi-ted-Light' \ No newline at end of file diff --git a/models/smi_ted/training/train_model_D.py b/models/smi_ted/training/train_model_D.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5b5ab3300694f1c80cbe1e20187e3dcacd60b7 --- /dev/null +++ b/models/smi_ted/training/train_model_D.py @@ -0,0 +1,98 @@ +# This code uses the decoder loss directly. +# +# + +# Deep learning +import torch +from torch_optimizer.lamb import Lamb +from trainer import TrainerDirectDecoder + +# Parallel +from torch.utils.data.distributed import DistributedSampler +from torch.distributed import init_process_group, destroy_process_group + +# Data +from utils import MoleculeModule, get_optim_groups +from torch.utils.data import DataLoader + +# Standard library +import os +import args + + +def ddp_setup(): + init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +def load_train_objs(config): + # load data + train_loader = MoleculeModule( + config.max_len, + config.train_load, + config.data_root + ) + train_loader.setup() + + loader = DataLoader( + train_loader.pubchem, + batch_size=config.n_batch, + pin_memory=True, + shuffle=False, + collate_fn=train_loader.text_encoder.process, + sampler=DistributedSampler(train_loader.pubchem), + num_workers=config.n_workers + ) + + # load model + if config.smi_ted_version == 'v1': + from smi_ted_light.load import Smi_ted + elif config.smi_ted_version == 'v2': + from smi_ted_large.load import Smi_ted + + model = Smi_ted(config, train_loader.get_vocab()).to('cuda') + model.apply(model._init_weights) + + # load optimizer + optim_groups = get_optim_groups(model) + optimizer = torch.optim.AdamW(optim_groups, lr=config.lr_decoder, betas=(0.9, 0.99), fused=True) + + return loader, model, optimizer + + +def main( + config, + save_every: int, + total_epochs: int, + save_checkpoint_path: str, + load_checkpoint_path: str + ): + ddp_setup() + + # training objects + train_data, model, optimizer = load_train_objs(config) + + # init trainer + trainer = TrainerDirectDecoder( + model, + train_data, + optimizer, + save_every, + save_checkpoint_path, + load_checkpoint_path, + config + ) + trainer.train(total_epochs) + destroy_process_group() + + +if __name__ == '__main__': + parser = args.get_parser() + args = parser.parse_args() + main( + args, + args.checkpoint_every, + args.max_epochs, + save_checkpoint_path=args.save_checkpoint_path, + load_checkpoint_path=args.load_checkpoint_path, + ) diff --git a/models/smi_ted/training/train_model_ED.py b/models/smi_ted/training/train_model_ED.py new file mode 100644 index 0000000000000000000000000000000000000000..ca83d9caaaa673b93cc83542d479e319b9da2879 --- /dev/null +++ b/models/smi_ted/training/train_model_ED.py @@ -0,0 +1,100 @@ +# This code uses both encoder and decoder losses. +# +# + +# Deep learning +import torch +from torch_optimizer.lamb import Lamb +from trainer import TrainerEncoderDecoder + +# Parallel +from torch.utils.data.distributed import DistributedSampler +from torch.distributed import init_process_group, destroy_process_group + +# Data +from utils import MoleculeModule, get_optim_groups +from torch.utils.data import DataLoader + +# Standard library +import os +import args + + +def ddp_setup(): + init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +def load_train_objs(config): + # load data + train_loader = MoleculeModule( + config.max_len, + config.train_load, + config.data_root + ) + train_loader.setup() + + loader = DataLoader( + train_loader.pubchem, + batch_size=config.n_batch, + pin_memory=True, + shuffle=False, + collate_fn=train_loader.text_encoder.process, + sampler=DistributedSampler(train_loader.pubchem), + num_workers=config.n_workers + ) + + # load model + if config.smi_ted_version == 'v1': + from smi_ted_light.load import Smi_ted + elif config.smi_ted_version == 'v2': + from smi_ted_large.load import Smi_ted + + model = Smi_ted(config, train_loader.get_vocab()) + model.apply(model._init_weights) + + # load optimizer + optim_groupsE = get_optim_groups(model.encoder) + optim_groupsD = get_optim_groups(model.decoder) + optimizerE = Lamb(optim_groupsE, lr=config.lr_start*config.lr_multiplier, betas=(0.9, 0.99)) + optimizerD = torch.optim.Adam(optim_groupsD, lr=config.lr_decoder, betas=(0.9, 0.99)) + + return loader, model, (optimizerE, optimizerD) + + +def main( + config, + save_every: int, + total_epochs: int, + save_checkpoint_path: str, + load_checkpoint_path: str + ): + ddp_setup() + + # training objects + train_data, model, optimizers = load_train_objs(config) + + # init trainer + trainer = TrainerEncoderDecoder( + model, + train_data, + optimizers, + save_every, + save_checkpoint_path, + load_checkpoint_path, + config + ) + trainer.train(total_epochs) + destroy_process_group() + + +if __name__ == '__main__': + parser = args.get_parser() + args = parser.parse_args() + main( + args, + args.checkpoint_every, + args.max_epochs, + save_checkpoint_path=args.save_checkpoint_path, + load_checkpoint_path=args.load_checkpoint_path, + ) diff --git a/models/smi_ted/training/trainer.py b/models/smi_ted/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ffff34bb51c0896e7bf7f1dffe047ae15bbcdb41 --- /dev/null +++ b/models/smi_ted/training/trainer.py @@ -0,0 +1,454 @@ +# Deep learning +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel as DDP +from fast_transformers.masking import LengthMask + +# Standard library +from tqdm import tqdm +import pandas as pd +import numpy as np +import random +import os + + +class Trainer: + + def __init__( + self, + model: torch.nn.Module, + train_data: DataLoader, + optimizer: torch.optim.Optimizer, + save_every: int, + save_checkpoint_path: str, + load_checkpoint_path: str, + config, + ) -> None: + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.global_rank = int(os.environ["RANK"]) + self.model = model.to(self.local_rank) + self.train_data = train_data + self.optimizer = optimizer + self.save_every = save_every + self.epochs_run = 0 + self.last_batch_idx = -1 + self.save_checkpoint_path = save_checkpoint_path + self.config = config + + if os.path.exists(load_checkpoint_path): + print(f"Loading checkpoint at {load_checkpoint_path}...") + self._load_checkpoint(load_checkpoint_path) + + self.model = DDP(self.model, device_ids=[self.local_rank]) + + def _load_checkpoint(self, checkpoint_path): + opt_dict = None + loc = f"cuda:{self.local_rank}" + ckpt_dict = torch.load(checkpoint_path, map_location=loc) + if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')): + opt_dict = torch.load(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt'), map_location=loc) + + self.model.load_state_dict(ckpt_dict["MODEL_STATE"]) + if opt_dict is not None: + self.optimizer.load_state_dict(opt_dict["OPTIMIZER_STATE"]) + print('Optimizer states restored!') + + self.last_batch_idx = ckpt_dict["last_batch_idx"] if 'last_batch_idx' in ckpt_dict else -1 + self.epochs_run = ckpt_dict["EPOCHS_RUN"] + 1 if self.last_batch_idx == -1 else ckpt_dict["EPOCHS_RUN"] + + # load RNG states each time the model and states are loaded from checkpoint + if 'rng' in ckpt_dict: + rng = ckpt_dict['rng'] + for key, value in rng.items(): + if key =='torch_state': + torch.set_rng_state(value.cpu()) + elif key =='cuda_state': + torch.cuda.set_rng_state(value.cpu()) + elif key =='numpy_state': + np.random.set_state(value) + elif key =='python_state': + random.setstate(value) + else: + print('unrecognized state') + + print(f"Resuming training from checkpoint at Epoch {self.epochs_run}.") + + def _save_checkpoint(self, epoch, config, last_idx): + # save RNG states each time the model and states are saved + out_dict = dict() + out_dict['torch_state'] = torch.get_rng_state() + out_dict['cuda_state'] = torch.cuda.get_rng_state() + if np: + out_dict['numpy_state'] = np.random.get_state() + if random: + out_dict['python_state'] = random.getstate() + + # model states + ckpt_dict = { + "MODEL_STATE": self.model.module.state_dict(), + "EPOCHS_RUN": epoch, + "hparams": vars(config), + "last_batch_idx": last_idx, + "rng": out_dict + } + + # optimizer states + opt_dict = { + "OPTIMIZER_STATE": self.optimizer.state_dict(), + } + + if last_idx == -1: + filename = f'{str(self.model.module)}_{epoch}.pt' + else: + filename = f'{str(self.model.module)}_{last_idx}_{epoch}.pt' + + torch.save(ckpt_dict, os.path.join(self.save_checkpoint_path, filename)) + torch.save(opt_dict, os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')) + + print(f"Epoch {epoch} | Training checkpoint saved at {os.path.join(self.save_checkpoint_path, filename)}.") + + def train(self, max_epochs: int): + for epoch in range(self.epochs_run, max_epochs): + self._run_epoch(epoch) + if self.local_rank == 0: + self._save_checkpoint(epoch, self.config, last_idx=-1) + + def _run_epoch(self, epoch): + print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {self.config.n_batch} | Steps: {len(self.train_data)} | Last batch: {self.last_batch_idx}") + self.train_data.sampler.set_epoch(epoch) + loss_list = pd.Series() + + for idx, data in enumerate(tqdm(self.train_data)): + # skip batches + if idx <= self.last_batch_idx: + continue + + # run batch + bucket_idx_masked = data[0] + bucket_targets = data[1] + bucket_idx_not_masked = data[2] + loss = self._run_batch(bucket_idx_masked, bucket_targets, bucket_idx_not_masked) + torch.cuda.empty_cache() + + # track loss + if self.local_rank == 0: + loss_list = pd.concat([loss_list, pd.Series([loss])], axis=0) + + # checkpoint + if self.local_rank == 0 and idx % self.save_every == 0 and idx != 0: + self._save_checkpoint(epoch, self.config, idx) + # WARN: due to job limit time - save loss for each iter + loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_{idx}_epoch{epoch}.csv'), index=False) + loss_list = pd.Series() + + self.last_batch_idx = -1 + + if self.local_rank == 0: + loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False) + + def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked): + raise NotImplementedError + + +class TrainerEncoderDecoder(Trainer): + + def __init__( + self, + model: torch.nn.Module, + train_data: DataLoader, + optimizer: torch.optim.Optimizer, + save_every: int, + save_checkpoint_path: str, + load_checkpoint_path: str, + config, + ) -> None: + super().__init__(model, train_data, optimizer, save_every, save_checkpoint_path, load_checkpoint_path, config) + self.criterionC = nn.CrossEntropyLoss(ignore_index=-100) + self.criterionR = nn.MSELoss() + + self.optimE = self.optimizer[0] + self.optimD = self.optimizer[1] + + self.ngpus_per_node = torch.cuda.device_count() + self.total_batches = len(self.train_data) + self.batch_thresh = int(self.total_batches - (self.total_batches * 0.05 * self.ngpus_per_node)) + print('batch_thresh:', self.batch_thresh) + + def _load_checkpoint(self, checkpoint_path): + opt_dict = None + loc = f"cuda:{self.local_rank}" + ckpt_dict = torch.load(checkpoint_path, map_location=loc) + if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')): + opt_dict = torch.load(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt'), map_location=loc) + + self.model.load_state_dict(ckpt_dict["MODEL_STATE"]) + if opt_dict is not None: + self.optimizer[0].load_state_dict(opt_dict["OPTIMIZER_STATE_ENCODER"]) + self.optimizer[1].load_state_dict(opt_dict["OPTIMIZER_STATE_DECODER"]) + print('Optimizer states restored!') + + self.last_batch_idx = ckpt_dict["last_batch_idx"] if 'last_batch_idx' in ckpt_dict else -1 + self.epochs_run = ckpt_dict["EPOCHS_RUN"] + 1 if self.last_batch_idx == -1 else ckpt_dict["EPOCHS_RUN"] + + # load RNG states each time the model and states are loaded from checkpoint + if 'rng' in ckpt_dict: + rng = ckpt_dict['rng'] + for key, value in rng.items(): + if key =='torch_state': + torch.set_rng_state(value.cpu()) + elif key =='cuda_state': + torch.cuda.set_rng_state(value.cpu()) + elif key =='numpy_state': + np.random.set_state(value) + elif key =='python_state': + random.setstate(value) + else: + print('unrecognized state') + + print(f"Resuming training from checkpoint at Epoch {self.epochs_run}.") + + def _save_checkpoint(self, epoch, config, last_idx): + # save RNG states each time the model and states are saved + out_dict = dict() + out_dict['torch_state'] = torch.get_rng_state() + out_dict['cuda_state'] = torch.cuda.get_rng_state() + if np: + out_dict['numpy_state'] = np.random.get_state() + if random: + out_dict['python_state'] = random.getstate() + + # model states + ckpt_dict = { + "MODEL_STATE": self.model.module.state_dict(), + "EPOCHS_RUN": epoch, + "hparams": vars(config), + "last_batch_idx": last_idx, + "rng": out_dict + } + + # optimizer states + opt_dict = { + "OPTIMIZER_STATE_ENCODER": self.optimizer[0].state_dict(), + "OPTIMIZER_STATE_DECODER": self.optimizer[1].state_dict(), + } + + if last_idx == -1: + filename = f'{str(self.model.module)}_{epoch}.pt' + else: + filename = f'{str(self.model.module)}_{last_idx}_{epoch}.pt' + + torch.save(ckpt_dict, os.path.join(self.save_checkpoint_path, filename)) + torch.save(opt_dict, os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')) + + print(f"Epoch {epoch} | Training checkpoint saved at {os.path.join(self.save_checkpoint_path, filename)}.") + + def _run_epoch(self, epoch): + print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {self.config.n_batch} | Steps: {len(self.train_data)}") + self.train_data.sampler.set_epoch(epoch) + loss_list = pd.DataFrame() + + for idx, data in enumerate(tqdm(self.train_data)): + bucket_idx_masked = data[0] + bucket_targets = data[1] + bucket_idx_not_masked = data[2] + lossE, lossD = self._run_batch(idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked) + torch.cuda.empty_cache() + + if self.local_rank == 0: + df = pd.DataFrame({ + 'lossE': [lossE.cpu().item()], + 'lossD': [lossD.cpu().item()], + }) + loss_list = pd.concat([loss_list, df], axis=0) + + if self.local_rank == 0: + loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False) + + def custom(self, module): + def custom_forward(*inputs): + inputs = module(inputs[0]) + return inputs + return custom_forward + + def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked): + self.optimE.zero_grad(set_to_none=True) + self.optimD.zero_grad(set_to_none=True) + + can_train_encoder = (batch_idx + 1) <= self.batch_thresh + can_train_decoder = (batch_idx + 1) > self.batch_thresh + + padding_idx = 2 + errorE = torch.zeros(1).to(self.local_rank) + errorD = torch.zeros(1).to(self.local_rank) + errorE_tmp = .0 + errorD_tmp = .0 + + for chunk in range(len(bucket_idx_masked)): + idx_masked = bucket_idx_masked[chunk].to(self.local_rank) + targets = bucket_targets[chunk].to(self.local_rank) + idx_not_masked = bucket_idx_not_masked[chunk] + idx_not_masked = list(map(lambda x: F.pad(x, pad=(0, self.config.max_len - x.shape[0]), value=2).unsqueeze(0), idx_not_masked)) + idx_not_masked = torch.cat(idx_not_masked, dim=0).to(self.local_rank) + mask = (idx_masked != padding_idx) + + ########### + # Encoder # + ########### + if can_train_encoder: + for param in self.model.module.encoder.parameters(): + param.requires_grad = True + for param in self.model.module.decoder.parameters(): + param.requires_grad = False + + # encoder forward + x = self.model.module.encoder.tok_emb(idx_masked) + x = self.model.module.encoder.drop(x) + x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x) + logits = self.model.module.encoder.lang_model(x) + + # loss function + logits = logits.view(-1, logits.size(-1)) + targets = targets.view(-1) + errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked) + + if chunk < len(bucket_idx_masked)-1: + errorE_tmp.backward() + errorE += errorE_tmp.detach() + else: + errorE += errorE_tmp + + + ########### + # Decoder # + ########### + if can_train_decoder: + for param in self.model.module.encoder.parameters(): + param.requires_grad = False + for param in self.model.module.decoder.parameters(): + param.requires_grad = True + + self.model.module.encoder.eval() + + # encoder forward + with torch.no_grad(): + true_set, true_cte = self.model.module.encoder(idx_masked, mask=mask, inference=True) + + # add padding + input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float() + mask_embeddings = (true_cte * input_mask_expanded) + true_cte = F.pad(mask_embeddings, pad=(0, 0, 0, self.config.max_len - mask_embeddings.shape[1]), value=0) + true_cte = true_cte.view(-1, self.config.max_len*self.config.n_embd) + + # decoder forward + pred_set, pred_ids = self.model.module.decoder(true_cte) + + # losses + pred_ids = pred_ids.view(-1, pred_ids.size(-1)) + true_ids = idx_not_masked.view(-1) + + error_ids = self.criterionC(pred_ids, true_ids) / len(bucket_idx_masked) + error_set = self.criterionR(pred_set, true_set) / len(bucket_idx_masked) + errorD_tmp = error_ids + error_set + + if chunk < len(bucket_idx_masked)-1: + errorD_tmp.backward() + errorD += errorD_tmp.detach() + else: + errorD += errorD_tmp + + if can_train_decoder: + errorD.backward() + self.optimD.step() + elif can_train_encoder: + errorE.backward() + self.optimE.step() + + if self.local_rank == 0: + print(f'LossE: {errorE.item()} | LossD: {errorD.item()}') + return errorE, errorD + + +class TrainerDirectDecoder(Trainer): + + def __init__( + self, + model: torch.nn.Module, + train_data: DataLoader, + optimizer: torch.optim.Optimizer, + save_every: int, + save_checkpoint_path: str, + load_checkpoint_path: str, + config, + ) -> None: + super().__init__(model, train_data, optimizer, save_every, save_checkpoint_path, load_checkpoint_path, config) + self.criterionC = nn.CrossEntropyLoss(ignore_index=-100) + self.criterionR = nn.MSELoss() + + def custom(self, module): + def custom_forward(*inputs): + inputs = module(inputs[0], length_mask=inputs[1]) + return inputs + return custom_forward + + def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked): + padding_idx = 2 + error = torch.zeros(1).to(self.local_rank) + error_tmp = .0 + self.optimizer.zero_grad(set_to_none=True) + + for chunk in range(len(bucket_idx_masked)): + idx_masked = bucket_idx_masked[chunk].to(self.local_rank) + targets = bucket_targets[chunk].to(self.local_rank) + idx_not_masked = bucket_idx_not_masked[chunk] + idx_not_masked = list(map(lambda x: F.pad(x, pad=(0, self.config.max_len - x.shape[0]), value=2).unsqueeze(0), idx_not_masked)) + idx_not_masked = torch.cat(idx_not_masked, dim=0).to(self.local_rank) + mask = (idx_masked != padding_idx) + + # encoder forward + x = self.model.module.encoder.tok_emb(idx_masked) + x = self.model.module.encoder.drop(x) + x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x, LengthMask(mask.sum(-1), max_len=idx_masked.shape[1])) + + # mean pooling + input_masked_expanded = mask.unsqueeze(-1).expand(x.size()).float() + sum_embeddings = torch.sum(x*input_masked_expanded, 1) + sum_mask = torch.clamp(input_masked_expanded.sum(1), min=1e-9) + true_set = sum_embeddings/sum_mask + true_cte = x + del x + torch.cuda.empty_cache() + + # add padding + input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float() + mask_embeddings = (true_cte * input_mask_expanded) + true_cte = F.pad(mask_embeddings, pad=(0, 0, 0, self.config.max_len - mask_embeddings.shape[1]), value=0) + true_cte = true_cte.view(-1, self.config.max_len*self.config.n_embd) + + # decoder forward + pred_set, pred_ids = self.model.module.decoder(true_cte) + + # losses + pred_ids = pred_ids.view(-1, pred_ids.size(-1)) + true_ids = idx_not_masked.view(-1) + + error_ids = self.criterionC(pred_ids, true_ids) / len(bucket_idx_masked) + error_set = self.criterionR(pred_set, true_set) / len(bucket_idx_masked) + error_tmp = error_ids + error_set + + if chunk < len(bucket_idx_masked)-1: + error_tmp.backward() + error += error_tmp.detach() + else: + error += error_tmp + + torch.cuda.empty_cache() + + error.backward() + self.optimizer.step() + + if self.local_rank == 0: + print(f'Loss: {error.item()}') + return error.item() diff --git a/models/smi_ted/training/utils.py b/models/smi_ted/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0634255d5452391eeedecdbf007da92654fca8e --- /dev/null +++ b/models/smi_ted/training/utils.py @@ -0,0 +1,96 @@ +# Deep learning +import torch + +# Data +from pubchem_encoder import Encoder +from datasets import load_dataset + +# Standard library +import os +import getpass +import glob + + +class MoleculeModule: + def __init__(self, max_len, dataset, data_path): + super().__init__() + self.dataset = dataset + self.data_path = data_path + self.text_encoder = Encoder(max_len) + + def prepare_data(self): + pass + + def get_vocab(self): + #using home made tokenizer, should look into existing tokenizer + return self.text_encoder.char2id + + def get_cache(self): + return self.cache_files + + def setup(self, stage=None): + #using huggingface dataloader + # create cache in tmp directory of locale mabchine under the current users name to prevent locking issues + pubchem_path = {'train': self.data_path} + if 'canonical' in pubchem_path['train'].lower(): + pubchem_script = './pubchem_canon_script.py' + else: + pubchem_script = './pubchem_script.py' + zinc_path = './data/ZINC' + global dataset_dict + if 'ZINC' in self.dataset or 'zinc' in self.dataset: + zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] + for zfile in zinc_files: + print(zfile) + self.dataset = {'train': zinc_files} + dataset_dict = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) + + elif 'pubchem' in self.dataset: + dataset_dict = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'), split='train') + elif 'both' in self.dataset or 'Both' in self.dataset or 'BOTH' in self.dataset: + dataset_dict_pubchem = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'),split='train', trust_remote_code=True) + zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] + for zfile in zinc_files: + print(zfile) + self.dataset = {'train': zinc_files} + dataset_dict_zinc = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) + dataset_dict = concatenate_datasets([dataset_dict_zinc, dataset_dict_pubchem]) + self.pubchem= dataset_dict + print(dataset_dict.cache_files) + self.cache_files = [] + + for cache in dataset_dict.cache_files: + tmp = '/'.join(cache['filename'].split('/')[:4]) + self.cache_files.append(tmp) + + +def get_optim_groups(module): + # setup optimizer + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in module.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in module.named_parameters()} + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + return optim_groups \ No newline at end of file diff --git a/representation/.gitattributes b/representation/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..3307846513fc22ee6369f98b69102d089545bdd5 --- /dev/null +++ b/representation/.gitattributes @@ -0,0 +1,9 @@ +bace_mhg.pkl filter=lfs diff=lfs merge=lfs -text +esol_mhg.pkl filter=lfs diff=lfs merge=lfs -text +esol_mol-xl.pkl filter=lfs diff=lfs merge=lfs -text +bace_smi-ted.pkl filter=lfs diff=lfs merge=lfs -text +esol_bart.pkl filter=lfs diff=lfs merge=lfs -text +esol_smi-ted.pkl filter=lfs diff=lfs merge=lfs -text +bace_MorganFingerprint.pkl filter=lfs diff=lfs merge=lfs -text +bace_bart.pkl filter=lfs diff=lfs merge=lfs -text +bace_mol-xl.pkl filter=lfs diff=lfs merge=lfs -text diff --git a/representation/bace_MorganFingerprint.pkl b/representation/bace_MorganFingerprint.pkl new file mode 100644 index 0000000000000000000000000000000000000000..99011cc1c280a96acbbcfebfa10bbddff207a6b7 --- /dev/null +++ b/representation/bace_MorganFingerprint.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b64ea75a5ac268dd9fd69c5ec85dd32358a15e2f9069a78c0c3c3148cf54f257 +size 11161440 diff --git a/representation/esol_MorganFingerprint.pkl b/representation/esol_MorganFingerprint.pkl new file mode 100755 index 0000000000000000000000000000000000000000..6b1d2fafdec4fe4c97710f8a7b395c03b42703c8 --- /dev/null +++ b/representation/esol_MorganFingerprint.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f112772073dc4603631cf5845f9f13a10abb5c611687adde13a30a3537c00d3b +size 7889672 diff --git a/requirements.txt b/requirements.txt index 17053f0ff3b341f0fb5e925ef1da3112d2b84e6a..4f4bc906f62f3678e53f612462949f36e6b0e802 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ umap-learn torch-optimizer tqdm>=4.66.4 pandas==2.2.3 +mordred