giovanni correia commited on
Commit
b02d52a
·
verified ·
1 Parent(s): f9565b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -42
app.py CHANGED
@@ -1,64 +1,150 @@
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
  additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
 
 
 
 
52
  gr.Slider(
53
- minimum=0.1,
 
54
  maximum=1.0,
55
- value=0.95,
56
  step=0.05,
57
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
58
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ],
60
  )
61
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ # Configuration des limites de tokens
11
+ MAX_MAX_NEW_TOKENS = 6048
12
+ DEFAULT_MAX_NEW_TOKENS = 3024
13
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
14
+
15
+ # Description pour l'interface utilisateur
16
+ DESCRIPTION = """\
17
+ # DeepSeek-6.7B-Chat
18
+
19
+ This Space demonstrates model [DeepSeek-Coder](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) by DeepSeek, a code model with 6.7B parameters fine-tuned for chat instructions.
20
  """
 
 
 
21
 
22
+ if not torch.cuda.is_available():
23
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may run slowly on CPU.</p>"
24
 
25
+ # Chargement du modèle et du tokenizer
26
+ model_id = "deepseek-ai/deepseek-coder-6.7b-instruct"
27
+ if torch.cuda.is_available():
28
+ model = AutoModelForCausalLM.from_pretrained(model_id,
29
+ torch_dtype=torch.bfloat16,
30
+ device_map="auto")
31
+ else:
32
+ model = AutoModelForCausalLM.from_pretrained(model_id)
 
33
 
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+ tokenizer.use_default_system_prompt = False
 
 
 
36
 
 
37
 
38
+ @spaces.GPU
39
+ def generate(
40
+ message: str,
41
+ chat_history: list,
42
+ system_prompt: str,
43
+ max_new_tokens: int = 1024,
44
+ temperature: float = 0.6,
45
+ top_p: float = 0.9,
46
+ top_k: int = 50,
47
+ repetition_penalty: float = 1,
48
+ ) -> Iterator[str]:
49
+ # Préparation de la conversation
50
+ conversation = []
51
+ if system_prompt:
52
+ conversation.append({"role": "system", "content": system_prompt})
53
+ for user, assistant in chat_history:
54
+ conversation.extend([
55
+ {
56
+ "role": "user",
57
+ "content": user
58
+ },
59
+ {
60
+ "role": "assistant",
61
+ "content": assistant
62
+ },
63
+ ])
64
+ conversation.append({"role": "user", "content": message})
65
 
66
+ # Encodage des entrées
67
+ input_ids = tokenizer.apply_chat_template(conversation,
68
+ return_tensors="pt",
69
+ add_generation_prompt=True)
70
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
+ gr.Warning(
73
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
74
+ )
75
+ input_ids = input_ids.to(model.device)
76
 
77
+ # Création du flux de sortie
78
+ streamer = TextIteratorStreamer(tokenizer,
79
+ timeout=10.0,
80
+ skip_prompt=True,
81
+ skip_special_tokens=True)
82
+ generate_kwargs = dict(
83
+ {"input_ids": input_ids},
84
+ streamer=streamer,
85
+ max_new_tokens=max_new_tokens,
86
+ do_sample=False,
87
+ num_beams=1,
88
+ repetition_penalty=repetition_penalty,
89
+ eos_token_id=tokenizer.eos_token_id,
90
+ )
91
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
92
+ t.start()
93
 
94
+ outputs = []
95
+ for text in streamer:
96
+ outputs.append(text)
97
+ yield "".join(outputs).replace("<|EOT|>", "")
98
 
99
+
100
+ # Interface utilisateur avec Gradio
101
+ chat_interface = gr.ChatInterface(
102
+ fn=generate,
 
103
  additional_inputs=[
104
+ gr.Textbox(label="System prompt", lines=6),
105
+ gr.Slider(
106
+ label="Max new tokens",
107
+ minimum=1,
108
+ maximum=MAX_MAX_NEW_TOKENS,
109
+ step=1,
110
+ value=DEFAULT_MAX_NEW_TOKENS,
111
+ ),
112
  gr.Slider(
113
+ label="Top-p (nucleus sampling)",
114
+ minimum=0.05,
115
  maximum=1.0,
 
116
  step=0.05,
117
+ value=0.9,
118
+ ),
119
+ gr.Slider(
120
+ label="Top-k",
121
+ minimum=1,
122
+ maximum=1000,
123
+ step=1,
124
+ value=50,
125
  ),
126
+ gr.Slider(
127
+ label="Repetition penalty",
128
+ minimum=1.0,
129
+ maximum=2.0,
130
+ step=0.05,
131
+ value=1,
132
+ ),
133
+ ],
134
+ stop_btn=None,
135
+ examples=[
136
+ ["implement snake game using pygame"],
137
+ [
138
+ "Can you explain briefly to me what is the Python programming language?"
139
+ ],
140
+ ["write a program to find the factorial of a number"],
141
  ],
142
  )
143
 
144
+ # Création du bloc Gradio
145
+ with gr.Blocks(css="style.css") as demo:
146
+ gr.Markdown(DESCRIPTION)
147
+ chat_interface.render()
148
 
149
  if __name__ == "__main__":
150
+ demo.queue().launch(share=True)