import streamlit as st import duckdb import pandas as pd from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode from st_link_analysis import st_link_analysis, NodeStyle, EdgeStyle from graph_builder import StLinkBuilder # Node styles configuration NODE_STYLES = [ NodeStyle("EVENT", "#FF7F3E", "name", "description"), NodeStyle("PERSON", "#4CAF50", "name", "person"), NodeStyle("NAME", "#2A629A", "created_at", "badge"), NodeStyle("ORGANIZATION", "#9C27B0", "name", "business"), NodeStyle("LOCATION", "#2196F3", "name", "place"), NodeStyle("THEME", "#FFC107", "name", "sell"), NodeStyle("COUNT", "#795548", "name", "inventory"), NodeStyle("AMOUNT", "#607D8B", "name", "wallet"), ] # Edge styles configuration EDGE_STYLES = [ EdgeStyle("MENTIONED_IN", caption="label", directed=True), EdgeStyle("LOCATED_IN", caption="label", directed=True), EdgeStyle("CATEGORIZED_AS", caption="label", directed=True) ] def initialize_db(): """Initialize database connection and create dataset view""" con = duckdb.connect() con.execute(""" CREATE VIEW negative_tone AS ( SELECT * FROM read_parquet('hf://datasets/dwb2023/gdelt-gkg-march2020-v2@~parquet/default/negative_tone/*.parquet') ); """) return con def fetch_data(con, source_filter=None, start_date=None, end_date=None, limit=50, include_all_columns=False): """Fetch filtered data from the database""" if include_all_columns: columns = "*" else: columns = "GKGRECORDID, DATE, SourceCommonName, tone, DocumentIdentifier, 'V2.1Quotations', SourceCollectionIdentifier" query = f""" SELECT {columns} FROM negative_tone WHERE TRUE """ params = [] if source_filter: query += " AND SourceCommonName ILIKE ?" params.append(f"%{source_filter}%") if start_date: query += " AND DATE >= ?" params.append(start_date) if end_date: query += " AND DATE <= ?" params.append(end_date) if limit: query += f" LIMIT {limit}" try: result = con.execute(query, params) return result.fetchdf() except Exception as e: st.error(f"Query execution failed: {str(e)}") return pd.DataFrame() def render_data_grid(df): """ Render an interactive data grid (with built‑in filtering) and return the selected row. The grid is configured to show only the desired columns (ID, Date, Source, Tone) and allow filtering/search on each. """ st.subheader("Search and Filter Records") # Build grid options with AgGrid gb = GridOptionsBuilder.from_dataframe(df) gb.configure_default_column(filter=True, sortable=True, resizable=True) # Enable single row selection gb.configure_selection('single', use_checkbox=False) grid_options = gb.build() # Render AgGrid (the grid will have a filter field for each column) grid_response = AgGrid( df, gridOptions=grid_options, update_mode=GridUpdateMode.SELECTION_CHANGED, height=400, fit_columns_on_grid_load=True ) selected = grid_response.get('selected_rows') if selected is not None: # If selected is a DataFrame, use iloc to get the first row. if isinstance(selected, pd.DataFrame): if not selected.empty: return selected.iloc[0].to_dict() # Otherwise, if it's a list, get the first element. elif isinstance(selected, list) and len(selected) > 0: return selected[0] return None def render_graph(record): """ Render a graph visualization for the selected record. Uses StLinkBuilder to convert the record into graph format and then displays the graph using st_link_analysis. """ st.subheader(f"Event Graph: {record.get('GKGRECORDID', 'Unknown')}") stlink_builder = StLinkBuilder() # Convert the record (a Series) into a DataFrame with one row record_df = pd.DataFrame([record]) graph_data = stlink_builder.build_graph(record_df) return st_link_analysis( elements=graph_data, layout="fcose", # Column configuration for data grid - cose, fcose, breadthfirst, cola node_styles=NODE_STYLES, edge_styles=EDGE_STYLES ) def main(): st.title("🔍 COVID Event Graph Explorer") st.markdown(""" **Interactive Event Graph Viewer** Filter and select individual COVID-19 event records to display their detailed graph representations. Analyze relationships between events and associated entities using the interactive graph below. """) # Initialize database connection using context manager with initialize_db() as con: if con is not None: # Add UI components # Sidebar controls with st.sidebar: st.header("Search Filters") source = st.text_input("Filter by source name") start_date = st.text_input("Start date (YYYYMMDD)", "20200314") end_date = st.text_input("End date (YYYYMMDD)", "20200315") limit = st.slider("Number of results to display", 10, 500, 100) # Fetch initial data view df_initial = fetch_data( con=con, source_filter=source, start_date=start_date, end_date=end_date, limit=limit, include_all_columns=False ) # Fetch full records for selection df_full = fetch_data( con=con, source_filter=source, start_date=start_date, end_date=end_date, limit=limit, include_all_columns=True ) # Create a DataFrame for the grid with only the key columns grid_df = df_initial[['GKGRECORDID', 'DATE', 'SourceCommonName', 'tone', 'DocumentIdentifier', 'SourceCollectionIdentifier']].copy() grid_df.columns = ['ID', 'Date', 'Source', 'Tone', 'Doc ID', 'Source Collection ID'] # Render the interactive data grid at the top selected_row = render_data_grid(grid_df) if selected_row: # Find the full record in the original DataFrame using the selected ID selected_id = selected_row['ID'] full_record = df_full[df_full['GKGRECORDID'] == selected_id].iloc[0] # Display the graph and raw data below the grid render_graph(full_record) else: st.info("Use the grid filters above to search and select a record.") else: st.warning("No matching records found.") # Close database connection con.close() main()