File size: 2,082 Bytes
a96c72f
c585309
75eb9ca
a96c72f
6d7b830
 
e22ba0b
c585309
75eb9ca
a96c72f
6d7b830
a96c72f
6d7b830
 
a96c72f
9000ced
 
a96c72f
75eb9ca
 
56843e1
9000ced
 
75eb9ca
 
 
9000ced
1d45e7e
 
9000ced
 
e22ba0b
9000ced
9417eab
 
 
 
b434f91
9000ced
9417eab
 
 
 
 
 
 
6d7b830
 
75eb9ca
 
 
56843e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left", use_fast=False)
    model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
    return tokenizer, model

st.title("Український Чат-бот")

if "history" not in st.session_state:
    st.session_state.history = []

if "user_input" not in st.session_state:
    st.session_state.user_input = ""

tokenizer, model = load_model()

def send_message():
    if st.session_state.user_input:
        inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=100)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        st.session_state.history.extend([st.session_state.user_input, response])
        st.session_state.user_input = "" # clear the stored user input
        st.session_state.temp_user_input = "" # clear the text input field

def update_user_input():
    st.session_state.user_input = st.session_state.temp_user_input

# Очищаємо temp_user_input після натискання кнопки
def clear_input():
    st.session_state.temp_user_input = ""

st.text_input("Ви:", key="temp_user_input", on_change=update_user_input)

if st.button("Надіслати", on_click=clear_input):
    send_message()

# Обробка натискання Enter
if st.session_state.get("temp_user_input", "") and st.session_state.get("last_input", "") != st.session_state.get("temp_user_input", ""):
    st.session_state["last_input"] = st.session_state["temp_user_input"]
    send_message()

if st.session_state.history:
    for i in range(0, len(st.session_state.history), 2):
        st.write(f"Ви: {st.session_state.history[i]}")
        if i + 1 < len(st.session_state.history):
            st.write(f"Бот: {st.session_state.history[i+1]}")