#import spaces # import time print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}") # source openalex_env_map/bin/activate # Standard library imports import os #Enforce local cching: # os.makedirs("./pip_cache", exist_ok=True) # Pip: # os.makedirs("./pip_cache", exist_ok=True) # os.environ["PIP_CACHE_DIR"] = os.path.abspath("./pip_cache") # # MPL: # os.makedirs("./mpl_cache", exist_ok=True) # os.environ["MPLCONFIGDIR"] = os.path.abspath("./mpl_cache") # #Transformers # os.makedirs("./transformers_cache", exist_ok=True) # os.environ["TRANSFORMERS_CACHE"] = os.path.abspath("./transformers_cache") # import numba # print(numba.config) # print("Numba threads:", numba.get_num_threads()) # numba.set_num_threads(16) # print("Updated Numba threads:", numba.get_num_threads()) # import datamapplot.medoids # print(help(datamapplot.medoids)) from pathlib import Path from datetime import datetime from itertools import chain import ast # Add this import at the top with the standard library imports import base64 import json import pickle # Third-party imports import numpy as np import pandas as pd import torch import gradio as gr print(f"Gradio version: {gr.__version__}") import subprocess import re from color_utils import rgba_to_hex def print_datamapplot_version(): try: # On Unix systems, you can pipe commands by setting shell=True. version = subprocess.check_output("pip freeze | grep datamapplot", shell=True, text=True) print("datamapplot version:", version.strip()) except subprocess.CalledProcessError: print("datamapplot not found in pip freeze output.") print_datamapplot_version() from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn import matplotlib.pyplot as plt import tqdm import colormaps import matplotlib.colors as mcolors from matplotlib.colors import Normalize import random import opinionated # for fonts plt.style.use("opinionated_rc") from sklearn.neighbors import NearestNeighbors def is_running_in_hf_zero_gpu(): print(os.environ.get("SPACES_ZERO_GPU")) return os.environ.get("SPACES_ZERO_GPU") is_running_in_hf_zero_gpu() def is_running_in_hf_space(): return "SPACE_ID" in os.environ # #if is_running_in_hf_space(): # from spaces.zero.client import _get_token try: import spaces from spaces.zero.client import _get_token HAS_SPACES = True except (ImportError, ModuleNotFoundError): HAS_SPACES = False # Provide a harmless fallback so decorators don't explode if not HAS_SPACES: class _Dummy: def GPU(self, *a, **k): def deco(f): # no-op decorator return f return deco spaces = _Dummy() # fake module object def _get_token(request): # stub, never called off-Space return "" #if is_running_in_hf_space(): #import spaces # necessary to run on Zero. #print(f"Spaces version: {spaces.__version__}") import datamapplot import pyalex # Local imports from openalex_utils import ( openalex_url_to_pyalex_query, get_field, process_records_to_df, openalex_url_to_filename, get_records_from_dois, openalex_url_to_readable_name ) from ui_utils import highlight_queries from styles import DATAMAP_CUSTOM_CSS from data_setup import ( download_required_files, setup_basemap_data, setup_mapper, setup_embedding_model, ) from network_utils import create_citation_graph, draw_citation_graph # Add colormap chooser imports from colormap_chooser import ColormapChooser, setup_colormaps # Add legend builder imports try: from legend_builders import continuous_legend_html_css, categorical_legend_html_css HAS_LEGEND_BUILDERS = True except ImportError: print("Warning: legend_builders.py not found. Legends will be disabled.") HAS_LEGEND_BUILDERS = False # Configure OpenAlex pyalex.config.email = "maximilian.noichl@uni-bamberg.de" print(f"Imports completed: {time.strftime('%Y-%m-%d %H:%M:%S')}") # Set up colormaps for the chooser print("Setting up colormaps...") colormap_categories = setup_colormaps( included_collections=['matplotlib', 'cmocean', 'scientific', 'cmasher'], excluded_collections=['colorcet', 'carbonplan', 'sciviz'] ) colormap_chooser = ColormapChooser( categories=colormap_categories, smooth_steps=10, strip_width=200, strip_height=50, css_height=200, # show_search=False, # show_category=False, # show_preview=False, # show_selected_name=True, # show_selected_info=False, gallery_kwargs=dict(columns=3, allow_preview=False, height="200px") ) # Create a static directory to store the dynamic HTML files static_dir = Path("./static") static_dir.mkdir(parents=True, exist_ok=True) # Tell Gradio which absolute paths are allowed to be served os.environ["GRADIO_ALLOWED_PATHS"] = str(static_dir.resolve()) print("os.environ['GRADIO_ALLOWED_PATHS'] =", os.environ["GRADIO_ALLOWED_PATHS"]) # Create FastAPI app app = FastAPI() # Mount the static directory app.mount("/static", StaticFiles(directory="static"), name="static") # Resource configuration REQUIRED_FILES = { "100k_filtered_OA_sample_cluster_and_positions_supervised.pkl": "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/100k_filtered_OA_sample_cluster_and_positions_supervised.pkl", "umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl": "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl" } BASEMAP_PATH = "100k_filtered_OA_sample_cluster_and_positions_supervised.pkl" MAPPER_PARAMS_PATH = "umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl" MODEL_NAME = "m7n/discipline-tuned_specter_2_024" # Initialize models and data start_time = time.time() print("Initializing resources...") download_required_files(REQUIRED_FILES) basedata_df = setup_basemap_data(BASEMAP_PATH) mapper = setup_mapper(MAPPER_PARAMS_PATH) model = setup_embedding_model(MODEL_NAME) print(f"Resources initialized in {time.time() - start_time:.2f} seconds") # Setting up decorators for embedding on HF-Zero: def no_op_decorator(func): """A no-op (no operation) decorator that simply returns the function.""" def wrapper(*args, **kwargs): # Do nothing special return func(*args, **kwargs) return wrapper # # Decide which decorator to use based on environment # decorator_to_use = spaces.GPU() if is_running_in_hf_space() else no_op_decorator # #duration=120 @spaces.GPU(duration=1) # ← forces the detector to see a GPU-aware fn def _warmup(): print("Warming up...") _warmup() # if is_running_in_hf_space(): @spaces.GPU(duration=30) def create_embeddings_30(texts_to_embedd): """Create embeddings for the input texts using the loaded model.""" return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192) @spaces.GPU(duration=59) def create_embeddings_59(texts_to_embedd): """Create embeddings for the input texts using the loaded model.""" return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192) @spaces.GPU(duration=120) def create_embeddings_120(texts_to_embedd): """Create embeddings for the input texts using the loaded model.""" return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192) @spaces.GPU(duration=299) def create_embeddings_299(texts_to_embedd): """Create embeddings for the input texts using the loaded model.""" return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192) # else: def create_embeddings(texts_to_embedd): """Create embeddings for the input texts using the loaded model.""" return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192) def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox, sample_reduction_method, plot_type_dropdown, locally_approximate_publication_date_checkbox, download_csv_checkbox, download_png_checkbox, citation_graph_checkbox, csv_upload, highlight_color, selected_colormap_name, seed_value, progress=gr.Progress()): """ Main prediction pipeline that processes OpenAlex queries and creates visualizations. Args: request (gr.Request): Gradio request object text_input (str): OpenAlex query URL sample_size_slider (int): Maximum number of samples to process reduce_sample_checkbox (bool): Whether to reduce sample size sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results") plot_type_dropdown (str): Type of plot coloring ("No special coloring", "Time-based coloring", "Categorical coloring") locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting. download_csv_checkbox (bool): Whether to download CSV data download_png_checkbox (bool): Whether to download PNG data citation_graph_checkbox (bool): Whether to add citation graph csv_upload (str): Path to uploaded CSV file highlight_color (str): Color for highlighting points selected_colormap_name (str): Name of the selected colormap for time-based coloring progress (gr.Progress): Gradio progress tracker Returns: tuple: (link to visualization, iframe HTML) """ # Initialize start_time at the beginning of the function start_time = time.time() # Convert dropdown selection to boolean flags for backward compatibility plot_time_checkbox = plot_type_dropdown == "Time-based coloring" treat_as_categorical_checkbox = plot_type_dropdown == "Categorical coloring" # Helper function to generate error responses def create_error_response(error_message): return [ error_message, gr.DownloadButton(label="Download Interactive Visualization", value='html_file_path', visible=False), gr.DownloadButton(label="Download CSV Data", value='csv_file_path', visible=False), gr.DownloadButton(label="Download Static Plot", value='png_file_path', visible=False), gr.Button(visible=False) ] # Get the authentication token if is_running_in_hf_space(): token = _get_token(request) payload = token.split('.')[1] payload = f"{payload}{'=' * ((4 - len(payload) % 4) % 4)}" payload = json.loads(base64.urlsafe_b64decode(payload).decode()) print(payload) user = payload['user'] if user == None: user_type = "anonymous" elif '[pro]' in user: user_type = "pro" else: user_type = "registered" print(f"User type: {user_type}") # Check if a file has been uploaded or if we need to use OpenAlex query if csv_upload is not None: print(f"Using uploaded file instead of OpenAlex query: {csv_upload}") try: file_extension = os.path.splitext(csv_upload)[1].lower() if file_extension == '.csv': # Read the CSV file records_df = pd.read_csv(csv_upload) filename = os.path.splitext(os.path.basename(csv_upload))[0] # Check if this is a DOI-list CSV (single column, named 'doi' or similar) if (len(records_df.columns) == 1 and records_df.columns[0].lower() in ['doi', 'dois']): from openalex_utils import get_records_from_dois doi_list = records_df.iloc[:,0].dropna().astype(str).tolist() print(f"Detected DOI list with {len(doi_list)} DOIs. Downloading records from OpenAlex...") records_df = get_records_from_dois(doi_list) filename = f"doilist_{len(doi_list)}" else: # Convert *every* cell that looks like a serialized list/dict def _try_parse_obj(cell): if isinstance(cell, str): txt = cell.strip() if (txt.startswith('{') and txt.endswith('}')) or (txt.startswith('[') and txt.endswith(']')): # Try JSON first try: return json.loads(txt) except Exception: pass # Fallback to Python-repr (single quotes etc.) try: return ast.literal_eval(txt) except Exception: pass return cell records_df = records_df.map(_try_parse_obj) print(records_df.head()) else: error_message = f"Error: Unsupported file type. Please upload a CSV or PKL file." return create_error_response(error_message) records_df = process_records_to_df(records_df) # Make sure we have the required columns required_columns = ['title', 'abstract', 'publication_year'] missing_columns = [col for col in required_columns if col not in records_df.columns] if missing_columns: error_message = f"Error: Uploaded file is missing required columns: {', '.join(missing_columns)}" return create_error_response(error_message) print(f"Successfully loaded {len(records_df)} records from uploaded file") progress(0.2, desc="Processing uploaded data...") # For uploaded files, set all records to query_index 0 records_df['query_index'] = 0 except Exception as e: error_message = f"Error processing uploaded file: {str(e)}" return create_error_response(error_message) else: # Check if input is empty or whitespace print(f"Input: {text_input}") if not text_input or text_input.isspace(): error_message = "Error: Please enter a valid OpenAlex URL in the 'OpenAlex-search URL'-field or upload a CSV file" return create_error_response(error_message) print('Starting data projection pipeline') progress(0.1, desc="Starting...") # Split input into multiple URLs if present urls = [url.strip() for url in text_input.split(';')] records = [] query_indices = [] # Track which query each record comes from total_query_length = 0 expected_download_count = 0 # Track expected number of records to download for progress # Use first URL for filename first_query, first_params = openalex_url_to_pyalex_query(urls[0]) filename = openalex_url_to_filename(urls[0]) print(f"Filename: {filename}") # Process each URL for i, url in enumerate(urls): query, params = openalex_url_to_pyalex_query(url) query_length = query.count() total_query_length += query_length # Calculate expected download count for this query if reduce_sample_checkbox and sample_reduction_method == "First n samples": expected_for_this_query = min(sample_size_slider, query_length) elif reduce_sample_checkbox and sample_reduction_method == "n random samples": expected_for_this_query = min(sample_size_slider, query_length) else: # "All" expected_for_this_query = query_length expected_download_count += expected_for_this_query print(f'Requesting {query_length} entries from query {i+1}/{len(urls)} (expecting to download {expected_for_this_query})...') # Use PyAlex sampling for random samples - much more efficient! if reduce_sample_checkbox and sample_reduction_method == "n random samples": # Use PyAlex's built-in sample method for efficient server-side sampling target_size = min(sample_size_slider, query_length) try: seed_int = int(seed_value) if seed_value.strip() else 42 except ValueError: seed_int = 42 print(f"Invalid seed value '{seed_value}', using default: 42") print(f'Attempting PyAlex sampling: {target_size} from {query_length} (seed={seed_int})') try: # Check if PyAlex sample method exists and works if hasattr(query, 'sample'): sampled_query = query.sample(target_size, seed=seed_int) # IMPORTANT: When using sample(), must use method='page' for pagination! sampled_records = [] records_count = 0 for page in sampled_query.paginate(per_page=200, method='page', n_max=None): for record in page: sampled_records.append(record) records_count += 1 progress(0.1 + (0.15 * records_count / target_size), desc=f"Getting sampled data from query {i+1}/{len(urls)}... ({records_count}/{target_size})") print(f'PyAlex sampling successful: got {len(sampled_records)} records') else: raise AttributeError("sample method not available") except Exception as e: print(f"PyAlex sampling failed ({e}), using fallback method...") # Fallback: get all records and sample manually all_records = [] records_count = 0 # Use default cursor pagination for non-sampled queries for page in query.paginate(per_page=200, n_max=None): for record in page: all_records.append(record) records_count += 1 progress(0.1 + (0.15 * records_count / query_length), desc=f"Downloading for sampling from query {i+1}/{len(urls)}...") # Now sample manually if len(all_records) > target_size: import random random.seed(seed_int) sampled_records = random.sample(all_records, target_size) else: sampled_records = all_records print(f'Fallback sampling: got {len(sampled_records)} from {len(all_records)} total') # Add the sampled records for idx, record in enumerate(sampled_records): records.append(record) query_indices.append(i) # Safe progress calculation if expected_download_count > 0: progress_val = 0.1 + (0.2 * len(records) / expected_download_count) else: progress_val = 0.1 progress(progress_val, desc=f"Processing sampled data from query {i+1}/{len(urls)}...") else: # Keep existing logic for "First n samples" and "All" target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length records_per_query = 0 print(f"Query {i+1}: target_size={target_size}, query_length={query_length}, method={sample_reduction_method}") should_break_current_query = False # For "First n samples", limit the maximum records fetched to avoid over-downloading max_records_to_fetch = target_size if reduce_sample_checkbox and sample_reduction_method == "First n samples" else None for page in query.paginate(per_page=200, n_max=max_records_to_fetch): # Add retry mechanism for processing each page max_retries = 5 base_wait_time = 1 # Starting wait time in seconds exponent = 1.5 # Exponential factor for retry_attempt in range(max_retries): try: for record in page: # Safety check: don't process if we've already reached target if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size: print(f"Reached target size before processing: {records_per_query}/{target_size}, breaking from download") should_break_current_query = True break records.append(record) query_indices.append(i) # Track which query this record comes from records_per_query += 1 # Safe progress calculation if expected_download_count > 0: progress_val = 0.1 + (0.2 * len(records) / expected_download_count) else: progress_val = 0.1 progress(progress_val, desc=f"Getting data from query {i+1}/{len(urls)}...") if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size: print(f"Reached target size: {records_per_query}/{target_size}, breaking from download") should_break_current_query = True break # If we get here without an exception, break the retry loop break except Exception as e: print(f"Error processing page: {e}") if retry_attempt < max_retries - 1: wait_time = base_wait_time * (exponent ** retry_attempt) + random.random() print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...") time.sleep(wait_time) else: print(f"Maximum retries reached. Continuing with next page.") # Break out of retry loop if we've reached target if should_break_current_query: break if should_break_current_query: print(f"Successfully broke from page loop for query {i+1}") break # Continue to next query - don't break out of the main query loop print(f"Query completed in {time.time() - start_time:.2f} seconds") print(f"Total records collected: {len(records)}") print(f"Expected to download: {expected_download_count}") print(f"Available from all queries: {total_query_length}") print(f"Sample method used: {sample_reduction_method}") print(f"Reduce sample enabled: {reduce_sample_checkbox}") if sample_reduction_method == "n random samples": print(f"Seed value: {seed_value}") # Process records processing_start = time.time() records_df = process_records_to_df(records) # Add query_index to the dataframe records_df['query_index'] = query_indices[:len(records_df)] if reduce_sample_checkbox and sample_reduction_method != "All" and sample_reduction_method != "n random samples": # Note: We skip "n random samples" here because PyAlex sampling is already done above sample_size = min(sample_size_slider, len(records_df)) # Check if we have multiple queries for sampling logic urls = [url.strip() for url in text_input.split(';')] if text_input else [''] has_multiple_queries = len(urls) > 1 and not csv_upload # If using categorical coloring with multiple queries, sample each query independently if treat_as_categorical_checkbox and has_multiple_queries: # Sample the full sample_size from each query independently unique_queries = sorted(records_df['query_index'].unique()) sampled_dfs = [] for query_idx in unique_queries: query_records = records_df[records_df['query_index'] == query_idx] # Apply the full sample size to each query (only for "First n samples") current_sample_size = min(sample_size_slider, len(query_records)) if sample_reduction_method == "First n samples": sampled_query = query_records.iloc[:current_sample_size] sampled_dfs.append(sampled_query) print(f"Query {query_idx+1}: sampled {len(sampled_query)} records from {len(query_records)} available") records_df = pd.concat(sampled_dfs, ignore_index=True) print(f"Total after independent sampling: {len(records_df)} records") print(f"Query distribution: {records_df['query_index'].value_counts().sort_index()}") else: # Original sampling logic for single query or non-categorical (only "First n samples" now) if sample_reduction_method == "First n samples": records_df = records_df.iloc[:sample_size] print(f"Records processed in {time.time() - processing_start:.2f} seconds") # Create embeddings - this happens regardless of data source embedding_start = time.time() progress(0.3, desc="Embedding Data...") texts_to_embedd = [f"{title} {abstract}" for title, abstract in zip(records_df['title'], records_df['abstract'])] if is_running_in_hf_space(): if len(texts_to_embedd) < 2000: embeddings = create_embeddings_30(texts_to_embedd) elif len(texts_to_embedd) < 4000 or user_type == "anonymous": embeddings = create_embeddings_59(texts_to_embedd) elif len(texts_to_embedd) < 8000: embeddings = create_embeddings_120(texts_to_embedd) else: embeddings = create_embeddings_299(texts_to_embedd) else: embeddings = create_embeddings(texts_to_embedd) print(f"Embeddings created in {time.time() - embedding_start:.2f} seconds") # Project embeddings projection_start = time.time() progress(0.5, desc="Project into UMAP-embedding...") umap_embeddings = mapper.transform(embeddings) records_df[['x','y']] = umap_embeddings print(f"Projection completed in {time.time() - projection_start:.2f} seconds") # Prepare visualization data viz_prep_start = time.time() progress(0.6, desc="Preparing visualization data...") # Set up colors: basedata_df['color'] = '#ced4d211' # Convert highlight_color to hex if it isn't already if not highlight_color.startswith('#'): highlight_color = rgba_to_hex(highlight_color) highlight_color = rgba_to_hex(highlight_color) print('Highlight color:', highlight_color) # Check if we have multiple queries and categorical coloring is enabled urls = [url.strip() for url in text_input.split(';')] if text_input else [''] has_multiple_queries = len(urls) > 1 and not csv_upload if treat_as_categorical_checkbox and has_multiple_queries: # Use categorical coloring for multiple queries print("Using categorical coloring for multiple queries") # Get colors from selected colormap or use default categorical colors unique_queries = sorted(records_df['query_index'].unique()) num_queries = len(unique_queries) if selected_colormap_name and selected_colormap_name.strip(): try: # Use selected colormap to generate distinct colors categorical_cmap = plt.get_cmap(selected_colormap_name) # Sample colors evenly spaced across the colormap categorical_colors = [mcolors.to_hex(categorical_cmap(i / max(1, num_queries - 1))) for i in range(num_queries)] except Exception as e: print(f"Warning: Could not load colormap '{selected_colormap_name}' for categorical coloring: {e}") # Fallback to default categorical colors categorical_colors = [ '#e41a1c', # Red '#377eb8', # Blue '#4daf4a', # Green '#984ea3', # Purple '#ff7f00', # Orange '#ffff33', # Yellow '#a65628', # Brown '#f781bf', # Pink '#999999', # Gray '#66c2a5', # Teal '#fc8d62', # Light Orange '#8da0cb', # Light Blue '#e78ac3', # Light Pink '#a6d854', # Light Green '#ffd92f', # Light Yellow '#e5c494', # Beige '#b3b3b3', # Light Gray ] else: # Use default categorical colors categorical_colors = [ '#e41a1c', # Red '#377eb8', # Blue '#4daf4a', # Green '#984ea3', # Purple '#ff7f00', # Orange '#ffff33', # Yellow '#a65628', # Brown '#f781bf', # Pink '#999999', # Gray '#66c2a5', # Teal '#fc8d62', # Light Orange '#8da0cb', # Light Blue '#e78ac3', # Light Pink '#a6d854', # Light Green '#ffd92f', # Light Yellow '#e5c494', # Beige '#b3b3b3', # Light Gray ] # Assign colors based on query_index query_color_map = {query_idx: categorical_colors[i % len(categorical_colors)] for i, query_idx in enumerate(unique_queries)} records_df['color'] = records_df['query_index'].map(query_color_map) # Add query_label for better identification records_df['query_label'] = records_df['query_index'].apply(lambda x: f"Query {x+1}") elif plot_time_checkbox: # Use selected colormap if provided, otherwise default to haline if selected_colormap_name and selected_colormap_name.strip(): try: time_cmap = plt.get_cmap(selected_colormap_name) except Exception as e: print(f"Warning: Could not load colormap '{selected_colormap_name}': {e}") time_cmap = colormaps.haline else: time_cmap = colormaps.haline if not locally_approximate_publication_date_checkbox: # Create color mapping based on publication years years = pd.to_numeric(records_df['publication_year']) norm = mcolors.Normalize(vmin=years.min(), vmax=years.max()) records_df['color'] = [mcolors.to_hex(time_cmap(norm(year))) for year in years] # Store for legend generation years_for_legend = years legend_label = "Publication Year" legend_cmap = time_cmap else: n_neighbors = 10 # Adjust this value to control smoothing nn = NearestNeighbors(n_neighbors=n_neighbors) nn.fit(umap_embeddings) distances, indices = nn.kneighbors(umap_embeddings) # Calculate local average publication year for each point local_years = np.array([ np.mean(records_df['publication_year'].iloc[idx]) for idx in indices ]) norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max()) records_df['color'] = [mcolors.to_hex(time_cmap(norm(year))) for year in local_years] # Store for legend generation years_for_legend = local_years legend_label = "Approx. Year" legend_cmap = time_cmap else: # No special coloring - use highlight color records_df['color'] = highlight_color stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True) stacked_df = stacked_df.fillna("Unlabelled") stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()] # Create marker size array: basemap points = 2, query result points = 4 marker_sizes = np.concatenate([ np.full(len(basedata_df), 1.), # Basemap points np.full(len(records_df), 2.5) # Query result points ]) extra_data = pd.DataFrame(stacked_df['doi']) print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds") # Prepare file paths html_file_name = f"{filename}.html" html_file_path = static_dir / html_file_name csv_file_path = static_dir / f"{filename}.csv" png_file_path = static_dir / f"{filename}.png" if citation_graph_checkbox: citation_graph_start = time.time() citation_graph = create_citation_graph(records_df) graph_file_name = f"{filename}_citation_graph.jpg" graph_file_path = static_dir / graph_file_name draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True, min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])]) print(f"Citation graph created and saved in {time.time() - citation_graph_start:.2f} seconds") # Create and save plot plot_start = time.time() progress(0.7, desc="Creating interactive plot...") # Create a solid black colormap black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000']) # Generate legends based on plot type custom_html = "" legend_css = "" if HAS_LEGEND_BUILDERS: if treat_as_categorical_checkbox and has_multiple_queries: # Create categorical legend for multiple queries unique_queries = sorted(records_df['query_index'].unique()) color_mapping = {} # Get readable names for each query URL for i, query_idx in enumerate(unique_queries): try: if query_idx < len(urls): readable_name = openalex_url_to_readable_name(urls[query_idx]) # Truncate long names for legend display if len(readable_name) > 25: readable_name = readable_name[:22] + "..." else: readable_name = f"Query {query_idx + 1}" except Exception: readable_name = f"Query {query_idx + 1}" color_mapping[readable_name] = query_color_map[query_idx] legend_html, legend_css = categorical_legend_html_css( color_mapping, title="Queries" if len(color_mapping) > 1 else "Query", anchor="top-left", container_id="dmp-query-legend" ) custom_html += legend_html elif plot_time_checkbox and 'years_for_legend' in locals(): # Create continuous legend for time-based coloring using the stored variables # Create ticks every 5 years within the range, ignoring endpoints year_min, year_max = int(years_for_legend.min()), int(years_for_legend.max()) year_range = year_max - year_min # Find the first multiple of 5 that's greater than year_min first_tick = ((year_min // 5) + 1) * 5 # Generate ticks every 5 years until we reach year_max ticks = [] current_tick = first_tick while current_tick < year_max: ticks.append(current_tick) current_tick += 5 # For ranges under 15 years, include both endpoints if year_range < 15: if not ticks: # No 5-year ticks, just show endpoints ticks = [year_min, year_max] else: # Add endpoints to existing 5-year ticks if year_min not in ticks: ticks.insert(0, year_min) if year_max not in ticks: ticks.append(year_max) legend_html, legend_css = continuous_legend_html_css( legend_cmap, year_min, year_max, ticks=ticks, label=legend_label, anchor="top-right", container_id="dmp-year-legend" ) custom_html += legend_html # Add custom CSS to make legend titles equally large and bold legend_title_css = """ /* Make all legend titles equally large and bold */ #dmp-query-legend .legend-title, #dmp-year-legend .colorbar-label { font-size: 16px !important; font-weight: bold !important; font-family: 'Roboto Condensed', sans-serif !important; } """ # Combine legend CSS with existing custom CSS combined_css = DATAMAP_CUSTOM_CSS + "\n" + legend_css + "\n" + legend_title_css plot = datamapplot.create_interactive_plot( stacked_df[['x','y']].values, np.array(stacked_df['cluster_2_labels']), np.array(['Unlabelled' if pd.isna(x) else x for x in stacked_df['parsed_field']]), hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()], marker_color_array=stacked_df['color'], marker_size_array=marker_sizes, use_medoids=True, # Switch back once efficient mediod caclulation comes out! width=1000, height=1000, # point_size_scale=1.5, point_radius_min_pixels=1, text_outline_width=5, point_hover_color=highlight_color, point_radius_max_pixels=5, cmap=black_cmap, background_image=graph_file_name if citation_graph_checkbox else None, #color_label_text=False, font_family="Roboto Condensed", font_weight=600, tooltip_font_weight=600, tooltip_font_family="Roboto Condensed", extra_point_data=extra_data, on_click="window.open(`{doi}`)", custom_html=custom_html, custom_css=combined_css, initial_zoom_fraction=.8, enable_search=False, offline_mode=False ) # Save plot plot.save(html_file_path) print(f"Plot created and saved in {time.time() - plot_start:.2f} seconds") # Save additional files if requested if download_csv_checkbox: # Export relevant column export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']] export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()] export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']] # Add query information if categorical coloring is used if treat_as_categorical_checkbox and has_multiple_queries: export_df['query_index'] = records_df['query_index'] export_df['query_label'] = records_df['query_label'] if locally_approximate_publication_date_checkbox and plot_type_dropdown == "Time-based coloring" and 'years_for_legend' in locals(): export_df['approximate_publication_year'] = years_for_legend export_df.to_csv(csv_file_path, index=False) if download_png_checkbox: png_start_time = time.time() print("Starting PNG generation...") # Sample and prepare data sample_prep_start = time.time() sample_to_plot = basedata_df#.sample(20000) labels1 = np.array(sample_to_plot['cluster_2_labels']) labels2 = np.array(['Unlabelled' if pd.isna(x) else x for x in sample_to_plot['parsed_field']]) ratio = 0.6 mask = np.random.random(size=len(labels1)) < ratio combined_labels = np.where(mask, labels1, labels2) # Get the 30 most common labels unique_labels, counts = np.unique(combined_labels, return_counts=True) top_30_labels = set(unique_labels[np.argsort(counts)[-80:]]) # Replace less common labels with 'Unlabelled' combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels]) colors_base = ['#536878' for _ in range(len(labels1))] print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds") # Create main plot main_plot_start = time.time() fig, ax = datamapplot.create_plot( sample_to_plot[['x','y']].values, combined_labels, label_wrap_width=12, label_over_points=True, dynamic_label_size=True, use_medoids=True, # Switch back once efficient mediod caclulation comes out! point_size=2, marker_color_array=colors_base, force_matplotlib=True, max_font_size=12, min_font_size=4, min_font_weight=100, max_font_weight=300, font_family="Roboto Condensed", color_label_text=False, add_glow=False, highlight_labels=list(np.unique(labels1)), label_font_size=8, highlight_label_keywords={"fontsize": 12, "fontweight": "bold", "bbox":{"boxstyle":"circle", "pad":0.75,'alpha':0.}}, ) print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds") if citation_graph_checkbox: # Read and add the graph image graph_img = plt.imread(graph_file_path) ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])], alpha=0.9, aspect='auto') if len(records_df) > 50_000: point_size = .5 elif len(records_df) > 10_000: point_size = 1 else: point_size = 5 # Time-based visualization scatter_start = time.time() if plot_type_dropdown == "Time-based coloring": # Use selected colormap if provided, otherwise default to haline if selected_colormap_name and selected_colormap_name.strip(): try: static_cmap = plt.get_cmap(selected_colormap_name) except Exception as e: print(f"Warning: Could not load colormap '{selected_colormap_name}': {e}") static_cmap = colormaps.haline else: static_cmap = colormaps.haline if locally_approximate_publication_date_checkbox and 'years_for_legend' in locals(): scatter = plt.scatter( umap_embeddings[:,0], umap_embeddings[:,1], c=years_for_legend, cmap=static_cmap, alpha=0.8, s=point_size ) else: years = pd.to_numeric(records_df['publication_year']) scatter = plt.scatter( umap_embeddings[:,0], umap_embeddings[:,1], c=years, cmap=static_cmap, alpha=0.8, s=point_size ) plt.colorbar(scatter, shrink=0.5, format='%d') else: scatter = plt.scatter( umap_embeddings[:,0], umap_embeddings[:,1], c=records_df['color'], alpha=0.8, s=point_size ) print(f"Scatter plot creation completed in {time.time() - scatter_start:.2f} seconds") # Save plot save_start = time.time() plt.axis('off') plt.savefig(png_file_path, dpi=300, bbox_inches='tight') plt.close() print(f"Plot saving completed in {time.time() - save_start:.2f} seconds") print(f"Total PNG generation completed in {time.time() - png_start_time:.2f} seconds") progress(1.0, desc="Done!") print(f"Total pipeline completed in {time.time() - start_time:.2f} seconds") iframe = f"""""" # Return iframe and download buttons with appropriate visibility return [ iframe, gr.DownloadButton(label="Download Interactive Visualization", value=html_file_path, visible=True, variant='secondary'), gr.DownloadButton(label="Download CSV Data", value=csv_file_path, visible=download_csv_checkbox, variant='secondary'), gr.DownloadButton(label="Download Static Plot", value=png_file_path, visible=download_png_checkbox, variant='secondary'), gr.Button(visible=False) # Return hidden state for cancel button ] predict.zerogpu = True theme = gr.themes.Monochrome( font=[gr.themes.GoogleFont("Roboto Condensed"), "ui-sans-serif", "system-ui", "sans-serif"], text_size="lg", ).set( button_secondary_background_fill="white", button_secondary_background_fill_hover="#f3f4f6", button_secondary_border_color="black", button_secondary_text_color="black", button_border_width="2px", ) # JS to enforce light theme by refreshing the page js_light = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } } """ # Gradio interface setup with gr.Blocks(theme=theme, css=f""" .gradio-container a {{ color: black !important; text-decoration: none !important; /* Force remove default underline */ font-weight: bold; transition: color 0.2s ease-in-out, border-bottom-color 0.2s ease-in-out; display: inline-block; /* Enable proper spacing for descenders */ line-height: 1.1; /* Adjust line height */ padding-bottom: 2px; /* Add space for descenders */ }} .gradio-container a:hover {{ color: #b23310 !important; border-bottom: 3px solid #b23310; /* Wider underline, only on hover */ }} /* Colormap chooser styles */ {colormap_chooser.css()} """, js=js_light) as demo: gr.Markdown("""
The visualization map will appear here after running a query