codewithdark commited on
Commit
d86d806
·
verified ·
1 Parent(s): 2dc2198

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -55
app.py CHANGED
@@ -1,66 +1,40 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
5
 
6
- # Initialize Hugging Face Inference API client
7
- hf_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
- # Load the second model (local)
10
- local_model_name = "codewithdark/latent-recurrent-depth-lm"
11
- tokenizer = AutoTokenizer.from_pretrained(local_model_name)
12
- model = AutoModelForCausalLM.from_pretrained(local_model_name, trust_remote_code=True)
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- model.to(device).eval() # Set model to evaluation mode
15
 
16
- def generate_response(
17
- message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, model_choice
18
- ):
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
20
 
21
- for val in history:
22
- if val[0]:
23
- messages.append({"role": "user", "content": val[0]})
24
- if val[1]:
25
- messages.append({"role": "assistant", "content": val[1]})
26
 
27
- messages.append({"role": "user", "content": message})
28
-
29
- if model_choice == "Zephyr-7B (API)":
30
- response = ""
31
- try:
32
- for message in hf_client.chat_completion(
33
- messages=messages,
34
- max_tokens=max_tokens,
35
- stream=True,
36
- temperature=temperature,
37
- top_p=top_p,
38
- ):
39
- token = message.choices[0].delta.content if message.choices else ""
40
- response += token
41
- yield response
42
- except Exception as e:
43
- yield f"Error in API response: {e}"
44
- else:
45
- input_text = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
46
- with torch.no_grad():
47
- output = model.generate(input_text, max_length=max_tokens, temperature=temperature, top_p=top_p)
48
- response = tokenizer.decode(output[0], skip_special_tokens=True).strip()
49
-
50
- for i in range(len(response)):
51
- yield response[: i + 1]
52
 
53
- # Gradio UI
54
- demo = gr.ChatInterface(
55
- generate_response,
56
- additional_inputs=[
57
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
58
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
59
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
60
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
61
- gr.Radio(["Zephyr-7B (API)", "Latent Recurrent Depth LM"], value="Zephyr-7B (API)", label="Select Model"),
62
- ],
63
- )
64
 
 
65
  if __name__ == "__main__":
66
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from transformers import AutoModel, AutoTokenizer
4
 
5
+ # Load the local model
6
+ model_name = "codewithdark/latent-recurrent-depth-lm"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model.to(device).eval() # Set to evaluation mode
11
 
12
+ # Define inference function
13
+ def chat_with_model(input_text, model_choice):
14
+ if model_choice == "Latent Recurrent Depth LM":
15
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
16
+ with torch.no_grad():
17
+ output = model.generate(input_ids, max_length=512)
18
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
19
+ return response
20
+ return "Model not available yet!"
21
 
22
+ # Create Gradio Interface
23
+ with gr.Blocks() as demo:
24
+ gr.Markdown("# 🤖 Chat with Latent Recurrent Depth LM")
 
 
25
 
26
+ model_choice = gr.Radio(
27
+ ["Latent Recurrent Depth LM"], # Add more models if needed
28
+ label="Select Model",
29
+ value="Latent Recurrent Depth LM"
30
+ )
31
+
32
+ text_input = gr.Textbox(label="Enter your message")
33
+ submit_button = gr.Button("Generate Response")
34
+ output_text = gr.Textbox(label="Model Response")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ submit_button.click(fn=chat_with_model, inputs=[text_input, model_choice], outputs=output_text)
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Launch the Gradio app
39
  if __name__ == "__main__":
40
  demo.launch()