Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
@@ -147,6 +147,32 @@ async def chat(query,history,sources,reports,subtype, client_ip=None, session_id
|
|
147 |
sources=sources,subtype=subtype)
|
148 |
end_time = time.time()
|
149 |
print("Time for retriever:",end_time - start_time)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
context_retrieved_formatted = "||".join(doc.page_content for doc in context_retrieved)
|
151 |
context_retrieved_lst = [doc.page_content for doc in context_retrieved]
|
152 |
|
@@ -245,6 +271,7 @@ async def chat(query,history,sources,reports,subtype, client_ip=None, session_id
|
|
245 |
answer_yet += token
|
246 |
parsed_answer = parse_output_llm_with_sources(answer_yet)
|
247 |
history[-1] = (query, parsed_answer)
|
|
|
248 |
logs_data["answer"] = parsed_answer
|
249 |
yield [tuple(x) for x in history], docs_html, logs_data, session_id
|
250 |
end_time = time.time()
|
@@ -416,7 +443,7 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
|
|
416 |
label = "Filter for Sub-Type",
|
417 |
interactive=True)
|
418 |
|
419 |
-
#----------- update the
|
420 |
def rs_change(rs):
|
421 |
if rs: # Only update choices if a value is selected
|
422 |
return gr.update(choices=new_files[rs], value=None) # Set value to None (no preselection)
|
@@ -649,9 +676,19 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
|
|
649 |
|
650 |
|
651 |
|
652 |
-
def show_feedback(
|
653 |
-
"""
|
654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
|
656 |
def submit_feedback_okay(logs_data):
|
657 |
"""Handle 'okay' feedback submission"""
|
@@ -683,30 +720,260 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
|
|
683 |
def get_client_ip_handler(dummy_input="", request: gr.Request = None):
|
684 |
"""Handler for getting client IP in Gradio context"""
|
685 |
return get_client_ip(request)
|
686 |
-
|
687 |
|
688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
|
690 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
(textbox
|
692 |
-
.submit(
|
693 |
-
|
694 |
-
|
695 |
-
[
|
696 |
-
|
697 |
-
|
698 |
-
.then(
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
|
711 |
demo.queue()
|
712 |
|
|
|
147 |
sources=sources,subtype=subtype)
|
148 |
end_time = time.time()
|
149 |
print("Time for retriever:",end_time - start_time)
|
150 |
+
|
151 |
+
# WARNING FOR NO CONTEXT: Check if any paragraphs were retrieved, add warning if none found
|
152 |
+
# We use this in the Gradio UI below (displays in the chat dialogue box)
|
153 |
+
if not context_retrieved or len(context_retrieved) == 0:
|
154 |
+
warning_message = "⚠️ **No relevant information was found in the audit reports pertaining your query.** Please try rephrasing your question or selecting different report filters."
|
155 |
+
history[-1] = (query, warning_message)
|
156 |
+
# Update logs with the warning instead of answer
|
157 |
+
logs_data = {
|
158 |
+
"record_id": str(uuid4()),
|
159 |
+
"session_id": session_id,
|
160 |
+
"session_duration_seconds": session_duration,
|
161 |
+
"client_location": session_data['location_info'],
|
162 |
+
"platform": session_data['platform_info'],
|
163 |
+
"year": year,
|
164 |
+
"question": query,
|
165 |
+
"retriever": model_config.get('retriever','MODEL'),
|
166 |
+
"endpoint_type": model_config.get('reader','TYPE'),
|
167 |
+
"reader": model_config.get('reader','NVIDIA_MODEL'),
|
168 |
+
"answer": warning_message,
|
169 |
+
"no_results": True # Flag to indicate no results were found
|
170 |
+
}
|
171 |
+
yield [tuple(x) for x in history], "", logs_data, session_id
|
172 |
+
# Save log for the warning response
|
173 |
+
save_logs(scheduler, JSON_DATASET_PATH, logs_data)
|
174 |
+
return
|
175 |
+
|
176 |
context_retrieved_formatted = "||".join(doc.page_content for doc in context_retrieved)
|
177 |
context_retrieved_lst = [doc.page_content for doc in context_retrieved]
|
178 |
|
|
|
271 |
answer_yet += token
|
272 |
parsed_answer = parse_output_llm_with_sources(answer_yet)
|
273 |
history[-1] = (query, parsed_answer)
|
274 |
+
# update logs_data with current answer
|
275 |
logs_data["answer"] = parsed_answer
|
276 |
yield [tuple(x) for x in history], docs_html, logs_data, session_id
|
277 |
end_time = time.time()
|
|
|
443 |
label = "Filter for Sub-Type",
|
444 |
interactive=True)
|
445 |
|
446 |
+
#----------- update the second level filter based on values from first level ----------------
|
447 |
def rs_change(rs):
|
448 |
if rs: # Only update choices if a value is selected
|
449 |
return gr.update(choices=new_files[rs], value=None) # Set value to None (no preselection)
|
|
|
676 |
|
677 |
|
678 |
|
679 |
+
def show_feedback(logs_data):
|
680 |
+
"""Handle feedback display with proper output format"""
|
681 |
+
if logs_data is None:
|
682 |
+
return (
|
683 |
+
gr.update(visible=False), # feedback_row
|
684 |
+
gr.update(visible=False), # feedback_thanks
|
685 |
+
None # feedback_state
|
686 |
+
)
|
687 |
+
return (
|
688 |
+
gr.update(visible=True), # feedback_row
|
689 |
+
gr.update(visible=False), # feedback_thanks
|
690 |
+
logs_data # feedback_state
|
691 |
+
)
|
692 |
|
693 |
def submit_feedback_okay(logs_data):
|
694 |
"""Handle 'okay' feedback submission"""
|
|
|
720 |
def get_client_ip_handler(dummy_input="", request: gr.Request = None):
|
721 |
"""Handler for getting client IP in Gradio context"""
|
722 |
return get_client_ip(request)
|
|
|
723 |
|
724 |
+
#-------------------- No Filters Set Warning -------------------------
|
725 |
+
# Warn users when no filters are selected
|
726 |
+
warning_state = gr.State(False)
|
727 |
+
pending_query = gr.State(None)
|
728 |
+
|
729 |
+
def show_warning():
|
730 |
+
"""Show warning popup when no filters selected"""
|
731 |
+
return gr.update(visible=True)
|
732 |
+
|
733 |
+
def hide_warning():
|
734 |
+
"""Hide warning popup"""
|
735 |
+
return gr.update(visible=False)
|
736 |
+
|
737 |
+
# Logic needs to be changed to accomodate default filter values (currently I have them all set to None)
|
738 |
+
def check_filters(check_status, textbox_value, sources, reports, subtype):
|
739 |
+
"""Check if any filters are selected"""
|
740 |
+
# If a previous check failed, don't continue with this check
|
741 |
+
if check_status is not None:
|
742 |
+
return (
|
743 |
+
check_status, # keep current check status
|
744 |
+
False, # keep warning state unchanged
|
745 |
+
gr.update(visible=False), # keep warning row visibility unchanged
|
746 |
+
textbox_value, # keep the textbox value
|
747 |
+
None # no need to store query
|
748 |
+
)
|
749 |
+
|
750 |
+
no_filters = (not reports) and (not sources) and (not subtype)
|
751 |
+
if no_filters:
|
752 |
+
# If no filters, show warning and set status
|
753 |
+
return (
|
754 |
+
"filter", # check status - no filters selected
|
755 |
+
True, # warning state
|
756 |
+
gr.update(visible=True), # warning row visibility
|
757 |
+
gr.update(value=""), # clear textbox
|
758 |
+
textbox_value # store the query
|
759 |
+
)
|
760 |
+
# If filters exist, proceed normally
|
761 |
+
return (
|
762 |
+
None, # no check failed
|
763 |
+
False, # normal state
|
764 |
+
gr.update(visible=False), # hide warning
|
765 |
+
textbox_value, # keep the original value
|
766 |
+
None # no need to store query
|
767 |
+
)
|
768 |
|
769 |
+
async def handle_chat_flow(check_status, warning_active, short_query_warning_active, query, chatbot, sources, reports, subtype, client_ip, session_id):
|
770 |
+
"""Handle chat flow with explicit check for status"""
|
771 |
+
# Don't proceed if any check failed or query is None
|
772 |
+
if check_status is not None or warning_active or short_query_warning_active or query is None or query == "":
|
773 |
+
yield (
|
774 |
+
chatbot, # unchanged chatbot
|
775 |
+
"", # empty sources
|
776 |
+
None, # no feedback state
|
777 |
+
session_id # keep session
|
778 |
+
)
|
779 |
+
return # Exit the generator
|
780 |
+
|
781 |
+
# Include start_chat functionality here
|
782 |
+
history = chatbot + [(query, None)]
|
783 |
+
history = [tuple(x) for x in history]
|
784 |
+
|
785 |
+
# Proceed with chat and yield each update
|
786 |
+
async for update in chat(query, history, sources, reports, subtype, client_ip, session_id):
|
787 |
+
yield update
|
788 |
+
|
789 |
+
#-------------------- Short Query Warning -------------------------
|
790 |
+
# Warn users when query is too short (less than 4 words)
|
791 |
+
short_query_warning_state = gr.State(False)
|
792 |
+
check_status = gr.State(None)
|
793 |
+
|
794 |
+
def check_query_length(textbox_value):
|
795 |
+
"""Check if query has at least 4 words"""
|
796 |
+
if textbox_value and len(textbox_value.split()) < 3:
|
797 |
+
# If query is too short, show warning and set status
|
798 |
+
return (
|
799 |
+
"short", # check status - this query is too short
|
800 |
+
True, # short query warning state
|
801 |
+
gr.update(visible=True), # short query warning row visibility
|
802 |
+
gr.update(value=""), # clear textbox
|
803 |
+
textbox_value # store the query
|
804 |
+
)
|
805 |
+
# If query is long enough, proceed normally
|
806 |
+
return (
|
807 |
+
None, # no check failed
|
808 |
+
False, # normal state
|
809 |
+
gr.update(visible=False), # hide warning
|
810 |
+
gr.update(value=textbox_value), # keep the textbox value
|
811 |
+
None # no need to store query
|
812 |
+
)
|
813 |
+
|
814 |
+
|
815 |
+
#-------------------- Gradio Handlers -------------------------
|
816 |
+
|
817 |
+
# Hanlders: Text input from Textbox
|
818 |
(textbox
|
819 |
+
.submit(
|
820 |
+
check_query_length,
|
821 |
+
[textbox],
|
822 |
+
[check_status, short_query_warning_state, short_query_warning_row, textbox, pending_query],
|
823 |
+
api_name="check_query_length_textbox"
|
824 |
+
)
|
825 |
+
.then(
|
826 |
+
check_filters,
|
827 |
+
[check_status, textbox, dropdown_sources, dropdown_reports, dropdown_category],
|
828 |
+
[check_status, warning_state, warning_row, textbox, pending_query],
|
829 |
+
api_name="submit_textbox",
|
830 |
+
show_progress=False
|
831 |
+
)
|
832 |
+
.then(
|
833 |
+
get_client_ip_handler,
|
834 |
+
[textbox],
|
835 |
+
[client_ip],
|
836 |
+
show_progress=False,
|
837 |
+
api_name="get_client_ip_textbox"
|
838 |
+
)
|
839 |
+
.then(
|
840 |
+
handle_chat_flow,
|
841 |
+
[check_status, warning_state, short_query_warning_state, textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
|
842 |
+
[chatbot, sources_textbox, feedback_state, session_id],
|
843 |
+
queue=True,
|
844 |
+
api_name="handle_chat_flow_textbox"
|
845 |
+
)
|
846 |
+
.then(
|
847 |
+
show_feedback,
|
848 |
+
[feedback_state],
|
849 |
+
[feedback_row, feedback_thanks, feedback_state],
|
850 |
+
api_name="show_feedback_textbox"
|
851 |
+
)
|
852 |
+
.then(
|
853 |
+
finish_chat,
|
854 |
+
None,
|
855 |
+
[textbox],
|
856 |
+
api_name="finish_chat_textbox"
|
857 |
+
))
|
858 |
+
|
859 |
+
# Hanlders: Text input from Examples (same chain as textbox)
|
860 |
+
examples_hidden.change(
|
861 |
+
lambda x: x,
|
862 |
+
inputs=examples_hidden,
|
863 |
+
outputs=textbox,
|
864 |
+
api_name="submit_examples"
|
865 |
+
).then(
|
866 |
+
check_query_length,
|
867 |
+
[textbox],
|
868 |
+
[check_status, short_query_warning_state, short_query_warning_row, textbox, pending_query],
|
869 |
+
api_name="check_query_length_examples"
|
870 |
+
).then(
|
871 |
+
check_filters,
|
872 |
+
[check_status, textbox, dropdown_sources, dropdown_reports, dropdown_category],
|
873 |
+
[check_status, warning_state, warning_row, textbox, pending_query],
|
874 |
+
api_name="check_filters_examples",
|
875 |
+
show_progress=False
|
876 |
+
).then(
|
877 |
+
get_client_ip_handler,
|
878 |
+
[textbox],
|
879 |
+
[client_ip],
|
880 |
+
show_progress=False,
|
881 |
+
api_name="get_client_ip_examples"
|
882 |
+
).then(
|
883 |
+
handle_chat_flow,
|
884 |
+
[check_status, warning_state, short_query_warning_state, textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
|
885 |
+
[chatbot, sources_textbox, feedback_state, session_id],
|
886 |
+
queue=True,
|
887 |
+
api_name="handle_chat_flow_examples"
|
888 |
+
).then(
|
889 |
+
show_feedback,
|
890 |
+
[feedback_state],
|
891 |
+
[feedback_row, feedback_thanks, feedback_state],
|
892 |
+
api_name="show_feedback_examples"
|
893 |
+
).then(
|
894 |
+
finish_chat,
|
895 |
+
None,
|
896 |
+
[textbox],
|
897 |
+
api_name="finish_chat_examples"
|
898 |
+
)
|
899 |
+
|
900 |
+
|
901 |
+
# Handlers for the warning buttons
|
902 |
+
proceed_btn.click(
|
903 |
+
lambda query: (
|
904 |
+
None, # reset check status
|
905 |
+
False, # warning state
|
906 |
+
gr.update(visible=False), # warning row
|
907 |
+
gr.update(value=query if query else "", interactive=True), # restore query
|
908 |
+
None # clear pending query
|
909 |
+
),
|
910 |
+
pending_query,
|
911 |
+
[check_status, warning_state, warning_row, textbox, pending_query]
|
912 |
+
).then(
|
913 |
+
get_client_ip_handler,
|
914 |
+
[textbox],
|
915 |
+
[client_ip]
|
916 |
+
).then(
|
917 |
+
handle_chat_flow,
|
918 |
+
[check_status, warning_state, short_query_warning_state, textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
|
919 |
+
[chatbot, sources_textbox, feedback_state, session_id],
|
920 |
+
queue=True
|
921 |
+
).then(
|
922 |
+
show_feedback,
|
923 |
+
[feedback_state],
|
924 |
+
[feedback_row, feedback_thanks, feedback_state]
|
925 |
+
).then(
|
926 |
+
finish_chat,
|
927 |
+
None,
|
928 |
+
[textbox]
|
929 |
+
)
|
930 |
+
|
931 |
+
# Cancel button for no filters
|
932 |
+
cancel_btn.click(
|
933 |
+
lambda: (
|
934 |
+
None, # reset check status
|
935 |
+
False, # warning state
|
936 |
+
gr.update(visible=False), # warning row
|
937 |
+
gr.update(value="", interactive=True), # clear textbox
|
938 |
+
None # clear pending query
|
939 |
+
),
|
940 |
+
None,
|
941 |
+
[check_status, warning_state, warning_row, textbox, pending_query]
|
942 |
+
)
|
943 |
+
|
944 |
+
# short query warning OK button
|
945 |
+
short_query_proceed_btn.click(
|
946 |
+
lambda query: (
|
947 |
+
None, # reset check status
|
948 |
+
False, # short query warning state
|
949 |
+
gr.update(visible=False), # short query warning row
|
950 |
+
gr.update(value=query if query else "", interactive=True), # restore query
|
951 |
+
None # clear pending query
|
952 |
+
),
|
953 |
+
pending_query,
|
954 |
+
[check_status, short_query_warning_state, short_query_warning_row, textbox, pending_query]
|
955 |
+
)
|
956 |
+
|
957 |
+
|
958 |
+
#(textbox
|
959 |
+
# .submit(get_client_ip_handler, [textbox], [client_ip], api_name="get_ip_textbox")
|
960 |
+
# .then(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
|
961 |
+
# .then(chat,
|
962 |
+
# [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
|
963 |
+
# [chatbot, sources_textbox, feedback_state, session_id],
|
964 |
+
# queue=True, concurrency_limit=8, api_name="chat_textbox")
|
965 |
+
# .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_textbox")
|
966 |
+
# .then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
|
967 |
+
|
968 |
+
#(examples_hidden
|
969 |
+
# .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
|
970 |
+
# .then(get_client_ip_handler, [examples_hidden], [client_ip], api_name="get_ip_examples")
|
971 |
+
# .then(chat,
|
972 |
+
# [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, client_ip, session_id],
|
973 |
+
# [chatbot, sources_textbox, feedback_state, session_id],
|
974 |
+
# concurrency_limit=8, api_name="chat_examples")
|
975 |
+
# .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_examples")
|
976 |
+
# .then(finish_chat, None, [textbox], api_name="finish_chat_examples"))
|
977 |
|
978 |
demo.queue()
|
979 |
|