m7n commited on
Commit
f895c88
·
1 Parent(s): cabc445

Many updates, mainly added categorical.

Browse files
Files changed (6) hide show
  1. app.py +392 -72
  2. color_utils.py +2 -2
  3. colormap_chooser.py +933 -0
  4. colormap_chooser_testing_app.py +47 -0
  5. openalex_utils.py +162 -2
  6. ui_utils.py +52 -0
app.py CHANGED
@@ -128,8 +128,10 @@ from openalex_utils import (
128
  get_field,
129
  process_records_to_df,
130
  openalex_url_to_filename,
131
- get_records_from_dois
 
132
  )
 
133
  from styles import DATAMAP_CUSTOM_CSS
134
  from data_setup import (
135
  download_required_files,
@@ -141,7 +143,8 @@ from data_setup import (
141
 
142
  from network_utils import create_citation_graph, draw_citation_graph
143
 
144
-
 
145
 
146
 
147
  # Configure OpenAlex
@@ -149,6 +152,26 @@ pyalex.config.email = "maximilian.noichl@uni-bamberg.de"
149
 
150
  print(f"Imports completed: {time.strftime('%Y-%m-%d %H:%M:%S')}")
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
 
154
  # Create a static directory to store the dynamic HTML files
@@ -236,21 +259,65 @@ def create_embeddings_299(texts_to_embedd):
236
  return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
237
 
238
 
 
239
  # else:
240
  def create_embeddings(texts_to_embedd):
241
  """Create embeddings for the input texts using the loaded model."""
242
  return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
243
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
 
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
 
249
  def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
250
- sample_reduction_method, plot_time_checkbox,
251
  locally_approximate_publication_date_checkbox,
252
  download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
253
- csv_upload, highlight_color,
254
  progress=gr.Progress()):
255
  """
256
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
@@ -261,13 +328,14 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
261
  sample_size_slider (int): Maximum number of samples to process
262
  reduce_sample_checkbox (bool): Whether to reduce sample size
263
  sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results")
264
- plot_time_checkbox (bool): Whether to color points by publication date
265
  locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting.
266
  download_csv_checkbox (bool): Whether to download CSV data
267
  download_png_checkbox (bool): Whether to download PNG data
268
  citation_graph_checkbox (bool): Whether to add citation graph
269
  csv_upload (str): Path to uploaded CSV file
270
  highlight_color (str): Color for highlighting points
 
271
  progress (gr.Progress): Gradio progress tracker
272
 
273
  Returns:
@@ -276,6 +344,10 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
276
  # Initialize start_time at the beginning of the function
277
  start_time = time.time()
278
 
 
 
 
 
279
  # Helper function to generate error responses
280
  def create_error_response(error_message):
281
  return [
@@ -358,6 +430,9 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
358
  print(f"Successfully loaded {len(records_df)} records from uploaded file")
359
  progress(0.2, desc="Processing uploaded data...")
360
 
 
 
 
361
  except Exception as e:
362
  error_message = f"Error processing uploaded file: {str(e)}"
363
  return create_error_response(error_message)
@@ -374,6 +449,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
374
  # Split input into multiple URLs if present
375
  urls = [url.strip() for url in text_input.split(';')]
376
  records = []
 
377
  total_query_length = 0
378
 
379
  # Use first URL for filename
@@ -388,54 +464,154 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
388
  total_query_length += query_length
389
  print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
390
 
391
- target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
392
- records_per_query = 0
393
-
394
- should_break = False
395
- for page in query.paginate(per_page=200, n_max=None):
396
- # Add retry mechanism for processing each page
397
- max_retries = 5
398
- base_wait_time = 1 # Starting wait time in seconds
399
- exponent = 1.5 # Exponential factor
400
 
401
- for retry_attempt in range(max_retries):
402
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  for record in page:
404
- records.append(record)
405
- records_per_query += 1
406
- progress(0.1 + (0.2 * len(records) / (total_query_length)),
407
- desc=f"Getting data from query {i+1}/{len(urls)}...")
408
 
409
- if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
410
- should_break = True
411
- break
412
- # If we get here without an exception, break the retry loop
413
- break
414
- except Exception as e:
415
- print(f"Error processing page: {e}")
416
- if retry_attempt < max_retries - 1:
417
- wait_time = base_wait_time * (exponent ** retry_attempt) + random.random()
418
- print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...")
419
- time.sleep(wait_time)
420
- else:
421
- print(f"Maximum retries reached. Continuing with next page.")
 
 
 
 
 
 
 
422
 
423
- if should_break:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  break
425
- if should_break:
426
- break
427
  print(f"Query completed in {time.time() - start_time:.2f} seconds")
 
 
 
 
 
 
428
 
429
  # Process records
430
  processing_start = time.time()
431
  records_df = process_records_to_df(records)
432
 
433
- if reduce_sample_checkbox and sample_reduction_method != "All":
434
- sample_size = min(sample_size_slider, len(records_df))
435
- if sample_reduction_method == "n random samples":
436
- records_df = records_df.sample(sample_size)
437
- elif sample_reduction_method == "First n samples":
438
- records_df = records_df.iloc[:sample_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  print(f"Records processed in {time.time() - processing_start:.2f} seconds")
440
 
441
  # Create embeddings - this happens regardless of data source
@@ -468,14 +644,68 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
468
  viz_prep_start = time.time()
469
  progress(0.6, desc="Preparing visualization data...")
470
 
 
 
 
471
  basedata_df['color'] = '#ced4d211'
472
 
 
 
 
473
  highlight_color = rgba_to_hex(highlight_color)
474
 
475
- if not plot_time_checkbox:
476
- records_df['color'] = highlight_color
477
- else:
478
- cmap = colormaps.haline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  if not locally_approximate_publication_date_checkbox:
480
  # Create color mapping based on publication years
481
  years = pd.to_numeric(records_df['publication_year'])
@@ -495,6 +725,9 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
495
  ])
496
  norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
497
  records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
 
 
 
498
 
499
  stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
500
  stacked_df = stacked_df.fillna("Unlabelled")
@@ -562,7 +795,13 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
562
  export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
563
  export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()]
564
  export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
565
- if locally_approximate_publication_date_checkbox and plot_time_checkbox:
 
 
 
 
 
 
566
  export_df['approximate_publication_year'] = local_years
567
  export_df.to_csv(csv_file_path, index=False)
568
 
@@ -628,13 +867,23 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
628
 
629
  # Time-based visualization
630
  scatter_start = time.time()
631
- if plot_time_checkbox:
 
 
 
 
 
 
 
 
 
 
632
  if locally_approximate_publication_date_checkbox:
633
  scatter = plt.scatter(
634
  umap_embeddings[:,0],
635
  umap_embeddings[:,1],
636
  c=local_years,
637
- cmap=colormaps.haline,
638
  alpha=0.8,
639
  s=point_size
640
  )
@@ -644,7 +893,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
644
  umap_embeddings[:,0],
645
  umap_embeddings[:,1],
646
  c=years,
647
- cmap=colormaps.haline,
648
  alpha=0.8,
649
  s=point_size
650
  )
@@ -713,8 +962,8 @@ function refresh() {
713
 
714
 
715
  # Gradio interface setup
716
- with gr.Blocks(theme=theme, css="""
717
- .gradio-container a {
718
  color: black !important;
719
  text-decoration: none !important; /* Force remove default underline */
720
  font-weight: bold;
@@ -722,11 +971,14 @@ with gr.Blocks(theme=theme, css="""
722
  display: inline-block; /* Enable proper spacing for descenders */
723
  line-height: 1.1; /* Adjust line height */
724
  padding-bottom: 2px; /* Add space for descenders */
725
- }
726
- .gradio-container a:hover {
727
  color: #b23310 !important;
728
  border-bottom: 3px solid #b23310; /* Wider underline, only on hover */
729
- }
 
 
 
730
  """, js=js_light) as demo:
731
  gr.Markdown("""
732
  <div style="max-width: 100%; margin: 0 auto;">
@@ -756,6 +1008,13 @@ with gr.Blocks(theme=theme, css="""
756
  text_input = gr.Textbox(label="OpenAlex-search URL",
757
  info="Enter the URL to an OpenAlex-search.")
758
 
 
 
 
 
 
 
 
759
  gr.Markdown("### Sample Settings")
760
  reduce_sample_checkbox = gr.Checkbox(
761
  label="Reduce Sample Size",
@@ -766,7 +1025,8 @@ with gr.Blocks(theme=theme, css="""
766
  ["All", "First n samples", "n random samples"],
767
  label="Sample Selection Method",
768
  value="First n samples",
769
- info="How to choose the samples to keep."
 
770
  )
771
 
772
  if is_running_in_hf_zero_gpu():
@@ -781,20 +1041,32 @@ with gr.Blocks(theme=theme, css="""
781
  step=10,
782
  value=1000,
783
  info="How many samples to keep.",
784
- visible=True
 
 
 
 
 
 
 
 
785
  )
786
 
787
  gr.Markdown("### Plot Settings")
788
- plot_time_checkbox = gr.Checkbox(
789
- label="Plot Time",
790
- value=True,
791
- info="Colour points by their publication date."
 
 
792
  )
793
  locally_approximate_publication_date_checkbox = gr.Checkbox(
794
  label="Locally Approximate Publication Date",
795
  value=True,
796
- info="Colour points by the average publication date in their area."
 
797
  )
 
798
 
799
  gr.Markdown("### Download Options")
800
  download_csv_checkbox = gr.Checkbox(
@@ -821,14 +1093,24 @@ with gr.Blocks(theme=theme, css="""
821
  label="Upload your own CSV file downloaded via pyalex.",
822
  file_types=[".csv"],
823
  )
824
-
825
  # --- Aesthetics Accordion ---
826
  with gr.Accordion("Aesthetics", open=False):
 
 
827
  highlight_color_picker = gr.ColorPicker(
828
  label="Highlight Color",
 
829
  value="#5e2784",
830
- info="Choose the highlight color for your query points."
831
  )
 
 
 
 
 
 
 
832
 
833
  with gr.Column(scale=2):
834
  html = gr.HTML(
@@ -877,15 +1159,43 @@ with gr.Blocks(theme=theme, css="""
877
  </div>
878
  """)
879
 
880
- def update_slider_visibility(method):
881
- return gr.Slider(visible=(method != "All"))
882
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  sample_reduction_method.change(
884
- fn=update_slider_visibility,
885
- inputs=[sample_reduction_method],
886
- outputs=[sample_size_slider]
887
  )
888
 
 
 
 
 
 
 
889
  def show_cancel_button():
890
  return gr.Button(visible=True)
891
 
@@ -908,13 +1218,16 @@ with gr.Blocks(theme=theme, css="""
908
  sample_size_slider,
909
  reduce_sample_checkbox,
910
  sample_reduction_method,
911
- plot_time_checkbox,
912
  locally_approximate_publication_date_checkbox,
 
913
  download_csv_checkbox,
914
  download_png_checkbox,
915
  citation_graph_checkbox,
916
  csv_upload,
917
- highlight_color_picker
 
 
918
  ],
919
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
920
  )
@@ -927,6 +1240,13 @@ with gr.Blocks(theme=theme, css="""
927
  queue=False # Important to make the button hide immediately
928
  )
929
 
 
 
 
 
 
 
 
930
 
931
  # demo.static_dirs = {
932
  # "static": str(static_dir)
 
128
  get_field,
129
  process_records_to_df,
130
  openalex_url_to_filename,
131
+ get_records_from_dois,
132
+ openalex_url_to_readable_name
133
  )
134
+ from ui_utils import highlight_queries
135
  from styles import DATAMAP_CUSTOM_CSS
136
  from data_setup import (
137
  download_required_files,
 
143
 
144
  from network_utils import create_citation_graph, draw_citation_graph
145
 
146
+ # Add colormap chooser imports
147
+ from colormap_chooser import ColormapChooser, setup_colormaps
148
 
149
 
150
  # Configure OpenAlex
 
152
 
153
  print(f"Imports completed: {time.strftime('%Y-%m-%d %H:%M:%S')}")
154
 
155
+ # Set up colormaps for the chooser
156
+ print("Setting up colormaps...")
157
+ colormap_categories = setup_colormaps(
158
+ included_collections=['matplotlib', 'cmocean', 'scientific', 'cmasher'],
159
+ excluded_collections=['colorcet', 'carbonplan', 'sciviz']
160
+ )
161
+
162
+ colormap_chooser = ColormapChooser(
163
+ categories=colormap_categories,
164
+ smooth_steps=10,
165
+ strip_width=200,
166
+ strip_height=50,
167
+ css_height=200,
168
+ # show_search=False,
169
+ # show_category=False,
170
+ # show_preview=False,
171
+ # show_selected_name=True,
172
+ # show_selected_info=False,
173
+ gallery_kwargs=dict(columns=3, allow_preview=False, height="200px")
174
+ )
175
 
176
 
177
  # Create a static directory to store the dynamic HTML files
 
259
  return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
260
 
261
 
262
+
263
  # else:
264
  def create_embeddings(texts_to_embedd):
265
  """Create embeddings for the input texts using the loaded model."""
266
  return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
267
 
268
 
269
+ def highlight_queries(text: str) -> str:
270
+ """Split OpenAlex URLs on semicolons and display them as colored pills with readable names."""
271
+ palette = [
272
+ "#e8f4fd", "#fff2e8", "#f0f9e8", "#fdf2f8",
273
+ "#f3e8ff", "#e8f8f5", "#fef7e8", "#f8f0e8"
274
+ ]
275
+
276
+ # Handle empty input
277
+ if not text or not text.strip():
278
+ return "<div style='padding: 10px; color: #666; font-style: italic;'>Enter OpenAlex URLs separated by semicolons to see query descriptions</div>"
279
+
280
+ # Split URLs on semicolons and strip whitespace
281
+ urls = [url.strip() for url in text.split(";") if url.strip()]
282
 
283
+ if not urls:
284
+ return "<div style='padding: 10px; color: #666; font-style: italic;'>No valid URLs found</div>"
285
 
286
+ pills = []
287
+ for i, url in enumerate(urls):
288
+ color = palette[i % len(palette)]
289
+ try:
290
+ # Get readable name for the URL
291
+ readable_name = openalex_url_to_readable_name(url)
292
+ except Exception as e:
293
+ print(f"Error processing URL {url}: {e}")
294
+ readable_name = f"Query {i+1}"
295
+
296
+ pills.append(
297
+ f'<span style="background:{color};'
298
+ 'padding: 8px 12px; margin: 4px; '
299
+ 'border-radius: 12px; font-weight: 500;'
300
+ 'display: inline-block; font-family: \'Roboto Condensed\', sans-serif;'
301
+ 'border: 1px solid rgba(0,0,0,0.1); font-size: 14px;'
302
+ 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
303
+ f'{readable_name}</span>'
304
+ )
305
+
306
+ return (
307
+ "<div style='padding: 8px 0;'>"
308
+ "<div style='font-size: 12px; color: #666; margin-bottom: 6px; font-weight: 500;'>"
309
+ f"{'Query' if len(urls) == 1 else 'Queries'} ({len(urls)}):</div>"
310
+ "<div style='display: flex; flex-wrap: wrap; gap: 4px;'>"
311
+ + "".join(pills) +
312
+ "</div></div>"
313
+ )
314
 
315
 
316
  def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
317
+ sample_reduction_method, plot_type_dropdown,
318
  locally_approximate_publication_date_checkbox,
319
  download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
320
+ csv_upload, highlight_color, selected_colormap_name, seed_value,
321
  progress=gr.Progress()):
322
  """
323
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
 
328
  sample_size_slider (int): Maximum number of samples to process
329
  reduce_sample_checkbox (bool): Whether to reduce sample size
330
  sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results")
331
+ plot_type_dropdown (str): Type of plot coloring ("No special coloring", "Time-based coloring", "Categorical coloring")
332
  locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting.
333
  download_csv_checkbox (bool): Whether to download CSV data
334
  download_png_checkbox (bool): Whether to download PNG data
335
  citation_graph_checkbox (bool): Whether to add citation graph
336
  csv_upload (str): Path to uploaded CSV file
337
  highlight_color (str): Color for highlighting points
338
+ selected_colormap_name (str): Name of the selected colormap for time-based coloring
339
  progress (gr.Progress): Gradio progress tracker
340
 
341
  Returns:
 
344
  # Initialize start_time at the beginning of the function
345
  start_time = time.time()
346
 
347
+ # Convert dropdown selection to boolean flags for backward compatibility
348
+ plot_time_checkbox = plot_type_dropdown == "Time-based coloring"
349
+ treat_as_categorical_checkbox = plot_type_dropdown == "Categorical coloring"
350
+
351
  # Helper function to generate error responses
352
  def create_error_response(error_message):
353
  return [
 
430
  print(f"Successfully loaded {len(records_df)} records from uploaded file")
431
  progress(0.2, desc="Processing uploaded data...")
432
 
433
+ # For uploaded files, set all records to query_index 0
434
+ records_df['query_index'] = 0
435
+
436
  except Exception as e:
437
  error_message = f"Error processing uploaded file: {str(e)}"
438
  return create_error_response(error_message)
 
449
  # Split input into multiple URLs if present
450
  urls = [url.strip() for url in text_input.split(';')]
451
  records = []
452
+ query_indices = [] # Track which query each record comes from
453
  total_query_length = 0
454
 
455
  # Use first URL for filename
 
464
  total_query_length += query_length
465
  print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
466
 
467
+ # Use PyAlex sampling for random samples - much more efficient!
468
+ if reduce_sample_checkbox and sample_reduction_method == "n random samples":
469
+ # Use PyAlex's built-in sample method for efficient server-side sampling
470
+ target_size = min(sample_size_slider, query_length)
471
+ try:
472
+ seed_int = int(seed_value) if seed_value.strip() else 42
473
+ except ValueError:
474
+ seed_int = 42
475
+ print(f"Invalid seed value '{seed_value}', using default: 42")
476
 
477
+ print(f'Attempting PyAlex sampling: {target_size} from {query_length} (seed={seed_int})')
478
+
479
+ try:
480
+ # Check if PyAlex sample method exists and works
481
+ if hasattr(query, 'sample'):
482
+ sampled_query = query.sample(target_size, seed=seed_int)
483
+
484
+ # IMPORTANT: When using sample(), must use method='page' for pagination!
485
+ sampled_records = []
486
+ records_count = 0
487
+ for page in sampled_query.paginate(per_page=200, method='page', n_max=None):
488
+ for record in page:
489
+ sampled_records.append(record)
490
+ records_count += 1
491
+ progress(0.1 + (0.15 * records_count / target_size),
492
+ desc=f"Getting sampled data from query {i+1}/{len(urls)}... ({records_count}/{target_size})")
493
+
494
+ print(f'PyAlex sampling successful: got {len(sampled_records)} records')
495
+ else:
496
+ raise AttributeError("sample method not available")
497
+
498
+ except Exception as e:
499
+ print(f"PyAlex sampling failed ({e}), using fallback method...")
500
+
501
+ # Fallback: get all records and sample manually
502
+ all_records = []
503
+ records_count = 0
504
+
505
+ # Use default cursor pagination for non-sampled queries
506
+ for page in query.paginate(per_page=200, n_max=None):
507
  for record in page:
508
+ all_records.append(record)
509
+ records_count += 1
510
+ progress(0.1 + (0.15 * records_count / query_length),
511
+ desc=f"Downloading for sampling from query {i+1}/{len(urls)}...")
512
 
513
+ # Now sample manually
514
+ if len(all_records) > target_size:
515
+ import random
516
+ random.seed(seed_int)
517
+ sampled_records = random.sample(all_records, target_size)
518
+ else:
519
+ sampled_records = all_records
520
+
521
+ print(f'Fallback sampling: got {len(sampled_records)} from {len(all_records)} total')
522
+
523
+ # Add the sampled records
524
+ for idx, record in enumerate(sampled_records):
525
+ records.append(record)
526
+ query_indices.append(i)
527
+ progress(0.1 + (0.2 * len(records) / total_query_length),
528
+ desc=f"Processing sampled data from query {i+1}/{len(urls)}...")
529
+ else:
530
+ # Keep existing logic for "First n samples" and "All"
531
+ target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
532
+ records_per_query = 0
533
 
534
+ should_break_current_query = False
535
+ for page in query.paginate(per_page=200, n_max=None):
536
+ # Add retry mechanism for processing each page
537
+ max_retries = 5
538
+ base_wait_time = 1 # Starting wait time in seconds
539
+ exponent = 1.5 # Exponential factor
540
+
541
+ for retry_attempt in range(max_retries):
542
+ try:
543
+ for record in page:
544
+ records.append(record)
545
+ query_indices.append(i) # Track which query this record comes from
546
+ records_per_query += 1
547
+ progress(0.1 + (0.2 * len(records) / (total_query_length)),
548
+ desc=f"Getting data from query {i+1}/{len(urls)}...")
549
+
550
+ if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
551
+ should_break_current_query = True
552
+ break
553
+ # If we get here without an exception, break the retry loop
554
+ break
555
+ except Exception as e:
556
+ print(f"Error processing page: {e}")
557
+ if retry_attempt < max_retries - 1:
558
+ wait_time = base_wait_time * (exponent ** retry_attempt) + random.random()
559
+ print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...")
560
+ time.sleep(wait_time)
561
+ else:
562
+ print(f"Maximum retries reached. Continuing with next page.")
563
+
564
+ if should_break_current_query:
565
  break
566
+ # Continue to next query - don't break out of the main query loop
 
567
  print(f"Query completed in {time.time() - start_time:.2f} seconds")
568
+ print(f"Total records collected: {len(records)}")
569
+ print(f"Expected from all queries: {total_query_length}")
570
+ print(f"Sample method used: {sample_reduction_method}")
571
+ print(f"Reduce sample enabled: {reduce_sample_checkbox}")
572
+ if sample_reduction_method == "n random samples":
573
+ print(f"Seed value: {seed_value}")
574
 
575
  # Process records
576
  processing_start = time.time()
577
  records_df = process_records_to_df(records)
578
 
579
+ # Add query_index to the dataframe
580
+ records_df['query_index'] = query_indices[:len(records_df)]
581
+
582
+ if reduce_sample_checkbox and sample_reduction_method != "All" and sample_reduction_method != "n random samples":
583
+ # Note: We skip "n random samples" here because PyAlex sampling is already done above
584
+ sample_size = min(sample_size_slider, len(records_df))
585
+
586
+ # Check if we have multiple queries for sampling logic
587
+ urls = [url.strip() for url in text_input.split(';')] if text_input else ['']
588
+ has_multiple_queries = len(urls) > 1 and not csv_upload
589
+
590
+ # If using categorical coloring with multiple queries, sample each query independently
591
+ if treat_as_categorical_checkbox and has_multiple_queries:
592
+ # Sample the full sample_size from each query independently
593
+ unique_queries = sorted(records_df['query_index'].unique())
594
+
595
+ sampled_dfs = []
596
+ for query_idx in unique_queries:
597
+ query_records = records_df[records_df['query_index'] == query_idx]
598
+
599
+ # Apply the full sample size to each query (only for "First n samples")
600
+ current_sample_size = min(sample_size_slider, len(query_records))
601
+
602
+ if sample_reduction_method == "First n samples":
603
+ sampled_query = query_records.iloc[:current_sample_size]
604
+
605
+ sampled_dfs.append(sampled_query)
606
+ print(f"Query {query_idx+1}: sampled {len(sampled_query)} records from {len(query_records)} available")
607
+
608
+ records_df = pd.concat(sampled_dfs, ignore_index=True)
609
+ print(f"Total after independent sampling: {len(records_df)} records")
610
+ print(f"Query distribution: {records_df['query_index'].value_counts().sort_index()}")
611
+ else:
612
+ # Original sampling logic for single query or non-categorical (only "First n samples" now)
613
+ if sample_reduction_method == "First n samples":
614
+ records_df = records_df.iloc[:sample_size]
615
  print(f"Records processed in {time.time() - processing_start:.2f} seconds")
616
 
617
  # Create embeddings - this happens regardless of data source
 
644
  viz_prep_start = time.time()
645
  progress(0.6, desc="Preparing visualization data...")
646
 
647
+
648
+ # Set up colors:
649
+
650
  basedata_df['color'] = '#ced4d211'
651
 
652
+ # Convert highlight_color to hex if it isn't already
653
+ if not highlight_color.startswith('#'):
654
+ highlight_color = rgba_to_hex(highlight_color)
655
  highlight_color = rgba_to_hex(highlight_color)
656
 
657
+ print('Highlight color:', highlight_color)
658
+
659
+ # Check if we have multiple queries and categorical coloring is enabled
660
+ urls = [url.strip() for url in text_input.split(';')] if text_input else ['']
661
+ has_multiple_queries = len(urls) > 1 and not csv_upload
662
+
663
+ if treat_as_categorical_checkbox and has_multiple_queries:
664
+ # Use categorical coloring for multiple queries
665
+ print("Using categorical coloring for multiple queries")
666
+
667
+ # Define a categorical colormap - using distinct colors
668
+ categorical_colors = [
669
+ '#e41a1c', # Red
670
+ '#377eb8', # Blue
671
+ '#4daf4a', # Green
672
+ '#984ea3', # Purple
673
+ '#ff7f00', # Orange
674
+ '#ffff33', # Yellow
675
+ '#a65628', # Brown
676
+ '#f781bf', # Pink
677
+ '#999999', # Gray
678
+ '#66c2a5', # Teal
679
+ '#fc8d62', # Light Orange
680
+ '#8da0cb', # Light Blue
681
+ '#e78ac3', # Light Pink
682
+ '#a6d854', # Light Green
683
+ '#ffd92f', # Light Yellow
684
+ '#e5c494', # Beige
685
+ '#b3b3b3', # Light Gray
686
+ ]
687
+
688
+ # Assign colors based on query_index
689
+ unique_queries = sorted(records_df['query_index'].unique())
690
+ query_color_map = {query_idx: categorical_colors[i % len(categorical_colors)]
691
+ for i, query_idx in enumerate(unique_queries)}
692
+
693
+ records_df['color'] = records_df['query_index'].map(query_color_map)
694
+
695
+ # Add query_label for better identification
696
+ records_df['query_label'] = records_df['query_index'].apply(lambda x: f"Query {x+1}")
697
+
698
+ elif plot_time_checkbox:
699
+ # Use selected colormap if provided, otherwise default to haline
700
+ if selected_colormap_name and selected_colormap_name.strip():
701
+ try:
702
+ cmap = plt.get_cmap(selected_colormap_name)
703
+ except Exception as e:
704
+ print(f"Warning: Could not load colormap '{selected_colormap_name}': {e}")
705
+ cmap = colormaps.haline
706
+ else:
707
+ cmap = colormaps.haline
708
+
709
  if not locally_approximate_publication_date_checkbox:
710
  # Create color mapping based on publication years
711
  years = pd.to_numeric(records_df['publication_year'])
 
725
  ])
726
  norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
727
  records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
728
+ else:
729
+ # No special coloring - use highlight color
730
+ records_df['color'] = highlight_color
731
 
732
  stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
733
  stacked_df = stacked_df.fillna("Unlabelled")
 
795
  export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
796
  export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()]
797
  export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
798
+
799
+ # Add query information if categorical coloring is used
800
+ if treat_as_categorical_checkbox and has_multiple_queries:
801
+ export_df['query_index'] = records_df['query_index']
802
+ export_df['query_label'] = records_df['query_label']
803
+
804
+ if locally_approximate_publication_date_checkbox and plot_type_dropdown == "Time-based coloring":
805
  export_df['approximate_publication_year'] = local_years
806
  export_df.to_csv(csv_file_path, index=False)
807
 
 
867
 
868
  # Time-based visualization
869
  scatter_start = time.time()
870
+ if plot_type_dropdown == "Time-based coloring":
871
+ # Use selected colormap if provided, otherwise default to haline
872
+ if selected_colormap_name and selected_colormap_name.strip():
873
+ try:
874
+ static_cmap = plt.get_cmap(selected_colormap_name)
875
+ except Exception as e:
876
+ print(f"Warning: Could not load colormap '{selected_colormap_name}': {e}")
877
+ static_cmap = colormaps.haline
878
+ else:
879
+ static_cmap = colormaps.haline
880
+
881
  if locally_approximate_publication_date_checkbox:
882
  scatter = plt.scatter(
883
  umap_embeddings[:,0],
884
  umap_embeddings[:,1],
885
  c=local_years,
886
+ cmap=static_cmap,
887
  alpha=0.8,
888
  s=point_size
889
  )
 
893
  umap_embeddings[:,0],
894
  umap_embeddings[:,1],
895
  c=years,
896
+ cmap=static_cmap,
897
  alpha=0.8,
898
  s=point_size
899
  )
 
962
 
963
 
964
  # Gradio interface setup
965
+ with gr.Blocks(theme=theme, css=f"""
966
+ .gradio-container a {{
967
  color: black !important;
968
  text-decoration: none !important; /* Force remove default underline */
969
  font-weight: bold;
 
971
  display: inline-block; /* Enable proper spacing for descenders */
972
  line-height: 1.1; /* Adjust line height */
973
  padding-bottom: 2px; /* Add space for descenders */
974
+ }}
975
+ .gradio-container a:hover {{
976
  color: #b23310 !important;
977
  border-bottom: 3px solid #b23310; /* Wider underline, only on hover */
978
+ }}
979
+
980
+ /* Colormap chooser styles */
981
+ {colormap_chooser.css()}
982
  """, js=js_light) as demo:
983
  gr.Markdown("""
984
  <div style="max-width: 100%; margin: 0 auto;">
 
1008
  text_input = gr.Textbox(label="OpenAlex-search URL",
1009
  info="Enter the URL to an OpenAlex-search.")
1010
 
1011
+ # Add the query highlight display
1012
+ query_display = gr.HTML(
1013
+ value="<div style='padding: 10px; color: #666; font-style: italic;'>Enter OpenAlex URLs separated by semicolons to see query descriptions</div>",
1014
+ label="",
1015
+ show_label=False
1016
+ )
1017
+
1018
  gr.Markdown("### Sample Settings")
1019
  reduce_sample_checkbox = gr.Checkbox(
1020
  label="Reduce Sample Size",
 
1025
  ["All", "First n samples", "n random samples"],
1026
  label="Sample Selection Method",
1027
  value="First n samples",
1028
+ info="How to choose the samples to keep.",
1029
+ visible=True # Will be controlled by reduce_sample_checkbox
1030
  )
1031
 
1032
  if is_running_in_hf_zero_gpu():
 
1041
  step=10,
1042
  value=1000,
1043
  info="How many samples to keep.",
1044
+ visible=True # Will be controlled by reduce_sample_checkbox
1045
+ )
1046
+
1047
+ # Add this new seed field
1048
+ seed_textbox = gr.Textbox(
1049
+ label="Random Seed",
1050
+ value="42",
1051
+ info="Seed for random sampling reproducibility.",
1052
+ visible=False # Will be controlled by both reduce_sample_checkbox and sample_reduction_method
1053
  )
1054
 
1055
  gr.Markdown("### Plot Settings")
1056
+ # Replace plot_time_checkbox with a dropdown
1057
+ plot_type_dropdown = gr.Dropdown(
1058
+ ["No special coloring", "Time-based coloring", "Categorical coloring"],
1059
+ label="Plot Coloring Type",
1060
+ value="Time-based coloring",
1061
+ info="Choose how to color the points on the plot."
1062
  )
1063
  locally_approximate_publication_date_checkbox = gr.Checkbox(
1064
  label="Locally Approximate Publication Date",
1065
  value=True,
1066
+ info="Colour points by the average publication date in their area.",
1067
+ visible=True # Will be controlled by plot_type_dropdown
1068
  )
1069
+ # Remove treat_as_categorical_checkbox since it's now part of the dropdown
1070
 
1071
  gr.Markdown("### Download Options")
1072
  download_csv_checkbox = gr.Checkbox(
 
1093
  label="Upload your own CSV file downloaded via pyalex.",
1094
  file_types=[".csv"],
1095
  )
1096
+
1097
  # --- Aesthetics Accordion ---
1098
  with gr.Accordion("Aesthetics", open=False):
1099
+ gr.Markdown("### Color Selection")
1100
+ gr.Markdown("*Choose an individual color to highlight your data.*")
1101
  highlight_color_picker = gr.ColorPicker(
1102
  label="Highlight Color",
1103
+ show_label=False,
1104
  value="#5e2784",
1105
+ #info="Choose the highlight color for your query points."
1106
  )
1107
+
1108
+ # Add colormap chooser
1109
+ gr.Markdown("### Colormap Selection")
1110
+ gr.Markdown("*Choose a colormap for time-based visualizations (when 'Plot Time' is enabled)*")
1111
+
1112
+ # Render the colormap chooser (created earlier)
1113
+ colormap_chooser.render_tabs()
1114
 
1115
  with gr.Column(scale=2):
1116
  html = gr.HTML(
 
1159
  </div>
1160
  """)
1161
 
1162
+ # Update the visibility control functions
1163
+ def update_sample_controls_visibility(reduce_sample_enabled, sample_method):
1164
+ """Update visibility of sample reduction controls based on checkbox and method"""
1165
+ method_visible = reduce_sample_enabled
1166
+ slider_visible = reduce_sample_enabled and sample_method != "All"
1167
+ seed_visible = reduce_sample_enabled and sample_method == "n random samples"
1168
+
1169
+ return (
1170
+ gr.Dropdown(visible=method_visible),
1171
+ gr.Slider(visible=slider_visible),
1172
+ gr.Textbox(visible=seed_visible)
1173
+ )
1174
+
1175
+ def update_plot_controls_visibility(plot_type):
1176
+ """Update visibility of plot controls based on plot type"""
1177
+ locally_approx_visible = plot_type == "Time-based coloring"
1178
+ return gr.Checkbox(visible=locally_approx_visible)
1179
+
1180
+ # Update event handlers
1181
+ reduce_sample_checkbox.change(
1182
+ fn=update_sample_controls_visibility,
1183
+ inputs=[reduce_sample_checkbox, sample_reduction_method],
1184
+ outputs=[sample_reduction_method, sample_size_slider, seed_textbox]
1185
+ )
1186
+
1187
  sample_reduction_method.change(
1188
+ fn=update_sample_controls_visibility,
1189
+ inputs=[reduce_sample_checkbox, sample_reduction_method],
1190
+ outputs=[sample_reduction_method, sample_size_slider, seed_textbox]
1191
  )
1192
 
1193
+ plot_type_dropdown.change(
1194
+ fn=update_plot_controls_visibility,
1195
+ inputs=[plot_type_dropdown],
1196
+ outputs=[locally_approximate_publication_date_checkbox]
1197
+ )
1198
+
1199
  def show_cancel_button():
1200
  return gr.Button(visible=True)
1201
 
 
1218
  sample_size_slider,
1219
  reduce_sample_checkbox,
1220
  sample_reduction_method,
1221
+ plot_type_dropdown, # Changed from plot_time_checkbox
1222
  locally_approximate_publication_date_checkbox,
1223
+ # Removed treat_as_categorical_checkbox since it's now part of plot_type_dropdown
1224
  download_csv_checkbox,
1225
  download_png_checkbox,
1226
  citation_graph_checkbox,
1227
  csv_upload,
1228
+ highlight_color_picker,
1229
+ colormap_chooser.selected_name,
1230
+ seed_textbox
1231
  ],
1232
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
1233
  )
 
