|
""" |
|
허깅페이스 Space를 위한, 개선된 사고 분석 앱 (분류 결과 순위 강조) |
|
app.py 파일로 저장하세요 |
|
""" |
|
|
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import re |
|
import pandas as pd |
|
import os |
|
import tempfile |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
BartForConditionalGeneration |
|
) |
|
|
|
|
|
MODEL_CONFIG = { |
|
|
|
"classification_repo": "jennalee1385/accident_classification", |
|
|
|
"summarization_repo": "gogamza/kobart-summarization", |
|
|
|
"max_length": 256, |
|
|
|
"summary_max_length": 128 |
|
} |
|
|
|
|
|
model_cache = { |
|
"classification_model": None, |
|
"classification_tokenizer": None, |
|
"summarization_model": None, |
|
"summarization_tokenizer": None, |
|
"label_map": None |
|
} |
|
|
|
|
|
def create_rank_badge(rank): |
|
"""순위에 따른 배지 HTML을 생성합니다""" |
|
if rank == 1: |
|
return f"<span style='display:inline-block; background-color:#FFD700; color:#000; padding:3px 8px; border-radius:10px; font-weight:bold;'>🥇 {rank}순위</span>" |
|
elif rank == 2: |
|
return f"<span style='display:inline-block; background-color:#C0C0C0; color:#000; padding:3px 8px; border-radius:10px; font-weight:bold;'>🥈 {rank}순위</span>" |
|
elif rank == 3: |
|
return f"<span style='display:inline-block; background-color:#CD7F32; color:#000; padding:3px 8px; border-radius:10px; font-weight:bold;'>🥉 {rank}순위</span>" |
|
else: |
|
return f"<span style='display:inline-block; background-color:#E0E0E0; color:#000; padding:3px 8px; border-radius:10px;'>{rank}순위</span>" |
|
|
|
|
|
def create_probability_bar(probability): |
|
"""확률에 따른 진행 바 HTML을 생성합니다""" |
|
width = int(probability * 100) |
|
bar_color = "#4CAF50" if width > 80 else "#FFC107" if width > 50 else "#F44336" |
|
|
|
return f""" |
|
<div style='width:100%; background-color:#f1f1f1; border-radius:5px; margin:5px 0;'> |
|
<div style='width:{width}%; height:10px; background-color:{bar_color}; border-radius:5px;'></div> |
|
</div> |
|
<span style='font-size:0.9em;'>{probability:.2%}</span> |
|
""" |
|
|
|
|
|
def create_sample_excel(): |
|
try: |
|
df = pd.DataFrame({ |
|
'사고경위': [ |
|
'작업자가 계단을 내려오던 중 발을 헛디뎌 넘어졌다.', |
|
'지게차 운전 중 장애물과 부딪혀 머리를 부딪힘', |
|
'용접 작업 중 불꽃이 튀어 화상을 입음' |
|
] |
|
}) |
|
file_path = "./사고분석_양식.xlsx" |
|
df.to_excel(file_path, index=False) |
|
|
|
print(f"샘플 파일 생성 완료: {os.path.abspath(file_path)}") |
|
return file_path |
|
except Exception as e: |
|
print(f"샘플 엑셀 파일 생성 오류: {e}") |
|
return None |
|
|
|
class AccidentAnalysisModel: |
|
"""사고 경위 분석 모델 클래스""" |
|
|
|
def __init__(self, model_repo=None): |
|
self.model_repo = model_repo or MODEL_CONFIG["classification_repo"] |
|
self.summarization_repo = MODEL_CONFIG["summarization_repo"] |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"사용 디바이스: {self.device}") |
|
|
|
|
|
self._load_from_huggingface() |
|
|
|
def _load_from_huggingface(self): |
|
"""허깅페이스에서 모델을 로드합니다.""" |
|
try: |
|
|
|
if model_cache["classification_model"] is not None: |
|
self.classification_model = model_cache["classification_model"] |
|
self.classification_tokenizer = model_cache["classification_tokenizer"] |
|
self.label_map = model_cache["label_map"] |
|
self.summarization_model = model_cache["summarization_model"] |
|
self.summarization_tokenizer = model_cache["summarization_tokenizer"] |
|
print("캐시된 모델 로드 완료!") |
|
return |
|
|
|
print("모델 로드 시작...") |
|
|
|
|
|
print(f"분류 모델 '{self.model_repo}' 로드 중...") |
|
self.classification_tokenizer = AutoTokenizer.from_pretrained(self.model_repo) |
|
self.classification_model = AutoModelForSequenceClassification.from_pretrained(self.model_repo) |
|
self.classification_model.to(self.device) |
|
|
|
|
|
try: |
|
import json |
|
import requests |
|
|
|
label_map_url = f"https://huggingface.co/{self.model_repo}/raw/main/label_map.json" |
|
response = requests.get(label_map_url) |
|
if response.status_code == 200: |
|
self.label_map = response.json() |
|
else: |
|
print("레이블 맵을 찾을 수 없어 기본값 사용") |
|
|
|
self.label_map = {str(k): v for k, v in self.classification_model.config.id2label.items()} |
|
|
|
print(f"레이블 맵 로드 완료: {self.label_map}") |
|
except Exception as e: |
|
print(f"레이블 맵 로드 오류: {e}") |
|
self.label_map = {str(i): f"클래스 {i}" for i in range(self.classification_model.config.num_labels)} |
|
|
|
|
|
if self.summarization_repo: |
|
try: |
|
print(f"요약 모델 '{self.summarization_repo}' 로드 중...") |
|
self.summarization_tokenizer = AutoTokenizer.from_pretrained(self.summarization_repo) |
|
self.summarization_model = BartForConditionalGeneration.from_pretrained(self.summarization_repo) |
|
self.summarization_model.to(self.device) |
|
print("요약 모델 로드 완료!") |
|
except Exception as e: |
|
print(f"요약 모델 로드 오류: {e}") |
|
self.summarization_model = None |
|
self.summarization_tokenizer = None |
|
else: |
|
self.summarization_model = None |
|
self.summarization_tokenizer = None |
|
|
|
|
|
model_cache["classification_model"] = self.classification_model |
|
model_cache["classification_tokenizer"] = self.classification_tokenizer |
|
model_cache["summarization_model"] = self.summarization_model |
|
model_cache["summarization_tokenizer"] = self.summarization_tokenizer |
|
model_cache["label_map"] = self.label_map |
|
|
|
print("모델 로드 완료!") |
|
|
|
except Exception as e: |
|
print(f"모델 로드 오류: {e}") |
|
raise |
|
|
|
def enhanced_generate_summary(self, text): |
|
"""개선된 요약 생성 함수""" |
|
try: |
|
|
|
if not text or len(text.strip()) < 80: |
|
return text |
|
|
|
|
|
if self.summarization_model is None or self.summarization_tokenizer is None: |
|
return text |
|
|
|
|
|
text = re.sub(r'\([^)]*\)', '', text) |
|
text = re.sub(r'\[[^\]]*\]', '', text) |
|
text = re.sub(r'\'[^\']*\'', '', text) |
|
text = re.sub(r'\"[^\"]*\"', '', text) |
|
text = re.sub(r'\s+', ' ', text).strip() |
|
|
|
|
|
inputs = self.summarization_tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=512, |
|
truncation=True, |
|
padding=True |
|
).to(self.device) |
|
|
|
|
|
self.summarization_model.eval() |
|
with torch.no_grad(): |
|
summary_ids = self.summarization_model.generate( |
|
inputs["input_ids"], |
|
max_length=MODEL_CONFIG["summary_max_length"], |
|
min_length=30, |
|
num_beams=5, |
|
early_stopping=True, |
|
length_penalty=1.5, |
|
repetition_penalty=2.5, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
|
|
summary = self.summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
return summary.strip() |
|
|
|
except Exception as e: |
|
print(f"요약 생성 오류: {e}") |
|
return text |
|
|
|
def analyze(self, text, top_k=3): |
|
"""사고 경위를 분석하여 요약과 분류 결과를 반환합니다.""" |
|
|
|
if not text or len(text.strip()) == 0: |
|
return { |
|
"summary": "입력된 텍스트가 없습니다.", |
|
"classification": [], |
|
"error": "텍스트를 입력해주세요." |
|
} |
|
|
|
try: |
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
text = text.strip() |
|
|
|
|
|
summary = self.enhanced_generate_summary(text) |
|
|
|
|
|
|
|
text_for_classification = summary if self.summarization_model else text |
|
|
|
|
|
self.classification_model.eval() |
|
|
|
|
|
inputs = self.classification_tokenizer( |
|
text_for_classification, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=MODEL_CONFIG["max_length"], |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.classification_model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=1) |
|
|
|
|
|
topk_values, topk_indices = torch.topk(probabilities, k=min(top_k, len(self.label_map)), dim=1) |
|
|
|
|
|
classification_results = [] |
|
for i in range(min(top_k, len(self.label_map))): |
|
idx = topk_indices[0][i].item() |
|
label_text = self.label_map.get(str(idx), f"클래스 {idx}") |
|
classification_results.append({ |
|
"rank": i + 1, |
|
"class": idx, |
|
"class_name": label_text, |
|
"probability": float(topk_values[0][i].item()) |
|
}) |
|
|
|
|
|
return { |
|
"summary": summary if summary != text else "사고 경위 텍스트가 짧아서 요약을 생략하겠습니다다.", |
|
"classification": classification_results, |
|
"error": None |
|
} |
|
|
|
except Exception as e: |
|
print(f"분석 중 오류 발생: {e}") |
|
return { |
|
"summary": "오류 발생", |
|
"classification": [], |
|
"error": str(e) |
|
} |
|
|
|
def analyze_batch(self, texts): |
|
"""여러 사고 경위를 배치로 분석합니다.""" |
|
results = [] |
|
for text in texts: |
|
if not text or len(text.strip()) == 0: |
|
results.append({ |
|
"input_text": text, |
|
"summary": "입력된 텍스트가 없습니다.", |
|
"classification": [], |
|
"error": "텍스트를 입력해주세요." |
|
}) |
|
continue |
|
|
|
try: |
|
|
|
result = self.analyze(text) |
|
|
|
result["input_text"] = text |
|
results.append(result) |
|
except Exception as e: |
|
results.append({ |
|
"input_text": text, |
|
"summary": "오류 발생", |
|
"classification": [], |
|
"error": str(e) |
|
}) |
|
|
|
return results |
|
|
|
def analyze_excel(self, file_path): |
|
"""엑셀 파일의 사고 경위를 분석합니다.""" |
|
try: |
|
|
|
df = pd.read_excel(file_path) |
|
|
|
|
|
if '사고경위' not in df.columns: |
|
return None, "엑셀 파일에 '사고경위' 열이 없습니다. 양식을 다운로드하여 사용해주세요." |
|
|
|
|
|
accident_texts = df['사고경위'].fillna('').tolist() |
|
|
|
|
|
results = self.analyze_batch(accident_texts) |
|
|
|
|
|
result_data = [] |
|
for result in results: |
|
|
|
top_classification = result["classification"][0] if result["classification"] else {"class_name": "분류 실패", "probability": 0.0} |
|
|
|
|
|
second_classification = result["classification"][1] if len(result["classification"]) > 1 else {"class_name": "-", "probability": 0.0} |
|
third_classification = result["classification"][2] if len(result["classification"]) > 2 else {"class_name": "-", "probability": 0.0} |
|
|
|
result_data.append({ |
|
"사고경위": result["input_text"], |
|
"요약": result["summary"], |
|
"1순위 사고유형": top_classification["class_name"], |
|
"1순위 확률": f"{top_classification['probability']:.4f}", |
|
"2순위 사고유형": second_classification["class_name"], |
|
"2순위 확률": f"{second_classification['probability']:.4f}", |
|
"3순위 사고유형": third_classification["class_name"], |
|
"3순위 확률": f"{third_classification['probability']:.4f}" |
|
}) |
|
|
|
result_df = pd.DataFrame(result_data) |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
result_file = os.path.join(temp_dir, "사고분석_결과.xlsx") |
|
result_df.to_excel(result_file, index=False) |
|
|
|
return result_file, None |
|
|
|
except Exception as e: |
|
print(f"엑셀 분석 오류: {e}") |
|
return None, f"엑셀 파일 분석 중 오류 발생: {str(e)}" |
|
|
|
|
|
def get_model(): |
|
"""모델 인스턴스를 반환합니다.""" |
|
return AccidentAnalysisModel(model_repo=MODEL_CONFIG["classification_repo"]) |
|
|
|
|
|
def analyze_single(text): |
|
"""단일 텍스트를 분석합니다.""" |
|
model = get_model() |
|
results = model.analyze(text) |
|
|
|
|
|
if results["error"]: |
|
return results["error"], gr.HTML("<div style='color:red'>분석 오류가 발생했습니다.</div>") |
|
|
|
|
|
classification_html = "<div style='background-color:#f8f9fa; padding:15px; border-radius:10px; border:1px solid #ddd;'>" |
|
|
|
for result in results["classification"]: |
|
rank = result["rank"] |
|
class_name = result["class_name"] |
|
probability = result["probability"] |
|
|
|
rank_badge = create_rank_badge(rank) |
|
probability_bar = create_probability_bar(probability) |
|
|
|
classification_html += f""" |
|
<div style='margin-bottom:15px; padding:10px; background-color:white; border-radius:8px; box-shadow:0 1px 3px rgba(0,0,0,0.1);'> |
|
<div style='display:flex; justify-content:space-between; align-items:center; margin-bottom:5px;'> |
|
<div style='font-weight:bold; font-size:1.1em;'>{rank_badge} {class_name}</div> |
|
</div> |
|
{probability_bar} |
|
</div> |
|
""" |
|
|
|
classification_html += "</div>" |
|
|
|
return results["summary"], gr.HTML(classification_html) |
|
|
|
|
|
def analyze_multiple(texts): |
|
"""여러 텍스트를 분석합니다.""" |
|
model = get_model() |
|
|
|
|
|
text_list = [t.strip() for t in texts.split('\n') if t.strip()] |
|
|
|
if not text_list: |
|
return "텍스트를 입력해주세요", gr.HTML("<div style='color:red'>분석할 텍스트가 없습니다.</div>") |
|
|
|
|
|
results = model.analyze_batch(text_list) |
|
|
|
|
|
summary_text = "" |
|
for idx, result in enumerate(results): |
|
summary_text += f"{idx+1}. {result['summary']}\n\n" |
|
|
|
|
|
classification_html = "<div style='background-color:#f8f9fa; padding:15px; border-radius:10px; border:1px solid #ddd;'>" |
|
|
|
for idx, result in enumerate(results): |
|
short_text = text_list[idx][:50] + "..." if len(text_list[idx]) > 50 else text_list[idx] |
|
|
|
classification_html += f""" |
|
<div style='margin-bottom:20px; padding:10px; background-color:white; border-radius:8px; box-shadow:0 2px 5px rgba(0,0,0,0.1);'> |
|
<div style='font-weight:bold; margin-bottom:10px; padding-bottom:5px; border-bottom:1px solid #eee;'> |
|
<span style='background-color:#007bff; color:white; padding:2px 8px; border-radius:12px; margin-right:8px;'>{idx+1}</span> |
|
{short_text} |
|
</div> |
|
""" |
|
|
|
if result["classification"]: |
|
for class_result in result["classification"]: |
|
rank = class_result.get("rank", 1) |
|
class_name = class_result["class_name"] |
|
probability = class_result["probability"] |
|
|
|
rank_badge = create_rank_badge(rank) |
|
probability_bar = create_probability_bar(probability) |
|
|
|
classification_html += f""" |
|
<div style='margin-bottom:8px; padding:8px; background-color:#f7f7f7; border-radius:6px;'> |
|
<div style='display:flex; justify-content:space-between; align-items:center; margin-bottom:3px;'> |
|
<div style='font-weight:bold;'>{rank_badge} {class_name}</div> |
|
</div> |
|
{probability_bar} |
|
</div> |
|
""" |
|
else: |
|
classification_html += "<div style='color:red; padding:10px;'>분류 결과가 없습니다.</div>" |
|
|
|
classification_html += "</div>" |
|
|
|
classification_html += "</div>" |
|
|
|
return summary_text, gr.HTML(classification_html) |
|
|
|
|
|
def analyze_excel_file(file): |
|
"""엑셀 파일을 분석합니다.""" |
|
if file is None: |
|
return None, "파일을 업로드해주세요." |
|
|
|
model = get_model() |
|
|
|
try: |
|
|
|
print(f"업로드된 파일: {file.name}") |
|
|
|
|
|
df = pd.read_excel(file.name) |
|
|
|
|
|
if '사고경위' not in df.columns: |
|
return None, "엑셀 파일에 '사고경위' 열이 없습니다. 양식을 다운로드하여 사용해주세요." |
|
|
|
|
|
accident_texts = df['사고경위'].fillna('').tolist() |
|
|
|
|
|
results = model.analyze_batch(accident_texts) |
|
|
|
|
|
result_df = df.copy() |
|
|
|
|
|
|
|
if '요약' not in result_df.columns: |
|
result_df['요약'] = '' |
|
if '1순위 사고유형' not in result_df.columns: |
|
result_df['1순위 사고유형'] = '' |
|
if '1순위 확률' not in result_df.columns: |
|
result_df['1순위 확률'] = '' |
|
if '2순위 사고유형' not in result_df.columns: |
|
result_df['2순위 사고유형'] = '' |
|
if '2순위 확률' not in result_df.columns: |
|
result_df['2순위 확률'] = '' |
|
if '3순위 사고유형' not in result_df.columns: |
|
result_df['3순위 사고유형'] = '' |
|
if '3순위 확률' not in result_df.columns: |
|
result_df['3순위 확률'] = '' |
|
|
|
|
|
for i, result in enumerate(results): |
|
if i >= len(result_df): |
|
break |
|
|
|
|
|
result_df.at[i, '요약'] = result['summary'] |
|
|
|
|
|
if result['classification']: |
|
|
|
if len(result['classification']) > 0: |
|
top = result['classification'][0] |
|
result_df.at[i, '1순위 사고유형'] = top['class_name'] |
|
result_df.at[i, '1순위 확률'] = f"{top['probability']:.4f}" |
|
|
|
|
|
if len(result['classification']) > 1: |
|
second = result['classification'][1] |
|
result_df.at[i, '2순위 사고유형'] = second['class_name'] |
|
result_df.at[i, '2순위 확률'] = f"{second['probability']:.4f}" |
|
|
|
|
|
if len(result['classification']) > 2: |
|
third = result['classification'][2] |
|
result_df.at[i, '3순위 사고유형'] = third['class_name'] |
|
result_df.at[i, '3순위 확률'] = f"{third['probability']:.4f}" |
|
|
|
|
|
result_file = os.path.join(tempfile.gettempdir(), "사고분석_결과.xlsx") |
|
result_df.to_excel(result_file, index=False) |
|
|
|
print(f"분석 결과 파일 저장 경로: {result_file}") |
|
return result_file, None |
|
|
|
except Exception as e: |
|
print(f"엑셀 파일 분석 중 오류 발생: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None, f"엑셀 파일 분석 중 오류 발생: {str(e)}" |
|
|
|
|
|
def download_excel_template(): |
|
"""엑셀 양식 파일을 생성하고 경로를 반환합니다.""" |
|
return create_sample_excel() |
|
|
|
|
|
def create_interface(): |
|
"""Gradio 인터페이스를 생성합니다.""" |
|
with gr.Blocks(title="사고 경위 분석 시스템", css=""" |
|
.container { margin: 0 auto; max-width: 1200px; } |
|
.header { text-align: center; margin-bottom: 20px; } |
|
.result-container { border: 1px solid #ddd; border-radius: 10px; padding: 15px; background-color: #f9f9f9; } |
|
.rank-badge { display: inline-block; padding: 3px 8px; border-radius: 12px; font-weight: bold; margin-right: 5px; } |
|
.rank-1 { background-color: #FFD700; color: #000; } |
|
.rank-2 { background-color: #C0C0C0; color: #000; } |
|
.rank-3 { background-color: #CD7F32; color: #000; } |
|
.prob-bar { height: 10px; background-color: #4CAF50; border-radius: 5px; } |
|
.footer { text-align: center; margin-top: 30px; padding-top: 15px; border-top: 1px solid #eee; } |
|
""") as app: |
|
gr.Markdown( |
|
""" |
|
# 🔍 사고 경위 분석 시스템 |
|
|
|
사고 경위 텍스트를 입력하면 AI가 사고 유형을 분류하고 요약을 제공합니다. |
|
""" |
|
) |
|
|
|
with gr.Tabs() as tabs: |
|
|
|
with gr.TabItem("개별 사고 분석"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox( |
|
label="사고 경위", |
|
placeholder="여기에 사고 경위를 입력하세요...", |
|
lines=10 |
|
) |
|
analyze_btn = gr.Button("분석하기", variant="primary") |
|
|
|
with gr.Column(): |
|
summary_output = gr.Textbox( |
|
label="사고 요약", |
|
lines=5 |
|
) |
|
classification_output = gr.HTML( |
|
label="사고 유형 분류 결과", |
|
value="<div style='padding:20px; text-align:center; color:#666;'>분석 결과가 여기에 표시됩니다.</div>" |
|
) |
|
|
|
analyze_btn.click( |
|
fn=analyze_single, |
|
inputs=[input_text], |
|
outputs=[summary_output, classification_output] |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### 🔍 개별 사고 분석 사용법 |
|
1. 분석할 사고 경위를 텍스트 상자에 입력하세요. |
|
2. '분석하기' 버튼을 클릭하여 결과를 확인하세요. |
|
3. 분석 결과는 사고 요약과 1~3순위 사고 유형 분류를 포함합니다. |
|
""" |
|
) |
|
|
|
|
|
with gr.TabItem("다중 사고 분석"): |
|
gr.Markdown("### 📊 여러 사고를 한 번에 분석") |
|
gr.Markdown("각 사고 경위를 개별 입력 상자에 입력하세요. 분석 결과는 각 입력 아래에 표시됩니다.") |
|
|
|
|
|
def analyze_multiple_fields(*input_texts): |
|
|
|
valid_inputs = [(i, text) for i, text in enumerate(input_texts) if text and text.strip()] |
|
|
|
if not valid_inputs: |
|
return [gr.update(visible=True, value="텍스트를 입력해주세요")] + [gr.update(visible=False) for _ in range(len(input_texts)*2-1)] |
|
|
|
model = get_model() |
|
results = [] |
|
summary_updates = [] |
|
html_updates = [] |
|
|
|
|
|
for i in range(len(input_texts)): |
|
if i < len(valid_inputs) and valid_inputs[i][0] == i: |
|
idx, text = valid_inputs[i] |
|
result = model.analyze(text) |
|
results.append(result) |
|
|
|
|
|
summary_updates.append(gr.update(visible=True, value=result["summary"])) |
|
|
|
|
|
classification_html = "<div style='background-color:#f8f9fa; padding:15px; border-radius:10px; border:1px solid #ddd;'>" |
|
|
|
for class_result in result["classification"]: |
|
rank = class_result.get("rank", 1) |
|
class_name = class_result["class_name"] |
|
probability = class_result["probability"] |
|
|
|
rank_badge = create_rank_badge(rank) |
|
probability_bar = create_probability_bar(probability) |
|
|
|
classification_html += f""" |
|
<div style='margin-bottom:8px; padding:8px; background-color:white; border-radius:6px;'> |
|
<div style='display:flex; justify-content:space-between; align-items:center; margin-bottom:3px;'> |
|
<div style='font-weight:bold;'>{rank_badge} {class_name}</div> |
|
</div> |
|
{probability_bar} |
|
</div> |
|
""" |
|
|
|
classification_html += "</div>" |
|
html_updates.append(gr.update(visible=True, value=classification_html)) |
|
else: |
|
summary_updates.append(gr.update(visible=False)) |
|
html_updates.append(gr.update(visible=False)) |
|
|
|
|
|
return summary_updates + html_updates |
|
|
|
|
|
analysis_items = [] |
|
summary_outputs = [] |
|
classification_outputs = [] |
|
|
|
|
|
for i in range(3): |
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
analysis_items.append(gr.Textbox( |
|
label=f"사고 경위 {i+1}", |
|
placeholder="여기에 사고 경위를 입력하세요...", |
|
lines=3 |
|
)) |
|
with gr.Column(scale=1): |
|
summary_outputs.append(gr.Textbox( |
|
label=f"요약 결과 {i+1}", |
|
lines=2, |
|
visible=False |
|
)) |
|
|
|
classification_outputs.append(gr.HTML( |
|
label=f"분류 결과 {i+1}", |
|
visible=False |
|
)) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
add_input_btn = gr.Button("입력 필드 추가", variant="secondary") |
|
with gr.Column(scale=1): |
|
remove_input_btn = gr.Button("입력 필드 제거", variant="secondary") |
|
with gr.Column(scale=2): |
|
batch_analyze_btn = gr.Button("일괄 분석", variant="primary") |
|
|
|
|
|
def add_input_field(): |
|
i = len(analysis_items) |
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
new_input = gr.Textbox( |
|
label=f"사고 경위 {i+1}", |
|
placeholder="여기에 사고 경위를 입력하세요...", |
|
lines=3 |
|
) |
|
analysis_items.append(new_input) |
|
with gr.Column(scale=1): |
|
new_summary = gr.Textbox( |
|
label=f"요약 결과 {i+1}", |
|
lines=2, |
|
visible=False |
|
) |
|
summary_outputs.append(new_summary) |
|
new_classification = gr.HTML( |
|
label=f"분류 결과 {i+1}", |
|
visible=False |
|
) |
|
classification_outputs.append(new_classification) |
|
|
|
|
|
return analysis_items + summary_outputs + classification_outputs |
|
|
|
|
|
def remove_input_field(): |
|
if len(analysis_items) > 1: |
|
analysis_items.pop() |
|
summary_outputs.pop() |
|
classification_outputs.pop() |
|
|
|
|
|
return analysis_items + summary_outputs + classification_outputs |
|
|
|
|
|
add_input_btn.click( |
|
fn=add_input_field, |
|
inputs=[], |
|
outputs=analysis_items + summary_outputs + classification_outputs |
|
) |
|
|
|
remove_input_btn.click( |
|
fn=remove_input_field, |
|
inputs=[], |
|
outputs=analysis_items + summary_outputs + classification_outputs |
|
) |
|
|
|
|
|
batch_analyze_btn.click( |
|
fn=analyze_multiple_fields, |
|
inputs=analysis_items, |
|
outputs=summary_outputs + classification_outputs |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### 📊 다중 사고 분석 사용법 |
|
1. 각 사고 경위를 개별 입력 상자에 입력하세요. |
|
2. 필요에 따라 '입력 필드 추가' 버튼으로 더 많은 사고 경위를 추가할 수 있습니다. |
|
3. '일괄 분석' 버튼을 클릭하여 모든 사고를 한 번에 분석하세요. |
|
4. 각 사고에 대한 요약과 사고 유형 분류 결과가 각 입력 필드 아래에 표시됩니다. |
|
""") |
|
|
|
|
|
with gr.TabItem("엑셀 파일 분석"): |
|
gr.Markdown("### 📑 엑셀 파일을 통한 다중 사고 분석") |
|
gr.Markdown(""" |
|
**양식 안내**: 엑셀 파일에는 '사고경위' 열이 포함되어야 합니다. |
|
각 행에 분석할 사고 경위를 입력한 후 업로드해 주세요. |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
excel_template_btn = gr.Button("엑셀 양식 다운로드", variant="secondary") |
|
template_output = gr.File(label="양식 파일") |
|
|
|
excel_template_btn.click( |
|
fn=download_excel_template, |
|
inputs=[], |
|
outputs=[template_output] |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
excel_input = gr.File( |
|
label="사고 경위 엑셀 파일 업로드 (.xlsx)", |
|
file_types=[".xlsx"] |
|
) |
|
excel_analyze_btn = gr.Button("엑셀 파일 분석", variant="primary") |
|
|
|
with gr.Column(): |
|
excel_result = gr.File(label="분석 결과 파일") |
|
excel_error = gr.Textbox(label="처리 상태", placeholder="파일을 업로드하고 분석 버튼을 클릭하세요.") |
|
|
|
excel_analyze_btn.click( |
|
fn=analyze_excel_file, |
|
inputs=[excel_input], |
|
outputs=[excel_result, excel_error] |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### 📑 엑셀 파일 분석 사용법 |
|
1. '엑셀 양식 다운로드' 버튼을 클릭하여 템플릿을 다운로드하세요. |
|
2. 템플릿에 사고 경위를 입력하고 저장하세요. |
|
3. 저장된 파일을 업로드하세요. |
|
4. '엑셀 파일 분석' 버튼을 클릭하여 결과 파일을 생성하세요. |
|
5. 생성된 결과 파일을 다운로드하여 확인하세요. |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
## 📊 사용 방법 |
|
- **개별 사고 분석**: 단일 사고 경위를 입력하고 상세 분석 결과를 확인합니다. |
|
- **다중 사고 분석**: 여러 사고 경위를 한 번에 입력하여 일괄 분석합니다. |
|
- **엑셀 파일 분석**: 다수의 사고 경위가 포함된 엑셀 파일을 업로드하여 분석합니다. |
|
|
|
## ℹ️ 모델 정보 |
|
- 분류 모델: {classification_repo} |
|
- 요약 모델: {summarization_repo} |
|
|
|
""".format( |
|
classification_repo=MODEL_CONFIG['classification_repo'], |
|
summarization_repo=MODEL_CONFIG['summarization_repo'] if MODEL_CONFIG['summarization_repo'] else "사용 안 함" |
|
) |
|
) |
|
|
|
return app |
|
|
|
|
|
def main(): |
|
"""메인 실행 함수""" |
|
|
|
try: |
|
_ = get_model() |
|
print("모델 사전 로드 완료!") |
|
except Exception as e: |
|
print(f"모델 사전 로드 실패: {e}") |
|
print("첫 요청 시 모델을 로드합니다.") |
|
|
|
|
|
app = create_interface() |
|
app.launch(share=False) |
|
|
|
if __name__ == "__main__": |
|
main() |