openalex_mapper / colormap_chooser.py
m7n's picture
Many updates, mainly added categorical.
f895c88
"""Colormap Chooser Gradio Component
===================================
A reusable, importable Gradio component that provides a **scrollable, wide-strip**
chooser for Matplotlib (and compatible) colormaps. Designed to drop into an
existing Gradio Blocks app.
Features
--------
* Long, skinny gradient bars (not square tiles).
* Smart sampling:
- Continuous maps β†’ ~20 sample steps (configurable) interpolated across width.
- Categorical / qualitative maps β†’ actual number of colors (`cmap.N`).
* Scrollable gallery (height-capped w/ CSS).
* Selection callback returns the **selected colormap name** (string) you can pass
directly to Matplotlib (`mpl.colormaps[name]` or `plt.get_cmap(name)`).
* Optional category + search filtering UI.
* Minimal dependencies: NumPy, Matplotlib, Gradio.
Quick Start
-----------
```python
import gradio as gr
from colormap_chooser import ColormapChooser, setup_colormaps
# Set up colormaps with custom collections
categories = setup_colormaps(
included_collections=['matplotlib', 'cmocean', 'scientific'],
excluded_collections=['colorcet']
)
chooser = ColormapChooser(
categories=categories,
gallery_kwargs=dict(columns=4, allow_preview=True, height="400px")
)
with gr.Blocks() as demo:
with gr.Row():
chooser.render() # inserts the component cluster
# Use chooser.selected_name as an input to your plotting fn
import numpy as np, matplotlib.pyplot as plt
def show_demo(cmap_name):
data = np.random.rand(32, 32)
fig, ax = plt.subplots()
im = ax.imshow(data, cmap=cmap_name)
ax.set_title(cmap_name)
fig.colorbar(im, ax=ax)
return fig
out = gr.Plot()
chooser.selected_name.change(show_demo, chooser.selected_name, out)
demo.launch()
```
Installation
------------
Drop this file in your project (e.g., `colormap_chooser.py`) and import.
Customizing
-----------
Pass your own category dict, default sampling counts, or CSS overrides at
construction time; see class docstring below.
"""
from __future__ import annotations
import numpy as np
import matplotlib as mpl
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import gradio as gr
from typing import Any, Dict, List, Optional, Sequence, Tuple
# ------------------------------------------------------------------
# Default category mapping (extend or replace at init)
# ------------------------------------------------------------------
DEFAULT_CATEGORIES: Dict[str, List[str]] = {
"Perceptually Uniform": ["viridis", "plasma", "inferno", "magma", "cividis"],
"Sequential": ["Blues", "Greens", "Oranges", "Purples", "Reds", "Greys"],
"Diverging": ["coolwarm", "bwr", "seismic", "PiYG", "PRGn", "RdBu"],
"Qualitative": ["tab10", "tab20", "Set1", "Set2", "Accent"],
}
# ------------------------------------------------------------------
# Colormap setup functions
# ------------------------------------------------------------------
def load_matplotlib_colormaps():
"""
Load matplotlib's built-in colormaps directly.
Returns dict of colormap_name -> colormap_object
"""
matplotlib_cmaps = {}
# Get all matplotlib colormaps
for name in plt.colormaps():
try:
cmap = plt.get_cmap(name)
matplotlib_cmaps[name] = cmap
except Exception:
continue
return matplotlib_cmaps
def load_external_colormaps():
"""
Load colormaps from external packages (like colormaps, cmocean, etc.).
Returns dict of colormap_name -> colormap_object
"""
external_cmaps = {}
# Try to load from colormaps package
try:
import colormaps
for attr_name in dir(colormaps):
if not attr_name.startswith('_'):
try:
attr_value = getattr(colormaps, attr_name)
# Check if it looks like a colormap
if hasattr(attr_value, '__call__') or hasattr(attr_value, 'colors'):
external_cmaps[attr_name] = attr_value
except Exception:
continue
except ImportError:
pass
return external_cmaps
def categorize_colormaps(
colormap_dict: Dict[str, any],
included_collections: List[str],
excluded_collections: List[str]
) -> Dict[str, List[str]]:
"""
Categorize colormaps by type with priority ordering.
Args:
colormap_dict: Dict of colormap_name -> colormap_object
included_collections: List of collection names to include
excluded_collections: List of collection names to exclude
Returns:
Dict {"Category": [list_of_names]} with colormaps ordered by collection priority
"""
# Known categorizations based on documentation
matplotlib_sequential = {
'viridis', 'plasma', 'inferno', 'magma', 'cividis', # Perceptually uniform
'ylorbr', 'ylorrd', 'orrd', 'purd', 'rdpu', 'bupu', # Multi-hue sequential
'gnbu', 'pubu', 'ylgnbu', 'pubugn', 'bugn', 'ylgn',
'binary', 'gist_yarg', 'gist_gray', 'gray', 'bone', 'pink', # Sequential (2)
'spring', 'summer', 'autumn', 'winter', 'cool', 'wistia',
'hot', 'afmhot', 'gist_heat', 'copper'
}
# Single-color sequential maps to exclude
single_color_sequential = {
'blues', 'greens', 'oranges', 'purples', 'reds', 'greys'
}
matplotlib_diverging = {
'piyg', 'prgn', 'brbg', 'puor', 'rdgy', 'rdbu',
'rdylbu', 'rdylgn', 'spectral', 'coolwarm', 'bwr', 'seismic',
'berlin', 'managua', 'vanimo'
}
matplotlib_qualitative = {
'pastel1', 'pastel2', 'paired', 'accent',
'dark2', 'set1', 'set2', 'set3',
'tab10', 'tab20', 'tab20b', 'tab20c'
}
matplotlib_miscellaneous = {
'flag', 'prism', 'ocean', 'gist_earth', 'terrain', 'gist_stern',
'gnuplot', 'gnuplot2', 'cmrmap', 'cubehelix', 'brg',
'gist_rainbow', 'rainbow', 'jet', 'turbo', 'nipy_spectral',
'gist_ncar', 'twilight', 'twilight_shifted', 'hsv'
}
# External colormap collections
cmocean_sequential = {
'thermal', 'haline', 'solar', 'ice', 'gray', 'oxy', 'deep', 'dense',
'algae', 'matter', 'turbid', 'speed', 'amp', 'tempo', 'rain'
}
cmocean_diverging = {'balance', 'delta', 'curl', 'diff', 'tarn'}
cmocean_other = {'phase', 'topo'}
scientific_sequential = {
'batlow', 'batlowK', 'batlowW', 'acton', 'bamako', 'bilbao', 'buda', 'davos',
'devon', 'grayC', 'hawaii', 'imola', 'lajolla', 'lapaz', 'nuuk', 'oslo',
'tokyo', 'turku', 'actonS', 'bamO', 'brocO', 'corko', 'corkO', 'davosS',
'grayCS', 'hawaiiS', 'imolaS', 'lajollaS', 'lapazS', 'nuukS', 'osloS',
'tokyoS', 'turkuS'
}
scientific_diverging = {
'bam', 'bamo', 'berlin', 'broc', 'brocO', 'cork', 'corko', 'lisbon',
'managua', 'roma', 'romao', 'tofino', 'vanimo', 'vik', 'viko'
}
cmasher_sequential = {
'amber', 'amethyst', 'apple', 'arctic', 'autumn', 'bubblegum', 'chroma',
'cosmic', 'dusk', 'ember', 'emerald', 'flamingo', 'freeze', 'gem', 'gothic',
'heat', 'jungle', 'lavender', 'neon', 'neutral', 'nuclear', 'ocean',
'pepper', 'plasma_r', 'rainforest', 'savanna', 'sunburst', 'swamp', 'torch',
'toxic', 'tree', 'voltage', 'voltage_r'
}
cmasher_diverging = {
'copper', 'emergency', 'fusion', 'guppy', 'holly', 'iceburn', 'infinity',
'pride', 'prinsenvlag', 'redshift', 'seasons', 'seaweed', 'viola',
'waterlily', 'watermelon', 'wildfire'
}
# Helper function to determine collection priority
def get_collection_priority(name_lower):
# Check matplotlib first (highest priority)
if (name_lower in matplotlib_sequential or name_lower in matplotlib_diverging or
name_lower in matplotlib_qualitative or name_lower in matplotlib_miscellaneous):
return 0
# Then cmocean
elif (name_lower in cmocean_sequential or name_lower in cmocean_diverging or name_lower in cmocean_other):
return 1
# Then scientific
elif (name_lower in scientific_sequential or name_lower in scientific_diverging):
return 2
# Then cmasher
elif (name_lower in cmasher_sequential or name_lower in cmasher_diverging):
return 3
# Everything else
else:
return 4
# Collect all valid colormaps with their categories and priorities
valid_colormaps = []
for name, cmap_obj in colormap_dict.items():
name_lower = name.lower()
# Skip numbered variants (like brbg_9, set1_9, brbg_4_r, piyg_8_r, etc.)
parts = name_lower.split('_')
if len(parts) >= 2:
# Check if second-to-last part is a digit (handles both name_4 and name_4_r)
if parts[-2].isdigit():
continue
# Also check if last part is a digit (handles name_4)
if parts[-1].isdigit():
continue
# Skip single-color sequential maps
if name_lower in single_color_sequential:
continue
# Check if we should include this colormap based on collection filters
include_cmap = True
# Check excluded collections
for excluded in excluded_collections:
if excluded.lower() in name_lower:
include_cmap = False
break
if not include_cmap:
continue
# Check included collections
if included_collections:
include_cmap = False
for included in included_collections:
if (included.lower() in name_lower or
# Special handling for matplotlib colormaps
(included == 'matplotlib' and name in plt.colormaps()) or
# Special handling for known colormap sets
name_lower in cmocean_sequential or name_lower in cmocean_diverging or name_lower in cmocean_other or
name_lower in scientific_sequential or name_lower in scientific_diverging or
name_lower in cmasher_sequential or name_lower in cmasher_diverging):
include_cmap = True
break
if not include_cmap:
continue
# Categorize the colormap
category = None
if (name_lower in matplotlib_qualitative or
any(qual in name_lower for qual in ['tab10', 'tab20', 'set1', 'set2', 'set3', 'paired', 'accent', 'pastel', 'dark2'])):
category = "Qualitative"
elif (name_lower in cmocean_sequential or name_lower in scientific_sequential or
name_lower in cmasher_sequential or name_lower in matplotlib_sequential or
'sequential' in name_lower or
any(seq in name_lower for seq in ['viridis', 'plasma', 'inferno', 'magma', 'cividis'])):
category = "Sequential"
elif (name_lower in cmocean_diverging or name_lower in scientific_diverging or
name_lower in cmasher_diverging or name_lower in matplotlib_diverging or
'diverging' in name_lower or
any(div in name_lower for div in ['bwr', 'coolwarm', 'seismic', 'rdbu', 'rdgy', 'piyg', 'prgn', 'brbg'])):
category = "Diverging"
else:
category = "Other"
if category:
priority = get_collection_priority(name_lower)
valid_colormaps.append((name, category, priority))
# Sort by category, then by priority, then by name
valid_colormaps.sort(key=lambda x: (x[1], x[2], x[0].lower()))
# Group by category while maintaining order
categories = {
"Sequential": [],
"Diverging": [],
"Qualitative": [],
"Other": []
}
for name, category, priority in valid_colormaps:
categories[category].append(name)
# Remove empty categories and hide "Other" category
final_categories = {}
for cat_name, cmap_names in categories.items():
if cmap_names and cat_name != "Other": # Hide "Other" category
final_categories[cat_name] = cmap_names
return final_categories
def setup_colormaps(
included_collections: Optional[List[str]] = None,
excluded_collections: Optional[List[str]] = None,
additional_colormaps: Optional[Dict[str, any]] = None
) -> Dict[str, List[str]]:
"""
Set up and categorize colormaps from various sources.
Args:
included_collections: List of collection names to include
(e.g., ['matplotlib', 'cmocean', 'scientific'])
excluded_collections: List of collection names to exclude
additional_colormaps: Dict of additional colormaps to include
Returns:
Dict of {"Category": [list_of_colormap_names]} ready for ColormapChooser
"""
if excluded_collections is None:
excluded_collections = ['colorcet', 'carbonplan', 'sciviz']
if included_collections is None:
included_collections = ['matplotlib', 'cmocean', 'scientific', 'cmasher', 'colorbrewer', 'cartocolors']
# Combine all colormaps
all_colormaps = {}
# Add matplotlib colormaps
if 'matplotlib' in included_collections:
matplotlib_cmaps = load_matplotlib_colormaps()
all_colormaps.update(matplotlib_cmaps)
print(f"Added {len(matplotlib_cmaps)} matplotlib colormaps")
# Add external colormaps
try:
external_cmaps = load_external_colormaps()
all_colormaps.update(external_cmaps)
print(f"Added {len(external_cmaps)} external colormaps")
except Exception as e:
print(f"Could not load external colormaps: {e}")
# Add any additional colormaps
if additional_colormaps:
all_colormaps.update(additional_colormaps)
print(f"Added {len(additional_colormaps)} additional colormaps")
# Categorize colormaps
return categorize_colormaps(all_colormaps, included_collections, excluded_collections)
# ------------------------------------------------------------------
# Utility helpers
# ------------------------------------------------------------------
def _flatten_categories(categories: Dict[str, Sequence[str]]) -> List[str]:
names = []
for _, vals in categories.items():
names.extend(vals)
# maintain insertion order; drop dupes while preserving first occurrence
seen = set()
out = []
for n in names:
if n not in seen:
seen.add(n)
out.append(n)
return out
def _build_name2cat(categories: Dict[str, Sequence[str]]) -> Dict[str, str]:
m = {}
for cat, vals in categories.items():
for n in vals:
m[n] = cat
return m
# ------------------------------------------------------------------
# Sampling policy
# ------------------------------------------------------------------
def _is_categorical_cmap(
cmap: mcolors.Colormap,
declared_category: Optional[str] = None,
qualitative_label: str = "Qualitative",
max_auto: int = 32,
) -> bool:
"""Heuristic: treat as categorical/qualitative.
Priority:
1. If user-declared category == qualitative_label β†’ True.
2. If ListedColormap with small N β†’ True.
3. If colormap name suggests it's qualitative β†’ True.
4. Else False (continuous).
"""
# Check if explicitly declared as qualitative
if declared_category == qualitative_label:
return True
# Check if it's a ListedColormap with small N
if isinstance(cmap, mcolors.ListedColormap) and cmap.N <= max_auto:
return True
# Additional check: if the colormap name suggests it's qualitative
# This is a fallback in case the declared_category doesn't match exactly
if hasattr(cmap, 'name'):
name_lower = cmap.name.lower()
qualitative_names = {
'tab10', 'tab20', 'tab20b', 'tab20c', 'set1', 'set2', 'set3',
'pastel1', 'pastel2', 'paired', 'accent', 'dark2'
}
if name_lower in qualitative_names:
return True
return False
def _cmap_strip(
name: str,
width: int = 10,
height: int = 16,
smooth_steps: int = 20,
declared_category: Optional[str] = None,
qualitative_label: str = "Qualitative",
max_auto: int = 32,
):
"""Return RGB uint8 preview strip for *name* colormap.
Continuous maps are resampled to *smooth_steps* and linearly interpolated.
Categorical maps use actual number of colors, but adapt to available width.
"""
cmap = mpl.colormaps[name]
categorical = _is_categorical_cmap(
cmap, declared_category=declared_category, qualitative_label=qualitative_label, max_auto=max_auto
)
if categorical:
n = cmap.N
if hasattr(cmap, "colors"):
cols = np.asarray(cmap.colors)
if cols.shape[1] == 4:
cols = cols[:, :3]
else:
xs = np.linspace(0, 1, n, endpoint=False) + (0.5 / n)
cols = cmap(xs)[..., :3]
# Adaptive approach based on available width
min_block_width = 3 # Minimum pixels per color block for visibility
if width >= n * min_block_width:
# We have enough width to show all colors as distinct blocks
block_w = width // n
selected_cols = cols
num_blocks = n
else:
# Not enough width - show a representative sample
max_colors_that_fit = max(2, width // min_block_width) # At least 2 colors
if max_colors_that_fit >= n:
# We can fit all colors
selected_cols = cols
num_blocks = n
block_w = width // n
else:
# Sample evenly across the colormap
indices = np.linspace(0, n-1, max_colors_that_fit, dtype=int)
selected_cols = cols[indices]
num_blocks = max_colors_that_fit
block_w = width // num_blocks
# Debug output for categorical sampling
if name.lower() in ['tab10', 'tab20', 'set1', 'set2', 'accent', 'paired']:
print(f'CATEGORICAL SAMPLING DEBUG: {name}')
print(f' n (total colors): {n}')
print(f' width: {width}')
print(f' num_blocks (colors shown): {num_blocks}')
print(f' block_w (width per color): {block_w}')
print(f' showing all colors: {num_blocks == n}')
print('---')
# Create the array with discrete blocks
arr = np.repeat(selected_cols[np.newaxis, :, :], height, axis=0) # (h,num_blocks,3)
arr = np.repeat(arr, block_w, axis=1) # (h,num_blocks*block_w,3)
# Handle any remaining width
current_width = arr.shape[1]
if current_width < width:
# Pad by extending the last color
pad = width - current_width
last_color = arr[:, -1:, :] # Get last column
padding = np.repeat(last_color, pad, axis=1)
arr = np.concatenate([arr, padding], axis=1)
elif current_width > width:
# Trim to exact width
arr = arr[:, :width, :]
return (arr * 255).astype(np.uint8)
# continuous - unchanged
xs = np.linspace(0, 1, smooth_steps)
cols = cmap(xs)[..., :3]
xi = np.linspace(0, smooth_steps - 1, width)
lo = np.floor(xi).astype(int)
hi = np.minimum(lo + 1, smooth_steps - 1)
t = xi - lo
strip = (1 - t)[:, None] * cols[lo] + t[:, None] * cols[hi]
arr = np.repeat(strip[np.newaxis, :, :], height, axis=0)
return (arr * 255).astype(np.uint8)
# ------------------------------------------------------------------
# ColormapChooser class
# ------------------------------------------------------------------
class ColormapChooser:
"""Reusable scrollable colormap selector for Gradio.
Parameters
----------
categories:
Dict mapping *Category Label* β†’ list of cmap names. If None, uses
DEFAULT_CATEGORIES defined above. You may pass additional categories or
override existing ones. Order preserved.
smooth_steps:
Approx sample count for continuous maps (default 20).
strip_width:
Pixel width of preview strip images (default 512).
strip_height:
Pixel height of preview strip images (default 16).
css_height:
Max CSS height (pixels) for the scrollable gallery viewport.
qualitative_label:
Category label used to force qualitative sampling when present.
max_auto:
If a ListedColormap has N <= max_auto, treat as categorical even if not
declared Qualitative.
elem_id:
DOM id for the gallery (used to scope CSS overrides). Default 'cmap_gallery'.
show_search:
Whether to render the search Textbox.
show_category:
Whether to render the category Radio selector.
show_preview:
Show the big preview strip under the gallery. Off by default.
show_selected_name:
Show the textbox that echoes the selected colormap name. Off by default.
show_selected_info:
Show the markdown info line. Off by default.
gallery_kwargs:
Dictionary of keyword arguments to pass to the Gradio Gallery component
when it is created. For example, `columns=4, allow_preview=True, height="400px"`.
Public attributes after render():
category (optional)
search (optional)
gallery
preview
selected_name (Textbox; value string)
selected_info (Markdown)
names_state (State of current filtered cmap names)
Usage: see module Quick Start above.
"""
def __init__(
self,
*,
categories: Optional[Dict[str, Sequence[str]]] = None,
smooth_steps: int = 10,
strip_width: int = 10,
strip_height: int = 16,
css_height: int = 240,
qualitative_label: str = "Qualitative",
max_auto: int = 32,
elem_id: str = "cmap_gallery",
show_search: bool = True,
show_category: bool = True,
columns: int = 3,
thumb_margin_px: int = 2, # NEW
gallery_kwargs: Optional[Dict[str, Any]] = None,
show_preview: bool = False,
show_selected_name: bool = False,
show_selected_info: bool = True,
) -> None:
self.categories = categories if categories is not None else DEFAULT_CATEGORIES
self.smooth_steps = smooth_steps
self.strip_width = strip_width
self.strip_height = strip_height
self.css_height = css_height
self.qualitative_label = qualitative_label
self.max_auto = max_auto
self.elem_id = elem_id
self.show_search = show_search
self.show_category = show_category
self.columns = columns
self.thumb_margin_px = thumb_margin_px # NEW
self.gallery_kwargs = gallery_kwargs or {}
# visibility flags
self.show_preview = show_preview
self.show_selected_name = show_selected_name
self.show_selected_info = show_selected_info
self._all_names = _flatten_categories(self.categories)
self._name2cat = _build_name2cat(self.categories)
self._tile_cache: Dict[str, np.ndarray] = {}
# public gradio components (populated in render)
self.category = None
self.search = None
self.gallery = None
self.preview = None
self.selected_name = None
self.selected_info = None
self.names_state = None
# ------------------
# internal helpers
# ------------------
def _tile(self, name: str) -> np.ndarray:
if name not in self._tile_cache:
self._tile_cache[name] = _cmap_strip(
name,
width=self.strip_width,
height=self.strip_height,
smooth_steps=self.smooth_steps,
declared_category=self._name2cat.get(name),
qualitative_label=self.qualitative_label,
max_auto=self.max_auto,
)
return self._tile_cache[name]
def _make_gallery_items(self, names: Sequence[str]):
return [(self._tile(n), n) for n in names]
# ------------------
# event functions
# ------------------
def _filter(self, cat: str, s: str):
if self.show_category and cat in self.categories:
names = list(self.categories[cat])
else:
names = list(self._all_names)
if s and self.show_search:
sl = s.lower()
names = [n for n in names if sl in n.lower()]
# Remember new list for the select-callback
self.names_state.value = names
# 1) return an updated gallery
gkw = {
"value": self._make_gallery_items(names),
"selected_index": None,
}
gkw.update(self.gallery_kwargs)
gallery_update = gr.Gallery(**gkw)
# 2) clear the other widgets so old selection disappears
preview_update = gr.update(value=None)
name_update = gr.update(value="")
info_update = gr.update(value="")
return gallery_update, preview_update, name_update, info_update
def _select(self, evt: gr.SelectData, names: Sequence[str]):
if not names or evt.index is None or evt.index >= len(names):
return gr.update(), "", "Nothing selected"
name = names[evt.index]
big = _cmap_strip(
name,
width=max(self.strip_width * 2, 768),
height=max(self.strip_height * 2, 32),
smooth_steps=self.smooth_steps,
declared_category=self._name2cat.get(name),
qualitative_label=self.qualitative_label,
max_auto=self.max_auto,
)
info = f"**Selected:** `{name}` _(Category: {self._name2cat.get(name, '?')})_"
return big, name, info
# ------------------
# CSS block builder
# ------------------
def css(self) -> str:
return f"""
/* ───── 0. easy visual check the CSS is live (remove later) ───── */
#{self.elem_id} {{
/* background:rgba(255,255,0,.05); */
}}
/* the wrapper *is* the .block, so it owns the padding var */
#{self.elem_id}_wrap {{
padding: 0 !important;
--block-padding: 0 !important;
}}
/* ───── 1. the wrapper Gradio marks .fixed-height: make it scroll ─── */
#{self.elem_id} .grid-wrap {{
height: {self.css_height}px; /* kill inline 200 px or similar */
max-height: {self.css_height}px; /* cap the gallery’s height */
overflow-y: auto; /* rows that don’t fit will scroll */
}}
/* ───── 2. the real grid: keep masonry maths intact, tweak gap ─── */
#{self.elem_id} .grid-container {{
height: auto !important; /* sometimes Gradio sets one */
gap: 7px; /* tighter gutters (define attr) */
grid-auto-rows:auto !important;
}}
/* ───── 3. thumbnail boxes keep your ultra-wide shape ──────────── */
#{self.elem_id} .thumbnail-item {{
aspect-ratio: 3/1; /* e.g. 5/1 */
height: auto !important; /* beats Gradio’s inline 100 % */
margin: {self.thumb_margin_px}px !important;
overflow: hidden; /* just in case */
}}
/* ───── 4. images fill each box neatly ─────────────────────────── */
#{self.elem_id} img {{
width: 100%;
height: 100%;
object-fit: cover; /* crop to fill */
object-position: left;
display: block; /* kill inline-img whitespace */
}}
/* ───── 5. widen the β€œSelected:” info line ───────────────────── */
.cmap_selected_info {{
max-width: 100% !important; /* kill default 45 rem limit */
}}
"""
# ------------------
# Render into an existing Blocks context
# ------------------
def render(self):
"""Create Gradio UI elements and wire callbacks.
Must be called *inside* an active `gr.Blocks()` context.
Returns a tuple `(components_dict)` for convenience.
"""
# initial list: first category or all
if self.show_category:
first_cat = next(iter(self.categories))
init_names = list(self.categories[first_cat])
else:
init_names = list(self._all_names)
# preheat tiles lazily on demand; no bulk precompute
# (call _tile when building gallery items)
# layout
if self.show_category or self.show_search:
with gr.Row():
if self.show_category:
self.category = gr.Radio(list(self.categories.keys()), value=first_cat, label="Category")
else:
self.category = gr.State(None) # shim so filter signature works
if self.show_search:
self.search = gr.Textbox(label="Search", placeholder="type to filter...")
else:
self.search = gr.State("")
else:
self.category = gr.State(None)
self.search = gr.State("")
self.names_state = gr.State(init_names)
gkw = {
"value": self._make_gallery_items(init_names),
"label": None, # remove label
"allow_preview": False,
"elem_id": self.elem_id,
"show_share_button": False,
"columns": getattr(self, "columns", 3),
}
gkw.update(self.gallery_kwargs)
self.gallery = gr.Gallery(**gkw)
self.preview = gr.Image(
label="Preview", interactive=False, height=60, visible=self.show_preview
)
self.selected_name = gr.Textbox(
label="Selected cmap", interactive=False, visible=self.show_selected_name
)
self.selected_info = gr.Markdown(
visible=self.show_selected_info,
elem_classes="cmap_selected_info",
)
# wiring
if self.show_category or self.show_search:
def _wrapped_filter(cat, s):
if not self.show_category:
cat = None
if not self.show_search:
s = ""
return self._filter(cat, s)
outputs = [self.gallery,
self.preview,
self.selected_name,
self.selected_info]
if self.show_category:
self.category.change(
_wrapped_filter,
[self.category, self.search],
outputs
)
if self.show_search:
self.search.change(
_wrapped_filter,
[self.category, self.search],
outputs
)
def _wrapped_select(evt: gr.SelectData, names):
return self._select(evt, names)
self.gallery.select(_wrapped_select, [self.names_state],
[self.preview, self.selected_name, self.selected_info])
return {
"gallery": self.gallery,
"selected_name": self.selected_name,
"preview": self.preview,
"info": self.selected_info,
"category": self.category,
"search": self.search,
"names_state": self.names_state,
}
# ==========================================================
# NEW TAB-BASED RENDERER
# ==========================================================
def render_tabs(self):
"""
Render the chooser as one Gallery per category inside a gradio Tabs
container. No search box is provided – each tab already filters
by category.
Returns the same components dict as `render()`, plus a "galleries"
dict that maps category β†’ Gallery component.
"""
galleries = {}
with gr.Tabs() as root_tabs:
# --- build a tab + gallery for every category -------------
for cat, names in self.categories.items():
with gr.TabItem(cat):
gkw = {
"value": self._make_gallery_items(names),
"label": None, # remove label
"allow_preview": False,
"show_share_button": False,
"elem_id": self.elem_id,
"columns": getattr(self, "columns", 3),
"show_label": False
}
gkw.update(self.gallery_kwargs)
with gr.Row(elem_id=f"{self.elem_id}_wrap"): # ← new wrapper
gal = gr.Gallery(**gkw)
galleries[cat] = gal
# --- shared preview / meta area under the tabs ----------------
self.preview = gr.Image(
label="Preview", interactive=False, height=60, visible=self.show_preview
)
self.selected_name = gr.Textbox(
label="Selected cmap", interactive=False, visible=self.show_selected_name
)
self.selected_info = gr.Markdown(
visible=self.show_selected_info,
elem_classes="cmap_selected_info",
)
# --- wiring: every gallery uses the same _select callback -----
def _wrapped_select(evt: gr.SelectData, names):
return self._select(evt, names)
for cat, gal in galleries.items():
gal.select(
_wrapped_select,
[gr.State(list(self.categories[cat]))], # names list
[self.preview, self.selected_name, self.selected_info],
)
return {
"galleries": galleries,
"selected_name": self.selected_name,
"preview": self.preview,
"info": self.selected_info,
"tabs": root_tabs,
}
# ------------------------------------------------------------------
# Minimal self-demo (only runs if module executed directly)
# ------------------------------------------------------------------
if __name__ == "__main__":
chooser = ColormapChooser()
with gr.Blocks(css=chooser.css()) as demo:
gr.Markdown("## Colormap Chooser Demo")
chooser.render()
demo.launch()