sfaezella commited on
Commit
bd47e61
·
verified ·
1 Parent(s): 27bd3e3

Add full model description and caution note to Gradio app

Browse files
Files changed (1) hide show
  1. app.py +102 -65
app.py CHANGED
@@ -1,66 +1,103 @@
1
- import torch
2
- import gradio as gr
3
- import numpy as np
4
- from transformers import T5Tokenizer, T5EncoderModel
5
- import esm
6
- from inference import load_models, predict_ensemble
7
- from transformers import AutoTokenizer, AutoModel
8
- import spaces
9
-
10
- # Load trained models
11
- model_protT5, model_cat = load_models()
12
-
13
- # Load ProtT5 model
14
- tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
15
- model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
16
- model_t5 = model_t5.eval()
17
-
18
- # Load the tokenizer and model
19
- model_name = "facebook/esm2_t33_650M_UR50D"
20
- tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
21
- esm_model = AutoModel.from_pretrained(model_name)
22
-
23
- def extract_prott5_embedding(sequence):
24
- sequence = sequence.replace(" ", "")
25
- seq = " ".join(list(sequence))
26
- ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
27
- with torch.no_grad():
28
- embedding = model_t5(**ids).last_hidden_state
29
- return torch.mean(embedding, dim=1)
30
-
31
-
32
- # Extract ESM2 embedding
33
- def extract_esm_embedding(sequence):
34
- # Tokenize the sequence
35
- inputs = tokenizer_esm(sequence, return_tensors="pt", padding=True, truncation=True)
36
-
37
- # Forward pass through the model
38
- with torch.no_grad():
39
- outputs = esm_model(**inputs)
40
-
41
- # Extract the embeddings from the 33rd layer (ESM2 layer)
42
- token_representations = outputs.last_hidden_state # This is the default layer
43
- return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
44
-
45
- def estimate_duration(sequence):
46
- # Estimate duration based on sequence length
47
- base_time = 30 # Base time in seconds
48
- time_per_residue = 0.5 # Estimated time per residue
49
- estimated_time = base_time + len(sequence) * time_per_residue
50
- return min(int(estimated_time), 300) # Cap at 300 seconds
51
-
52
- @spaces.GPU(duration=120)
53
- def classify(sequence):
54
- protT5_emb = extract_prott5_embedding(sequence)
55
- esm_emb = extract_esm_embedding(sequence)
56
- concat = torch.cat((esm_emb, protT5_emb), dim=1)
57
- pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
58
- return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
59
-
60
-
61
- demo = gr.Interface(fn=classify,
62
- inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
63
- outputs=gr.Label(label="Prediction"))
64
-
65
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  demo.launch()
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ from transformers import T5Tokenizer, T5EncoderModel
5
+ import esm
6
+ from inference import load_models, predict_ensemble
7
+ from transformers import AutoTokenizer, AutoModel
8
+ import spaces
9
+
10
+ # Load trained models
11
+ model_protT5, model_cat = load_models()
12
+
13
+ # Load ProtT5 model
14
+ tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
15
+ model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
16
+ model_t5 = model_t5.eval()
17
+
18
+ # Load the tokenizer and model
19
+ model_name = "facebook/esm2_t33_650M_UR50D"
20
+ tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
21
+ esm_model = AutoModel.from_pretrained(model_name)
22
+
23
+ def extract_prott5_embedding(sequence):
24
+ sequence = sequence.replace(" ", "")
25
+ seq = " ".join(list(sequence))
26
+ ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
27
+ with torch.no_grad():
28
+ embedding = model_t5(**ids).last_hidden_state
29
+ return torch.mean(embedding, dim=1)
30
+
31
+
32
+ # Extract ESM2 embedding
33
+ def extract_esm_embedding(sequence):
34
+ # Tokenize the sequence
35
+ inputs = tokenizer_esm(sequence, return_tensors="pt", padding=True, truncation=True)
36
+
37
+ # Forward pass through the model
38
+ with torch.no_grad():
39
+ outputs = esm_model(**inputs)
40
+
41
+ # Extract the embeddings from the 33rd layer (ESM2 layer)
42
+ token_representations = outputs.last_hidden_state # This is the default layer
43
+ return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
44
+
45
+ def estimate_duration(sequence):
46
+ # Estimate duration based on sequence length
47
+ base_time = 30 # Base time in seconds
48
+ time_per_residue = 0.5 # Estimated time per residue
49
+ estimated_time = base_time + len(sequence) * time_per_residue
50
+ return min(int(estimated_time), 300) # Cap at 300 seconds
51
+
52
+ @spaces.GPU(duration=120)
53
+ def classify(sequence):
54
+ protT5_emb = extract_prott5_embedding(sequence)
55
+ esm_emb = extract_esm_embedding(sequence)
56
+ concat = torch.cat((esm_emb, protT5_emb), dim=1)
57
+ pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
58
+ return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
59
+
60
+
61
+ demo = gr.Interface(fn=classify,
62
+ inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
63
+ outputs=gr.Label(label="Prediction"))
64
+
65
+ # if __name__ == "__main__":
66
+ # demo.launch()
67
+
68
+
69
+
70
+ description_md = """
71
+ ### ℹ️ **About AllerTrans – Allergenicity Prediction Tool**
72
+
73
+ **🧬 Input Format – FASTA Sequences**
74
+ This tool accepts protein sequences in FASTA format
75
+
76
+ **💡 Accepted Proteins**
77
+ - Natural and recombinant proteins
78
+ - Pharmaceutical and industrial proteins
79
+ - Synthetic sequences (tags or mutations allowed)
80
+
81
+ 🔎 **Note of Caution**:
82
+ While our model demonstrates promising performance—particularly with recombinant proteins, as evidenced by our additional evaluation with a recombinant protein dataset
83
+ from UniProt—**we advise caution when generalizing the results to all recombinant protein scenarios**.
84
+ The specificity of the model to various recombinant constructs and modifications has not been explored.
85
+
86
+ **🧠 Prediction Process**
87
+ - Embeddings via ProtT5 + ESM-2
88
+ - Deep neural network for classification
89
+
90
+ **⚠️ Disclaimer**
91
+ Although AllerTrans provides highly accurate predictions, it is intended as a screening tool. For clinical or regulatory decisions, always confirm results with experimental validation.
92
+ """
93
+
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown(description_md)
96
+ with gr.Row():
97
+ input_box = gr.Textbox(lines=3, placeholder="Enter protein sequence...")
98
+ output_label = gr.Label(label="Prediction")
99
+ classify_btn = gr.Button("Run Prediction")
100
+ classify_btn.click(classify, inputs=input_box, outputs=output_label)
101
+
102
+ if __name__ == "__main__":
103
  demo.launch()