Spaces:
Runtime error
Runtime error
File size: 5,220 Bytes
b4b75f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import matplotlib
matplotlib.use('Agg')
import gradio as gr
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import plotly.express as px
from plotly import subplots
import pandas as pd
import random
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_data = np.concatenate([x_train, x_test])
y_data = np.concatenate([y_train, y_test])
num_classes = 10
classes = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
clustering_model = from_pretrained_keras("johko/semantic-image-clustering")
# Get the cluster probability distribution of the input images.
clustering_probs = clustering_model.predict(x_data, batch_size=500, verbose=1)
# Get the cluster of the highest probability.
cluster_assignments = tf.math.argmax(clustering_probs, axis=-1).numpy()
# Store the clustering confidence.
# Images with the highest clustering confidence are considered the 'prototypes'
# of the clusters.
cluster_confidence = tf.math.reduce_max(clustering_probs, axis=-1).numpy()
clusters = defaultdict(list)
for idx, c in enumerate(cluster_assignments):
clusters[c].append((idx, cluster_confidence[idx]))
def get_cluster_size(cluster_number: int):
cluster_size = len(clusters[cluster_number-1])
return f"Cluster #{cluster_number} consists of {cluster_size} objects"
def get_images_from_cluster(cluster_number: int, num_images: int, image_mode: str):
position = 1
if image_mode == "Random Images from Cluster":
cluster_instances = clusters[cluster_number-1]
random.shuffle(cluster_instances)
else :
cluster_instances = sorted(clusters[cluster_number-1], key=lambda kv: kv[1], reverse=True)
fig = plt.figure()
for j in range(num_images):
image_idx = cluster_instances[j][0]
plt.subplot(1, num_images, position)
plt.imshow(x_data[image_idx].astype("uint8"))
plt.title(classes[y_data[image_idx][0]])
plt.axis("off")
position += 1
fig.tight_layout()
return fig
# labels = []
# images = []
# for j in range(num_images):
# image_idx = cluster_instances[j][0]
# images.append(x_data[image_idx].astype("uint8"))
# labels.append(classes[y_data[image_idx][0]])
# fig = subplots.make_subplots(rows=int(num_images/4)+1, cols=4, subplot_titles=labels)
# for j in range(num_images):
# fig.add_trace(px.imshow(images[j]).data[0], row=int(j/4)+1, col=j%4+1)
# fig.update_xaxes(visible=False)
# fig.update_yaxes(visible=False)
# return fig
def get_cluster_details(cluster_number: int):
cluster_label_counts = list()
cluster_label_counts = [0] * num_classes
instances = clusters[cluster_number-1]
for i, _ in instances:
cluster_label_counts[y_data[i][0]] += 1
class_count = zip(classes, cluster_label_counts)
class_count_dict = dict(class_count)
count_df = pd.Series(class_count_dict).to_frame()
fig_pie = px.pie(count_df, values=0, names=count_df.index, title='Number of class objects in cluster')
return fig_pie
def get_cluster_info(cluster_number: int, num_images: int, image_mode: str):
cluster_size = get_cluster_size(cluster_number)
img_fig = get_images_from_cluster(cluster_number, num_images, image_mode)
detail_fig = get_cluster_details(cluster_number)
return [cluster_size, img_fig, detail_fig]
article = """<center>
Authors: <a href='https://twitter.com/johko990' target='_blank'>Johannes Kolbe</a> after an example by [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/) on
<a href='https://keras.io/examples/vision/semantic_image_clustering/' target='_blank'>**keras.io**</a>"""
description = """<center>
# Semantic Image Clustering
This space is intended to give you insights to image clusters, created by a model trained with the [**Semantic Clustering by Adopting Nearest neighbors (SCAN)**](https://arxiv.org/abs/2005.12320)(Van Gansbeke et al., 2020) algorithm.
First choose one of the 20 clusters, and how many images you want to preview from it. There are two options for the images either *Random*, which as you might guess,
gives you random images from the cluster or *High Similarity*, which gives you images that are similar according to the learned representations of the cluster.
"""
demo = gr.Blocks()
with demo:
gr.Markdown(description)
with gr.Row():
btn = gr.Button("Get Cluster Info")
with gr.Column():
inp = [gr.Slider(minimum=1, maximum=20, step=1, label="Select Cluster"),
gr.Slider(minimum=6, maximum=15, step=1, label="Number of Images to Show", value=8),
gr.Radio(["Random Images from Cluster", "High Similarity Images"], label="Image Choice")]
with gr.Row():
with gr.Column():
out1 = [gr.Text(label="Cluster Size"), gr.Plot(label="Image Examples"), gr.Plot(label="Class details")]
gr.Markdown(article)
btn.click(fn=get_cluster_info, inputs=inp, outputs=out1)
demo.launch()
|