hh1199 commited on
Commit
62d6e64
·
verified ·
1 Parent(s): 6d95508

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -14
app.py CHANGED
@@ -10,10 +10,14 @@ MODELS = {
10
  "ruRoberta": "sberbank-ai/ruRoberta-large"
11
  }
12
 
 
 
 
 
 
 
13
  def get_embeddings(model, tokenizer, text):
14
- # Добавляем промпт
15
- prompted_text = f"Товар: {text}. Категория:"
16
- inputs = tokenizer(prompted_text,
17
  padding=True,
18
  truncation=True,
19
  return_tensors="pt",
@@ -21,18 +25,22 @@ def get_embeddings(model, tokenizer, text):
21
  outputs = model(**inputs)
22
  return outputs.last_hidden_state[:, 0].detach().numpy()
23
 
24
- def classify(model_name: str, item: str, categories: str) -> str:
25
  tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
26
  model = AutoModel.from_pretrained(MODELS[model_name])
27
 
28
- # Эмбеддинги для товара с промптом
29
- item_embedding = get_embeddings(model, tokenizer, item)
 
 
 
30
 
31
- # Эмбеддинги для категорий
32
- category_embeddings = []
33
- for category in categories.split(","):
34
- emb = get_embeddings(model, tokenizer, category.strip())
35
- category_embeddings.append(emb)
 
36
 
37
  # Сравнение
38
  similarities = cosine_similarity(item_embedding, np.vstack(category_embeddings))[0]
@@ -43,9 +51,10 @@ def classify(model_name: str, item: str, categories: str) -> str:
43
  gr.Interface(
44
  fn=classify,
45
  inputs=[
46
- gr.Dropdown(list(MODELS.keys())),
47
- gr.Textbox(),
48
- gr.Textbox(value="Инструменты, Овощи, Техника")
 
49
  ],
50
  outputs=gr.Textbox()
51
  ).launch()
 
10
  "ruRoberta": "sberbank-ai/ruRoberta-large"
11
  }
12
 
13
+ PROMPT_TEMPLATES = {
14
+ "basic": "Товар: {item}. Категория:",
15
+ "examples": "Примеры:\n- Молоток → Инструменты\n- Морковь → Овощи\nТовар: {item} → ",
16
+ "strict": "Выбери категорию из [{categories}]. Товар: {item}. Категория:"
17
+ }
18
+
19
  def get_embeddings(model, tokenizer, text):
20
+ inputs = tokenizer(text,
 
 
21
  padding=True,
22
  truncation=True,
23
  return_tensors="pt",
 
25
  outputs = model(**inputs)
26
  return outputs.last_hidden_state[:, 0].detach().numpy()
27
 
28
+ def classify(model_name: str, prompt_type: str, item: str, categories: str) -> str:
29
  tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name])
30
  model = AutoModel.from_pretrained(MODELS[model_name])
31
 
32
+ # Формируем промпт
33
+ prompt = PROMPT_TEMPLATES[prompt_type].format(
34
+ item=item,
35
+ categories=", ".join([c.strip() for c in categories.split(",")])
36
+ )
37
 
38
+ # Эмбеддинги
39
+ item_embedding = get_embeddings(model, tokenizer, prompt)
40
+ category_embeddings = [
41
+ get_embeddings(model, tokenizer, c.strip())
42
+ for c in categories.split(",")
43
+ ]
44
 
45
  # Сравнение
46
  similarities = cosine_similarity(item_embedding, np.vstack(category_embeddings))[0]
 
51
  gr.Interface(
52
  fn=classify,
53
  inputs=[
54
+ gr.Dropdown(list(MODELS.keys()), label="Модель"),
55
+ gr.Dropdown(list(PROMPT_TEMPLATES.keys()), label="Шаблон промпта"),
56
+ gr.Textbox(label="Товар"),
57
+ gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника")
58
  ],
59
  outputs=gr.Textbox()
60
  ).launch()