File size: 2,082 Bytes
a96c72f c585309 75eb9ca a96c72f 6d7b830 e22ba0b c585309 75eb9ca a96c72f 6d7b830 a96c72f 6d7b830 a96c72f 9000ced a96c72f 75eb9ca 56843e1 9000ced 75eb9ca 9000ced 1d45e7e 9000ced e22ba0b 9000ced 9417eab b434f91 9000ced 9417eab 6d7b830 75eb9ca 56843e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left", use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
return tokenizer, model
st.title("Український Чат-бот")
if "history" not in st.session_state:
st.session_state.history = []
if "user_input" not in st.session_state:
st.session_state.user_input = ""
tokenizer, model = load_model()
def send_message():
if st.session_state.user_input:
inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.session_state.history.extend([st.session_state.user_input, response])
st.session_state.user_input = "" # clear the stored user input
st.session_state.temp_user_input = "" # clear the text input field
def update_user_input():
st.session_state.user_input = st.session_state.temp_user_input
# Очищаємо temp_user_input після натискання кнопки
def clear_input():
st.session_state.temp_user_input = ""
st.text_input("Ви:", key="temp_user_input", on_change=update_user_input)
if st.button("Надіслати", on_click=clear_input):
send_message()
# Обробка натискання Enter
if st.session_state.get("temp_user_input", "") and st.session_state.get("last_input", "") != st.session_state.get("temp_user_input", ""):
st.session_state["last_input"] = st.session_state["temp_user_input"]
send_message()
if st.session_state.history:
for i in range(0, len(st.session_state.history), 2):
st.write(f"Ви: {st.session_state.history[i]}")
if i + 1 < len(st.session_state.history):
st.write(f"Бот: {st.session_state.history[i+1]}") |