1240
  queue=False # Important to make the button hide immediately
1241
  )
1242
 
1243
+ # Connect text input changes to query display updates
1244
+ text_input.change(
1245
+ fn=highlight_queries,
1246
+ inputs=text_input,
1247
+ outputs=query_display
1248
+ )
1249
+
1250
 
1251
  # demo.static_dirs = {
1252
  # "static": str(static_dir)
color_utils.py CHANGED
@@ -7,8 +7,8 @@ def rgba_to_hex(color):
7
  # If already hex
8
  if color.startswith('#') and (len(color) == 7 or len(color) == 4):
9
  return color
10
- # If rgba or rgb
11
- match = re.match(r"rgba?\\(([^)]+)\\)", color)
12
  if match:
13
  parts = match.group(1).split(',')
14
  r = int(float(parts[0]))
 
7
  # If already hex
8
  if color.startswith('#') and (len(color) == 7 or len(color) == 4):
9
  return color
10
+ # If rgba or rgb - FIX: Remove extra backslashes
11
+ match = re.match(r"rgba?\(([^)]+)\)", color)
12
  if match:
13
  parts = match.group(1).split(',')
14
  r = int(float(parts[0]))
colormap_chooser.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Colormap Chooser Gradio Component
2
+ ===================================
3
+
4
+ A reusable, importable Gradio component that provides a **scrollable, wide-strip**
5
+ chooser for Matplotlib (and compatible) colormaps. Designed to drop into an
6
+ existing Gradio Blocks app.
7
+
8
+ Features
9
+ --------
10
+ * Long, skinny gradient bars (not square tiles).
11
+ * Smart sampling:
12
+ - Continuous maps → ~20 sample steps (configurable) interpolated across width.
13
+ - Categorical / qualitative maps → actual number of colors (`cmap.N`).
14
+ * Scrollable gallery (height-capped w/ CSS).
15
+ * Selection callback returns the **selected colormap name** (string) you can pass
16
+ directly to Matplotlib (`mpl.colormaps[name]` or `plt.get_cmap(name)`).
17
+ * Optional category + search filtering UI.
18
+ * Minimal dependencies: NumPy, Matplotlib, Gradio.
19
+
20
+ Quick Start
21
+ -----------
22
+ ```python
23
+ import gradio as gr
24
+ from colormap_chooser import ColormapChooser, setup_colormaps
25
+
26
+ # Set up colormaps with custom collections
27
+ categories = setup_colormaps(
28
+ included_collections=['matplotlib', 'cmocean', 'scientific'],
29
+ excluded_collections=['colorcet']
30
+ )
31
+
32
+ chooser = ColormapChooser(
33
+ categories=categories,
34
+ gallery_kwargs=dict(columns=4, allow_preview=True, height="400px")
35
+ )
36
+
37
+ with gr.Blocks() as demo:
38
+ with gr.Row():
39
+ chooser.render() # inserts the component cluster
40
+ # Use chooser.selected_name as an input to your plotting fn
41
+ import numpy as np, matplotlib.pyplot as plt
42
+ def show_demo(cmap_name):
43
+ data = np.random.rand(32, 32)
44
+ fig, ax = plt.subplots()
45
+ im = ax.imshow(data, cmap=cmap_name)
46
+ ax.set_title(cmap_name)
47
+ fig.colorbar(im, ax=ax)
48
+ return fig
49
+ out = gr.Plot()
50
+ chooser.selected_name.change(show_demo, chooser.selected_name, out)
51
+
52
+ demo.launch()
53
+ ```
54
+
55
+ Installation
56
+ ------------
57
+ Drop this file in your project (e.g., `colormap_chooser.py`) and import.
58
+
59
+ Customizing
60
+ -----------
61
+ Pass your own category dict, default sampling counts, or CSS overrides at
62
+ construction time; see class docstring below.
63
+ """
64
+
65
+ from __future__ import annotations
66
+
67
+ import numpy as np
68
+ import matplotlib as mpl
69
+ import matplotlib.colors as mcolors
70
+ import matplotlib.pyplot as plt
71
+ import gradio as gr
72
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
73
+
74
+ # ------------------------------------------------------------------
75
+ # Default category mapping (extend or replace at init)
76
+ # ------------------------------------------------------------------
77
+ DEFAULT_CATEGORIES: Dict[str, List[str]] = {
78
+ "Perceptually Uniform": ["viridis", "plasma", "inferno", "magma", "cividis"],
79
+ "Sequential": ["Blues", "Greens", "Oranges", "Purples", "Reds", "Greys"],
80
+ "Diverging": ["coolwarm", "bwr", "seismic", "PiYG", "PRGn", "RdBu"],
81
+ "Qualitative": ["tab10", "tab20", "Set1", "Set2", "Accent"],
82
+ }
83
+
84
+
85
+ # ------------------------------------------------------------------
86
+ # Colormap setup functions
87
+ # ------------------------------------------------------------------
88
+
89
+ def load_matplotlib_colormaps():
90
+ """
91
+ Load matplotlib's built-in colormaps directly.
92
+ Returns dict of colormap_name -> colormap_object
93
+ """
94
+ matplotlib_cmaps = {}
95
+
96
+ # Get all matplotlib colormaps
97
+ for name in plt.colormaps():
98
+ try:
99
+ cmap = plt.get_cmap(name)
100
+ matplotlib_cmaps[name] = cmap
101
+ except Exception:
102
+ continue
103
+
104
+ return matplotlib_cmaps
105
+
106
+
107
+ def load_external_colormaps():
108
+ """
109
+ Load colormaps from external packages (like colormaps, cmocean, etc.).
110
+ Returns dict of colormap_name -> colormap_object
111
+ """
112
+ external_cmaps = {}
113
+
114
+ # Try to load from colormaps package
115
+ try:
116
+ import colormaps
117
+ for attr_name in dir(colormaps):
118
+ if not attr_name.startswith('_'):
119
+ try:
120
+ attr_value = getattr(colormaps, attr_name)
121
+ # Check if it looks like a colormap
122
+ if hasattr(attr_value, '__call__') or hasattr(attr_value, 'colors'):
123
+ external_cmaps[attr_name] = attr_value
124
+ except Exception:
125
+ continue
126
+ except ImportError:
127
+ pass
128
+
129
+ return external_cmaps
130
+
131
+
132
+ def categorize_colormaps(
133
+ colormap_dict: Dict[str, any],
134
+ included_collections: List[str],
135
+ excluded_collections: List[str]
136
+ ) -> Dict[str, List[str]]:
137
+ """
138
+ Categorize colormaps by type with priority ordering.
139
+
140
+ Args:
141
+ colormap_dict: Dict of colormap_name -> colormap_object
142
+ included_collections: List of collection names to include
143
+ excluded_collections: List of collection names to exclude
144
+
145
+ Returns:
146
+ Dict {"Category": [list_of_names]} with colormaps ordered by collection priority
147
+ """
148
+
149
+ # Known categorizations based on documentation
150
+ matplotlib_sequential = {
151
+ 'viridis', 'plasma', 'inferno', 'magma', 'cividis', # Perceptually uniform
152
+ 'ylorbr', 'ylorrd', 'orrd', 'purd', 'rdpu', 'bupu', # Multi-hue sequential
153
+ 'gnbu', 'pubu', 'ylgnbu', 'pubugn', 'bugn', 'ylgn',
154
+ 'binary', 'gist_yarg', 'gist_gray', 'gray', 'bone', 'pink', # Sequential (2)
155
+ 'spring', 'summer', 'autumn', 'winter', 'cool', 'wistia',
156
+ 'hot', 'afmhot', 'gist_heat', 'copper'
157
+ }
158
+
159
+ # Single-color sequential maps to exclude
160
+ single_color_sequential = {
161
+ 'blues', 'greens', 'oranges', 'purples', 'reds', 'greys'
162
+ }
163
+
164
+ matplotlib_diverging = {
165
+ 'piyg', 'prgn', 'brbg', 'puor', 'rdgy', 'rdbu',
166
+ 'rdylbu', 'rdylgn', 'spectral', 'coolwarm', 'bwr', 'seismic',
167
+ 'berlin', 'managua', 'vanimo'
168
+ }
169
+
170
+ matplotlib_qualitative = {
171
+ 'pastel1', 'pastel2', 'paired', 'accent',
172
+ 'dark2', 'set1', 'set2', 'set3',
173
+ 'tab10', 'tab20', 'tab20b', 'tab20c'
174
+ }
175
+
176
+ matplotlib_miscellaneous = {
177
+ 'flag', 'prism', 'ocean', 'gist_earth', 'terrain', 'gist_stern',
178
+ 'gnuplot', 'gnuplot2', 'cmrmap', 'cubehelix', 'brg',
179
+ 'gist_rainbow', 'rainbow', 'jet', 'turbo', 'nipy_spectral',
180
+ 'gist_ncar', 'twilight', 'twilight_shifted', 'hsv'
181
+ }
182
+
183
+ # External colormap collections
184
+ cmocean_sequential = {
185
+ 'thermal', 'haline', 'solar', 'ice', 'gray', 'oxy', 'deep', 'dense',
186
+ 'algae', 'matter', 'turbid', 'speed', 'amp', 'tempo', 'rain'
187
+ }
188
+ cmocean_diverging = {'balance', 'delta', 'curl', 'diff', 'tarn'}
189
+ cmocean_other = {'phase', 'topo'}
190
+
191
+ scientific_sequential = {
192
+ 'batlow', 'batlowK', 'batlowW', 'acton', 'bamako', 'bilbao', 'buda', 'davos',
193
+ 'devon', 'grayC', 'hawaii', 'imola', 'lajolla', 'lapaz', 'nuuk', 'oslo',
194
+ 'tokyo', 'turku', 'actonS', 'bamO', 'brocO', 'corko', 'corkO', 'davosS',
195
+ 'grayCS', 'hawaiiS', 'imolaS', 'lajollaS', 'lapazS', 'nuukS', 'osloS',
196
+ 'tokyoS', 'turkuS'
197
+ }
198
+ scientific_diverging = {
199
+ 'bam', 'bamo', 'berlin', 'broc', 'brocO', 'cork', 'corko', 'lisbon',
200
+ 'managua', 'roma', 'romao', 'tofino', 'vanimo', 'vik', 'viko'
201
+ }
202
+
203
+ cmasher_sequential = {
204
+ 'amber', 'amethyst', 'apple', 'arctic', 'autumn', 'bubblegum', 'chroma',
205
+ 'cosmic', 'dusk', 'ember', 'emerald', 'flamingo', 'freeze', 'gem', 'gothic',
206
+ 'heat', 'jungle', 'lavender', 'neon', 'neutral', 'nuclear', 'ocean',
207
+ 'pepper', 'plasma_r', 'rainforest', 'savanna', 'sunburst', 'swamp', 'torch',
208
+ 'toxic', 'tree', 'voltage', 'voltage_r'
209
+ }
210
+ cmasher_diverging = {
211
+ 'copper', 'emergency', 'fusion', 'guppy', 'holly', 'iceburn', 'infinity',
212
+ 'pride', 'prinsenvlag', 'redshift', 'seasons', 'seaweed', 'viola',
213
+ 'waterlily', 'watermelon', 'wildfire'
214
+ }
215
+
216
+ # Helper function to determine collection priority
217
+ def get_collection_priority(name_lower):
218
+ # Check matplotlib first (highest priority)
219
+ if (name_lower in matplotlib_sequential or name_lower in matplotlib_diverging or
220
+ name_lower in matplotlib_qualitative or name_lower in matplotlib_miscellaneous):
221
+ return 0
222
+ # Then cmocean
223
+ elif (name_lower in cmocean_sequential or name_lower in cmocean_diverging or name_lower in cmocean_other):
224
+ return 1
225
+ # Then scientific
226
+ elif (name_lower in scientific_sequential or name_lower in scientific_diverging):
227
+ return 2
228
+ # Then cmasher
229
+ elif (name_lower in cmasher_sequential or name_lower in cmasher_diverging):
230
+ return 3
231
+ # Everything else
232
+ else:
233
+ return 4
234
+
235
+ # Collect all valid colormaps with their categories and priorities
236
+ valid_colormaps = []
237
+
238
+ for name, cmap_obj in colormap_dict.items():
239
+ name_lower = name.lower()
240
+
241
+ # Skip numbered variants (like brbg_9, set1_9, brbg_4_r, piyg_8_r, etc.)
242
+ parts = name_lower.split('_')
243
+ if len(parts) >= 2:
244
+ # Check if second-to-last part is a digit (handles both name_4 and name_4_r)
245
+ if parts[-2].isdigit():
246
+ continue
247
+ # Also check if last part is a digit (handles name_4)
248
+ if parts[-1].isdigit():
249
+ continue
250
+
251
+ # Skip single-color sequential maps
252
+ if name_lower in single_color_sequential:
253
+ continue
254
+
255
+ # Check if we should include this colormap based on collection filters
256
+ include_cmap = True
257
+
258
+ # Check excluded collections
259
+ for excluded in excluded_collections:
260
+ if excluded.lower() in name_lower:
261
+ include_cmap = False
262
+ break
263
+
264
+ if not include_cmap:
265
+ continue
266
+
267
+ # Check included collections
268
+ if included_collections:
269
+ include_cmap = False
270
+ for included in included_collections:
271
+ if (included.lower() in name_lower or
272
+ # Special handling for matplotlib colormaps
273
+ (included == 'matplotlib' and name in plt.colormaps()) or
274
+ # Special handling for known colormap sets
275
+ name_lower in cmocean_sequential or name_lower in cmocean_diverging or name_lower in cmocean_other or
276
+ name_lower in scientific_sequential or name_lower in scientific_diverging or
277
+ name_lower in cmasher_sequential or name_lower in cmasher_diverging):
278
+ include_cmap = True
279
+ break
280
+
281
+ if not include_cmap:
282
+ continue
283
+
284
+ # Categorize the colormap
285
+ category = None
286
+ if (name_lower in matplotlib_qualitative or
287
+ any(qual in name_lower for qual in ['tab10', 'tab20', 'set1', 'set2', 'set3', 'paired', 'accent', 'pastel', 'dark2'])):
288
+ category = "Qualitative"
289
+ elif (name_lower in cmocean_sequential or name_lower in scientific_sequential or
290
+ name_lower in cmasher_sequential or name_lower in matplotlib_sequential or
291
+ 'sequential' in name_lower or
292
+ any(seq in name_lower for seq in ['viridis', 'plasma', 'inferno', 'magma', 'cividis'])):
293
+ category = "Sequential"
294
+ elif (name_lower in cmocean_diverging or name_lower in scientific_diverging or
295
+ name_lower in cmasher_diverging or name_lower in matplotlib_diverging or
296
+ 'diverging' in name_lower or
297
+ any(div in name_lower for div in ['bwr', 'coolwarm', 'seismic', 'rdbu', 'rdgy', 'piyg', 'prgn', 'brbg'])):
298
+ category = "Diverging"
299
+ else:
300
+ category = "Other"
301
+
302
+ if category:
303
+ priority = get_collection_priority(name_lower)
304
+ valid_colormaps.append((name, category, priority))
305
+
306
+ # Sort by category, then by priority, then by name
307
+ valid_colormaps.sort(key=lambda x: (x[1], x[2], x[0].lower()))
308
+
309
+ # Group by category while maintaining order
310
+ categories = {
311
+ "Sequential": [],
312
+ "Diverging": [],
313
+ "Qualitative": [],
314
+ "Other": []
315
+ }
316
+
317
+ for name, category, priority in valid_colormaps:
318
+ categories[category].append(name)
319
+
320
+ # Remove empty categories and hide "Other" category
321
+ final_categories = {}
322
+ for cat_name, cmap_names in categories.items():
323
+ if cmap_names and cat_name != "Other": # Hide "Other" category
324
+ final_categories[cat_name] = cmap_names
325
+
326
+ return final_categories
327
+
328
+
329
+ def setup_colormaps(
330
+ included_collections: Optional[List[str]] = None,
331
+ excluded_collections: Optional[List[str]] = None,
332
+ additional_colormaps: Optional[Dict[str, any]] = None
333
+ ) -> Dict[str, List[str]]:
334
+ """
335
+ Set up and categorize colormaps from various sources.
336
+
337
+ Args:
338
+ included_collections: List of collection names to include
339
+ (e.g., ['matplotlib', 'cmocean', 'scientific'])
340
+ excluded_collections: List of collection names to exclude
341
+ additional_colormaps: Dict of additional colormaps to include
342
+
343
+ Returns:
344
+ Dict of {"Category": [list_of_colormap_names]} ready for ColormapChooser
345
+ """
346
+ if excluded_collections is None:
347
+ excluded_collections = ['colorcet', 'carbonplan', 'sciviz']
348
+
349
+ if included_collections is None:
350
+ included_collections = ['matplotlib', 'cmocean', 'scientific', 'cmasher', 'colorbrewer', 'cartocolors']
351
+
352
+ # Combine all colormaps
353
+ all_colormaps = {}
354
+
355
+ # Add matplotlib colormaps
356
+ if 'matplotlib' in included_collections:
357
+ matplotlib_cmaps = load_matplotlib_colormaps()
358
+ all_colormaps.update(matplotlib_cmaps)
359
+ print(f"Added {len(matplotlib_cmaps)} matplotlib colormaps")
360
+
361
+ # Add external colormaps
362
+ try:
363
+ external_cmaps = load_external_colormaps()
364
+ all_colormaps.update(external_cmaps)
365
+ print(f"Added {len(external_cmaps)} external colormaps")
366
+ except Exception as e:
367
+ print(f"Could not load external colormaps: {e}")
368
+
369
+ # Add any additional colormaps
370
+ if additional_colormaps:
371
+ all_colormaps.update(additional_colormaps)
372
+ print(f"Added {len(additional_colormaps)} additional colormaps")
373
+
374
+ # Categorize colormaps
375
+ return categorize_colormaps(all_colormaps, included_collections, excluded_collections)
376
+
377
+
378
+ # ------------------------------------------------------------------
379
+ # Utility helpers
380
+ # ------------------------------------------------------------------
381
+
382
+ def _flatten_categories(categories: Dict[str, Sequence[str]]) -> List[str]:
383
+ names = []
384
+ for _, vals in categories.items():
385
+ names.extend(vals)
386
+ # maintain insertion order; drop dupes while preserving first occurrence
387
+ seen = set()
388
+ out = []
389
+ for n in names:
390
+ if n not in seen:
391
+ seen.add(n)
392
+ out.append(n)
393
+ return out
394
+
395
+
396
+ def _build_name2cat(categories: Dict[str, Sequence[str]]) -> Dict[str, str]:
397
+ m = {}
398
+ for cat, vals in categories.items():
399
+ for n in vals:
400
+ m[n] = cat
401
+ return m
402
+
403
+
404
+ # ------------------------------------------------------------------
405
+ # Sampling policy
406
+ # ------------------------------------------------------------------
407
+
408
+ def _is_categorical_cmap(
409
+ cmap: mcolors.Colormap,
410
+ declared_category: Optional[str] = None,
411
+ qualitative_label: str = "Qualitative",
412
+ max_auto: int = 32,
413
+ ) -> bool:
414
+ """Heuristic: treat as categorical/qualitative.
415
+
416
+ Priority:
417
+ 1. If user-declared category == qualitative_label → True.
418
+ 2. If ListedColormap with small N → True.
419
+ 3. If colormap name suggests it's qualitative → True.
420
+ 4. Else False (continuous).
421
+ """
422
+ # Check if explicitly declared as qualitative
423
+ if declared_category == qualitative_label:
424
+ return True
425
+
426
+ # Check if it's a ListedColormap with small N
427
+ if isinstance(cmap, mcolors.ListedColormap) and cmap.N <= max_auto:
428
+ return True
429
+
430
+ # Additional check: if the colormap name suggests it's qualitative
431
+ # This is a fallback in case the declared_category doesn't match exactly
432
+ if hasattr(cmap, 'name'):
433
+ name_lower = cmap.name.lower()
434
+ qualitative_names = {
435
+ 'tab10', 'tab20', 'tab20b', 'tab20c', 'set1', 'set2', 'set3',
436
+ 'pastel1', 'pastel2', 'paired', 'accent', 'dark2'
437
+ }
438
+ if name_lower in qualitative_names:
439
+ return True
440
+
441
+ return False
442
+
443
+
444
+ def _cmap_strip(
445
+ name: str,
446
+ width: int = 10,
447
+ height: int = 16,
448
+ smooth_steps: int = 20,
449
+ declared_category: Optional[str] = None,
450
+ qualitative_label: str = "Qualitative",
451
+ max_auto: int = 32,
452
+ ):
453
+ """Return RGB uint8 preview strip for *name* colormap.
454
+
455
+ Continuous maps are resampled to *smooth_steps* and linearly interpolated.
456
+ Categorical maps use actual number of colors, but adapt to available width.
457
+ """
458
+ cmap = mpl.colormaps[name]
459
+ categorical = _is_categorical_cmap(
460
+ cmap, declared_category=declared_category, qualitative_label=qualitative_label, max_auto=max_auto
461
+ )
462
+
463
+ if categorical:
464
+ n = cmap.N
465
+ if hasattr(cmap, "colors"):
466
+ cols = np.asarray(cmap.colors)
467
+ if cols.shape[1] == 4:
468
+ cols = cols[:, :3]
469
+ else:
470
+ xs = np.linspace(0, 1, n, endpoint=False) + (0.5 / n)
471
+ cols = cmap(xs)[..., :3]
472
+
473
+ # Adaptive approach based on available width
474
+ min_block_width = 3 # Minimum pixels per color block for visibility
475
+
476
+ if width >= n * min_block_width:
477
+ # We have enough width to show all colors as distinct blocks
478
+ block_w = width // n
479
+ selected_cols = cols
480
+ num_blocks = n
481
+ else:
482
+ # Not enough width - show a representative sample
483
+ max_colors_that_fit = max(2, width // min_block_width) # At least 2 colors
484
+
485
+ if max_colors_that_fit >= n:
486
+ # We can fit all colors
487
+ selected_cols = cols
488
+ num_blocks = n
489
+ block_w = width // n
490
+ else:
491
+ # Sample evenly across the colormap
492
+ indices = np.linspace(0, n-1, max_colors_that_fit, dtype=int)
493
+ selected_cols = cols[indices]
494
+ num_blocks = max_colors_that_fit
495
+ block_w = width // num_blocks
496
+
497
+ # Debug output for categorical sampling
498
+ if name.lower() in ['tab10', 'tab20', 'set1', 'set2', 'accent', 'paired']:
499
+ print(f'CATEGORICAL SAMPLING DEBUG: {name}')
500
+ print(f' n (total colors): {n}')
501
+ print(f' width: {width}')
502
+ print(f' num_blocks (colors shown): {num_blocks}')
503
+ print(f' block_w (width per color): {block_w}')
504
+ print(f' showing all colors: {num_blocks == n}')
505
+ print('---')
506
+
507
+ # Create the array with discrete blocks
508
+ arr = np.repeat(selected_cols[np.newaxis, :, :], height, axis=0) # (h,num_blocks,3)
509
+ arr = np.repeat(arr, block_w, axis=1) # (h,num_blocks*block_w,3)
510
+
511
+ # Handle any remaining width
512
+ current_width = arr.shape[1]
513
+ if current_width < width:
514
+ # Pad by extending the last color
515
+ pad = width - current_width
516
+ last_color = arr[:, -1:, :] # Get last column
517
+ padding = np.repeat(last_color, pad, axis=1)
518
+ arr = np.concatenate([arr, padding], axis=1)
519
+ elif current_width > width:
520
+ # Trim to exact width
521
+ arr = arr[:, :width, :]
522
+
523
+ return (arr * 255).astype(np.uint8)
524
+
525
+ # continuous - unchanged
526
+ xs = np.linspace(0, 1, smooth_steps)
527
+ cols = cmap(xs)[..., :3]
528
+ xi = np.linspace(0, smooth_steps - 1, width)
529
+ lo = np.floor(xi).astype(int)
530
+ hi = np.minimum(lo + 1, smooth_steps - 1)
531
+ t = xi - lo
532
+ strip = (1 - t)[:, None] * cols[lo] + t[:, None] * cols[hi]
533
+ arr = np.repeat(strip[np.newaxis, :, :], height, axis=0)
534
+ return (arr * 255).astype(np.uint8)
535
+
536
+
537
+ # ------------------------------------------------------------------
538
+ # ColormapChooser class
539
+ # ------------------------------------------------------------------
540
+ class ColormapChooser:
541
+ """Reusable scrollable colormap selector for Gradio.
542
+
543
+ Parameters
544
+ ----------
545
+ categories:
546
+ Dict mapping *Category Label* → list of cmap names. If None, uses
547
+ DEFAULT_CATEGORIES defined above. You may pass additional categories or
548
+ override existing ones. Order preserved.
549
+ smooth_steps:
550
+ Approx sample count for continuous maps (default 20).
551
+ strip_width:
552
+ Pixel width of preview strip images (default 512).
553
+ strip_height:
554
+ Pixel height of preview strip images (default 16).
555
+ css_height:
556
+ Max CSS height (pixels) for the scrollable gallery viewport.
557
+ qualitative_label:
558
+ Category label used to force qualitative sampling when present.
559
+ max_auto:
560
+ If a ListedColormap has N <= max_auto, treat as categorical even if not
561
+ declared Qualitative.
562
+ elem_id:
563
+ DOM id for the gallery (used to scope CSS overrides). Default 'cmap_gallery'.
564
+ show_search:
565
+ Whether to render the search Textbox.
566
+ show_category:
567
+ Whether to render the category Radio selector.
568
+ show_preview:
569
+ Show the big preview strip under the gallery. Off by default.
570
+ show_selected_name:
571
+ Show the textbox that echoes the selected colormap name. Off by default.
572
+ show_selected_info:
573
+ Show the markdown info line. Off by default.
574
+ gallery_kwargs:
575
+ Dictionary of keyword arguments to pass to the Gradio Gallery component
576
+ when it is created. For example, `columns=4, allow_preview=True, height="400px"`.
577
+
578
+ Public attributes after render():
579
+ category (optional)
580
+ search (optional)
581
+ gallery
582
+ preview
583
+ selected_name (Textbox; value string)
584
+ selected_info (Markdown)
585
+ names_state (State of current filtered cmap names)
586
+
587
+ Usage: see module Quick Start above.
588
+ """
589
+
590
+ def __init__(
591
+ self,
592
+ *,
593
+ categories: Optional[Dict[str, Sequence[str]]] = None,
594
+ smooth_steps: int = 10,
595
+ strip_width: int = 10,
596
+ strip_height: int = 16,
597
+ css_height: int = 240,
598
+ qualitative_label: str = "Qualitative",
599
+ max_auto: int = 32,
600
+ elem_id: str = "cmap_gallery",
601
+ show_search: bool = True,
602
+ show_category: bool = True,
603
+ columns: int = 3,
604
+ thumb_margin_px: int = 2, # NEW
605
+ gallery_kwargs: Optional[Dict[str, Any]] = None,
606
+ show_preview: bool = False,
607
+ show_selected_name: bool = False,
608
+ show_selected_info: bool = True,
609
+ ) -> None:
610
+ self.categories = categories if categories is not None else DEFAULT_CATEGORIES
611
+ self.smooth_steps = smooth_steps
612
+ self.strip_width = strip_width
613
+ self.strip_height = strip_height
614
+ self.css_height = css_height
615
+ self.qualitative_label = qualitative_label
616
+ self.max_auto = max_auto
617
+ self.elem_id = elem_id
618
+ self.show_search = show_search
619
+ self.show_category = show_category
620
+ self.columns = columns
621
+ self.thumb_margin_px = thumb_margin_px # NEW
622
+ self.gallery_kwargs = gallery_kwargs or {}
623
+ # visibility flags
624
+ self.show_preview = show_preview
625
+ self.show_selected_name = show_selected_name
626
+ self.show_selected_info = show_selected_info
627
+ self._all_names = _flatten_categories(self.categories)
628
+ self._name2cat = _build_name2cat(self.categories)
629
+ self._tile_cache: Dict[str, np.ndarray] = {}
630
+
631
+ # public gradio components (populated in render)
632
+ self.category = None
633
+ self.search = None
634
+ self.gallery = None
635
+ self.preview = None
636
+ self.selected_name = None
637
+ self.selected_info = None
638
+ self.names_state = None
639
+
640
+ # ------------------
641
+ # internal helpers
642
+ # ------------------
643
+ def _tile(self, name: str) -> np.ndarray:
644
+ if name not in self._tile_cache:
645
+ self._tile_cache[name] = _cmap_strip(
646
+ name,
647
+ width=self.strip_width,
648
+ height=self.strip_height,
649
+ smooth_steps=self.smooth_steps,
650
+ declared_category=self._name2cat.get(name),
651
+ qualitative_label=self.qualitative_label,
652
+ max_auto=self.max_auto,
653
+ )
654
+ return self._tile_cache[name]
655
+
656
+ def _make_gallery_items(self, names: Sequence[str]):
657
+ return [(self._tile(n), n) for n in names]
658
+
659
+ # ------------------
660
+ # event functions
661
+ # ------------------
662
+ def _filter(self, cat: str, s: str):
663
+ if self.show_category and cat in self.categories:
664
+ names = list(self.categories[cat])
665
+ else:
666
+ names = list(self._all_names)
667
+
668
+ if s and self.show_search:
669
+ sl = s.lower()
670
+ names = [n for n in names if sl in n.lower()]
671
+
672
+ # Remember new list for the select-callback
673
+ self.names_state.value = names
674
+
675
+ # 1) return an updated gallery
676
+ gkw = {
677
+ "value": self._make_gallery_items(names),
678
+ "selected_index": None,
679
+ }
680
+ gkw.update(self.gallery_kwargs)
681
+ gallery_update = gr.Gallery(**gkw)
682
+ # 2) clear the other widgets so old selection disappears
683
+ preview_update = gr.update(value=None)
684
+ name_update = gr.update(value="")
685
+ info_update = gr.update(value="")
686
+
687
+ return gallery_update, preview_update, name_update, info_update
688
+
689
+ def _select(self, evt: gr.SelectData, names: Sequence[str]):
690
+ if not names or evt.index is None or evt.index >= len(names):
691
+ return gr.update(), "", "Nothing selected"
692
+ name = names[evt.index]
693
+ big = _cmap_strip(
694
+ name,
695
+ width=max(self.strip_width * 2, 768),
696
+ height=max(self.strip_height * 2, 32),
697
+ smooth_steps=self.smooth_steps,
698
+ declared_category=self._name2cat.get(name),
699
+ qualitative_label=self.qualitative_label,
700
+ max_auto=self.max_auto,
701
+ )
702
+ info = f"**Selected:** `{name}` _(Category: {self._name2cat.get(name, '?')})_"
703
+ return big, name, info
704
+
705
+ # ------------------
706
+ # CSS block builder
707
+ # ------------------
708
+ def css(self) -> str:
709
+ return f"""
710
+ /* ───── 0. easy visual check the CSS is live (remove later) ───── */
711
+ #{self.elem_id} {{
712
+ /* background:rgba(255,255,0,.05); */
713
+ }}
714
+
715
+ /* the wrapper *is* the .block, so it owns the padding var */
716
+ #{self.elem_id}_wrap {{
717
+ padding: 0 !important;
718
+ --block-padding: 0 !important;
719
+ }}
720
+
721
+ /* ───── 1. the wrapper Gradio marks .fixed-height: make it scroll ─── */
722
+ #{self.elem_id} .grid-wrap {{
723
+ height: {self.css_height}px; /* kill inline 200 px or similar */
724
+ max-height: {self.css_height}px; /* cap the gallery’s height */
725
+ overflow-y: auto; /* rows that don’t fit will scroll */
726
+ }}
727
+
728
+ /* ───── 2. the real grid: keep masonry maths intact, tweak gap ─── */
729
+ #{self.elem_id} .grid-container {{
730
+ height: auto !important; /* sometimes Gradio sets one */
731
+ gap: 7px; /* tighter gutters (define attr) */
732
+ grid-auto-rows:auto !important;
733
+ }}
734
+
735
+ /* ───── 3. thumbnail boxes keep your ultra-wide shape ──────────── */
736
+ #{self.elem_id} .thumbnail-item {{
737
+ aspect-ratio: 3/1; /* e.g. 5/1 */
738
+ height: auto !important; /* beats Gradio’s inline 100 % */
739
+ margin: {self.thumb_margin_px}px !important;
740
+ overflow: hidden; /* just in case */
741
+ }}
742
+
743
+ /* ───── 4. images fill each box neatly ─────────────────────────── */
744
+ #{self.elem_id} img {{
745
+ width: 100%;
746
+ height: 100%;
747
+ object-fit: cover; /* crop to fill */
748
+ object-position: left;
749
+ display: block; /* kill inline-img whitespace */
750
+ }}
751
+
752
+ /* ───── 5. widen the “Selected:” info line ───────────────────── */
753
+ .cmap_selected_info {{
754
+ max-width: 100% !important; /* kill default 45 rem limit */
755
+ }}
756
+ """
757
+
758
+ # ------------------
759
+ # Render into an existing Blocks context
760
+ # ------------------
761
+ def render(self):
762
+ """Create Gradio UI elements and wire callbacks.
763
+
764
+ Must be called *inside* an active `gr.Blocks()` context.
765
+ Returns a tuple `(components_dict)` for convenience.
766
+ """
767
+ # initial list: first category or all
768
+ if self.show_category:
769
+ first_cat = next(iter(self.categories))
770
+ init_names = list(self.categories[first_cat])
771
+ else:
772
+ init_names = list(self._all_names)
773
+
774
+ # preheat tiles lazily on demand; no bulk precompute
775
+ # (call _tile when building gallery items)
776
+
777
+ # layout
778
+ if self.show_category or self.show_search:
779
+ with gr.Row():
780
+ if self.show_category:
781
+ self.category = gr.Radio(list(self.categories.keys()), value=first_cat, label="Category")
782
+ else:
783
+ self.category = gr.State(None) # shim so filter signature works
784
+ if self.show_search:
785
+ self.search = gr.Textbox(label="Search", placeholder="type to filter...")
786
+ else:
787
+ self.search = gr.State("")
788
+ else:
789
+ self.category = gr.State(None)
790
+ self.search = gr.State("")
791
+
792
+ self.names_state = gr.State(init_names)
793
+
794
+ gkw = {
795
+ "value": self._make_gallery_items(init_names),
796
+ "label": None, # remove label
797
+ "allow_preview": False,
798
+ "elem_id": self.elem_id,
799
+ "show_share_button": False,
800
+ "columns": getattr(self, "columns", 3),
801
+ }
802
+ gkw.update(self.gallery_kwargs)
803
+ self.gallery = gr.Gallery(**gkw)
804
+
805
+ self.preview = gr.Image(
806
+ label="Preview", interactive=False, height=60, visible=self.show_preview
807
+ )
808
+ self.selected_name = gr.Textbox(
809
+ label="Selected cmap", interactive=False, visible=self.show_selected_name
810
+ )
811
+ self.selected_info = gr.Markdown(
812
+ visible=self.show_selected_info,
813
+ elem_classes="cmap_selected_info",
814
+ )
815
+
816
+ # wiring
817
+ if self.show_category or self.show_search:
818
+
819
+ def _wrapped_filter(cat, s):
820
+ if not self.show_category:
821
+ cat = None
822
+ if not self.show_search:
823
+ s = ""
824
+ return self._filter(cat, s)
825
+
826
+ outputs = [self.gallery,
827
+ self.preview,
828
+ self.selected_name,
829
+ self.selected_info]
830
+
831
+ if self.show_category:
832
+ self.category.change(
833
+ _wrapped_filter,
834
+ [self.category, self.search],
835
+ outputs
836
+ )
837
+ if self.show_search:
838
+ self.search.change(
839
+ _wrapped_filter,
840
+ [self.category, self.search],
841
+ outputs
842
+ )
843
+
844
+ def _wrapped_select(evt: gr.SelectData, names):
845
+ return self._select(evt, names)
846
+
847
+ self.gallery.select(_wrapped_select, [self.names_state],
848
+ [self.preview, self.selected_name, self.selected_info])
849
+
850
+ return {
851
+ "gallery": self.gallery,
852
+ "selected_name": self.selected_name,
853
+ "preview": self.preview,
854
+ "info": self.selected_info,
855
+ "category": self.category,
856
+ "search": self.search,
857
+ "names_state": self.names_state,
858
+ }
859
+
860
+ # ==========================================================
861
+ # NEW TAB-BASED RENDERER
862
+ # ==========================================================
863
+ def render_tabs(self):
864
+ """
865
+ Render the chooser as one Gallery per category inside a gradio Tabs
866
+ container. No search box is provided – each tab already filters
867
+ by category.
868
+
869
+ Returns the same components dict as `render()`, plus a "galleries"
870
+ dict that maps category → Gallery component.
871
+ """
872
+ galleries = {}
873
+
874
+ with gr.Tabs() as root_tabs:
875
+
876
+ # --- build a tab + gallery for every category -------------
877
+ for cat, names in self.categories.items():
878
+ with gr.TabItem(cat):
879
+ gkw = {
880
+ "value": self._make_gallery_items(names),
881
+ "label": None, # remove label
882
+ "allow_preview": False,
883
+ "show_share_button": False,
884
+ "elem_id": self.elem_id,
885
+ "columns": getattr(self, "columns", 3),
886
+ "show_label": False
887
+ }
888
+ gkw.update(self.gallery_kwargs)
889
+ with gr.Row(elem_id=f"{self.elem_id}_wrap"): # ← new wrapper
890
+ gal = gr.Gallery(**gkw)
891
+ galleries[cat] = gal
892
+
893
+ # --- shared preview / meta area under the tabs ----------------
894
+ self.preview = gr.Image(
895
+ label="Preview", interactive=False, height=60, visible=self.show_preview
896
+ )
897
+ self.selected_name = gr.Textbox(
898
+ label="Selected cmap", interactive=False, visible=self.show_selected_name
899
+ )
900
+ self.selected_info = gr.Markdown(
901
+ visible=self.show_selected_info,
902
+ elem_classes="cmap_selected_info",
903
+ )
904
+
905
+ # --- wiring: every gallery uses the same _select callback -----
906
+ def _wrapped_select(evt: gr.SelectData, names):
907
+ return self._select(evt, names)
908
+
909
+ for cat, gal in galleries.items():
910
+ gal.select(
911
+ _wrapped_select,
912
+ [gr.State(list(self.categories[cat]))], # names list
913
+ [self.preview, self.selected_name, self.selected_info],
914
+ )
915
+
916
+ return {
917
+ "galleries": galleries,
918
+ "selected_name": self.selected_name,
919
+ "preview": self.preview,
920
+ "info": self.selected_info,
921
+ "tabs": root_tabs,
922
+ }
923
+
924
+
925
+ # ------------------------------------------------------------------
926
+ # Minimal self-demo (only runs if module executed directly)
927
+ # ------------------------------------------------------------------
928
+ if __name__ == "__main__":
929
+ chooser = ColormapChooser()
930
+ with gr.Blocks(css=chooser.css()) as demo:
931
+ gr.Markdown("## Colormap Chooser Demo")
932
+ chooser.render()
933
+ demo.launch()
colormap_chooser_testing_app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from colormap_chooser import ColormapChooser, setup_colormaps
3
+
4
+ # Set up colormaps with our preferred collections and ordering
5
+ print("Setting up colormaps...")
6
+ categories = setup_colormaps(
7
+ included_collections=['matplotlib', 'cmocean', 'scientific', 'cmasher', 'colorbrewer', 'cartocolors'],
8
+ excluded_collections=['colorcet', 'carbonplan', 'sciviz']
9
+ )
10
+
11
+
12
+ # Create the chooser with our categories
13
+ chooser = ColormapChooser(
14
+ categories=categories,
15
+ smooth_steps=10,
16
+ strip_width=200,
17
+ strip_height=50,
18
+
19
+
20
+ css_height=180, # outer box height (becomes a scroll-pane)
21
+ thumb_margin_px=2, # more space between strips
22
+ gallery_kwargs=dict(columns=3, allow_preview=False, height="200px") # anything else you need
23
+ )
24
+
25
+ print(chooser.css())
26
+
27
+ with gr.Blocks(css=chooser.css()) as demo:
28
+ with gr.Row():
29
+ with gr.Column(scale=1):
30
+ chooser.render_tabs()
31
+ with gr.Column(scale=2):
32
+ plot = gr.Plot(label="Demo Plot")
33
+
34
+ # When the user picks a cmap, update the plot
35
+ def _plot(name):
36
+ print(f"Plotting {name}")
37
+ import numpy as np, matplotlib.pyplot as plt
38
+ data = np.random.RandomState(0).randn(100,100)
39
+ fig, ax = plt.subplots()
40
+ im = ax.imshow(data, cmap=name)
41
+ fig.colorbar(im, ax=ax)
42
+ plt.close(fig)
43
+ return fig
44
+
45
+ chooser.selected_name.change(_plot, chooser.selected_name, plot)
46
+
47
+ demo.launch(debug=True, share=False, inbrowser=True)
openalex_utils.py CHANGED
@@ -1,6 +1,6 @@
1
  import numpy as np
2
  from urllib.parse import urlparse, parse_qs
3
- from pyalex import Works
4
  import pandas as pd
5
  import ast, json
6
 
@@ -213,4 +213,164 @@ def get_records_from_dois(doi_list, block_size=50):
213
  all_records.extend(record_list)
214
  except Exception as e:
215
  print(f"Error fetching DOIs {sublist}: {e}")
216
- return pd.DataFrame(all_records)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from urllib.parse import urlparse, parse_qs
3
+ from pyalex import Works, Authors, Institutions
4
  import pandas as pd
5
  import ast, json
6
 
 
213
  all_records.extend(record_list)
214
  except Exception as e:
215
  print(f"Error fetching DOIs {sublist}: {e}")
216
+ return pd.DataFrame(all_records)
217
+
218
+ def openalex_url_to_readable_name(url):
219
+ """
220
+ Convert an OpenAlex URL to a short, human-readable query description.
221
+
222
+ Args:
223
+ url (str): The OpenAlex search URL
224
+
225
+ Returns:
226
+ str: A short, human-readable description of the query
227
+
228
+ Examples:
229
+ - "Search: 'Kuramoto Model'"
230
+ - "Search: 'quantum physics', 2020-2023"
231
+ - "Cites: Popper (1959)"
232
+ - "From: University of Pittsburgh, 1999-2020"
233
+ - "By: Einstein, A., 1905-1955"
234
+ """
235
+ import re
236
+
237
+ # Parse the URL
238
+ parsed_url = urlparse(url)
239
+ query_params = parse_qs(parsed_url.query)
240
+
241
+ # Initialize description parts
242
+ parts = []
243
+ year_range = None
244
+
245
+ # Handle filters
246
+ if 'filter' in query_params:
247
+ filters = query_params['filter'][0].split(',')
248
+
249
+ for f in filters:
250
+ if ':' not in f:
251
+ continue
252
+
253
+ key, value = f.split(':', 1)
254
+
255
+ try:
256
+ if key == 'default.search':
257
+ # Clean up search term (remove quotes if present)
258
+ search_term = value.strip('"\'')
259
+ parts.append(f"Search: '{search_term}'")
260
+
261
+ elif key == 'publication_year':
262
+ # Handle year ranges or single years
263
+ if '-' in value:
264
+ start_year, end_year = value.split('-')
265
+ year_range = f"{start_year}-{end_year}"
266
+ else:
267
+ year_range = value
268
+
269
+ elif key == 'cites':
270
+ # Look up the cited work to get author and year
271
+ work_id = value
272
+ try:
273
+ cited_work = Works()[work_id]
274
+ if cited_work:
275
+ # Get first author's last name
276
+ author_name = "Unknown"
277
+ year = "Unknown"
278
+
279
+ if cited_work.get('authorships') and len(cited_work['authorships']) > 0:
280
+ first_author = cited_work['authorships'][0]['author']
281
+ if first_author.get('display_name'):
282
+ # Extract last name (assuming "First Last" format)
283
+ name_parts = first_author['display_name'].split()
284
+ author_name = name_parts[-1] if name_parts else first_author['display_name']
285
+
286
+ if cited_work.get('publication_year'):
287
+ year = str(cited_work['publication_year'])
288
+
289
+ parts.append(f"Cites: {author_name} ({year})")
290
+ else:
291
+ parts.append(f"Cites: Work {work_id}")
292
+ except Exception as e:
293
+ print(f"Could not fetch cited work {work_id}: {e}")
294
+ parts.append(f"Cites: Work {work_id}")
295
+
296
+ elif key == 'authorships.institutions.lineage':
297
+ # Look up institution name
298
+ inst_id = value
299
+ try:
300
+ institution = Institutions()[inst_id]
301
+ if institution and institution.get('display_name'):
302
+ parts.append(f"From: {institution['display_name']}")
303
+ else:
304
+ parts.append(f"From: Institution {inst_id}")
305
+ except Exception as e:
306
+ print(f"Could not fetch institution {inst_id}: {e}")
307
+ parts.append(f"From: Institution {inst_id}")
308
+
309
+ elif key == 'authorships.author.id':
310
+ # Look up author name
311
+ author_id = value
312
+ try:
313
+ author = Authors()[author_id]
314
+ if author and author.get('display_name'):
315
+ parts.append(f"By: {author['display_name']}")
316
+ else:
317
+ parts.append(f"By: Author {author_id}")
318
+ except Exception as e:
319
+ print(f"Could not fetch author {author_id}: {e}")
320
+ parts.append(f"By: Author {author_id}")
321
+
322
+ elif key == 'type':
323
+ # Handle work types
324
+ type_mapping = {
325
+ 'article': 'Articles',
326
+ 'book': 'Books',
327
+ 'book-chapter': 'Book Chapters',
328
+ 'dissertation': 'Dissertations',
329
+ 'preprint': 'Preprints'
330
+ }
331
+ work_type = type_mapping.get(value, value.replace('-', ' ').title())
332
+ parts.append(f"Type: {work_type}")
333
+
334
+ elif key == 'host_venue.id':
335
+ # Look up venue name
336
+ venue_id = value
337
+ try:
338
+ # For venues, we can use Works to get source info, but let's try a direct approach
339
+ # This might need adjustment based on pyalex API structure
340
+ parts.append(f"In: Venue {venue_id}") # Fallback
341
+ except Exception as e:
342
+ parts.append(f"In: Venue {venue_id}")
343
+
344
+ elif key.startswith('concepts.id'):
345
+ # Handle concept filters - these are topic/concept IDs
346
+ concept_id = value
347
+ parts.append(f"Topic: {concept_id}") # Could be enhanced with concept lookup
348
+
349
+ else:
350
+ # Generic handling for other filters
351
+ clean_key = key.replace('_', ' ').replace('.', ' ').title()
352
+ clean_value = value.replace('_', ' ')
353
+ parts.append(f"{clean_key}: {clean_value}")
354
+
355
+ except Exception as e:
356
+ print(f"Error processing filter {f}: {e}")
357
+ continue
358
+
359
+ # Combine parts into final description
360
+ if not parts:
361
+ description = "OpenAlex Query"
362
+ else:
363
+ description = ", ".join(parts)
364
+
365
+ # Add year range if present
366
+ if year_range:
367
+ if parts:
368
+ description += f", {year_range}"
369
+ else:
370
+ description = f"Works from {year_range}"
371
+
372
+ # Limit length to keep it readable
373
+ if len(description) > 100:
374
+ description = description[:97] + "..."
375
+
376
+ return description
ui_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI utility functions for the OpenAlex Mapper Gradio app.
3
+ """
4
+
5
+ from openalex_utils import openalex_url_to_readable_name
6
+
7
+
8
+ def highlight_queries(text: str) -> str:
9
+ """Split OpenAlex URLs on semicolons and display them as colored pills with readable names."""
10
+ palette = ["#f5f5f5", #set to only light grey
11
+ # "#e8f4fd", "#fff2e8", "#f0f9e8", "#fdf2f8",
12
+ # "#f3e8ff", "#e8f8f5", "#fef7e8", "#f8f0e8"
13
+ ]
14
+
15
+ # Handle empty input
16
+ if not text or not text.strip():
17
+ return "<div style='padding: 10px; color: #666; font-style: italic;'>Enter OpenAlex URLs separated by semicolons to see query descriptions</div>"
18
+
19
+ # Split URLs on semicolons and strip whitespace
20
+ urls = [url.strip() for url in text.split(";") if url.strip()]
21
+
22
+ if not urls:
23
+ return "<div style='padding: 10px; color: #666; font-style: italic;'>No valid URLs found</div>"
24
+
25
+ pills = []
26
+ for i, url in enumerate(urls):
27
+ color = palette[i % len(palette)]
28
+ try:
29
+ # Get readable name for the URL
30
+ readable_name = openalex_url_to_readable_name(url)
31
+ except Exception as e:
32
+ print(f"Error processing URL {url}: {e}")
33
+ readable_name = f"Query {i+1}"
34
+
35
+ pills.append(
36
+ f'<span style="background:{color};'
37
+ 'padding: 8px 12px; margin: 4px; '
38
+ 'border-radius: 12px; font-weight: 500;'
39
+ 'display: inline-block; font-family: \'Roboto Condensed\', sans-serif;'
40
+ 'border: 1px solid rgba(0,0,0,0.1); font-size: 14px;'
41
+ 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
42
+ f'{readable_name}</span>'
43
+ )
44
+
45
+ return (
46
+ "<div style='padding: 8px 0;'>"
47
+ "<div style='font-size: 12px; color: #666; margin-bottom: 6px; font-weight: 500;'>"
48
+ f"{'Query' if len(urls) == 1 else 'Queries'} ({len(urls)}):</div>"
49
+ "<div style='display: flex; flex-wrap: wrap; gap: 4px;'>"
50
+ + "".join(pills) +
51
+ "</div></div>"
52
+ )