DemahAlmutairi commited on
Commit
8c3215b
·
verified ·
1 Parent(s): c95f394

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
 
4
 
5
  def load_model(model_name):
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -21,6 +22,7 @@ def load_model(model_name):
21
  )
22
  return generator
23
 
 
24
  def generate_text(prompt, model_name):
25
  generator = load_model(model_name)
26
  messages = [{"role": "user", "content": prompt}]
@@ -48,3 +50,4 @@ demo = gr.Interface(
48
  )
49
 
50
  demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
+ import spaces
5
 
6
  def load_model(model_name):
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
22
  )
23
  return generator
24
 
25
+ @spaces.GPU
26
  def generate_text(prompt, model_name):
27
  generator = load_model(model_name)
28
  messages = [{"role": "user", "content": prompt}]
 
50
  )
51
 
52
  demo.launch()
53
+