from collections import defaultdict import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import cm import torch def draw_panoptic_segmentation(model,segmentation, segments_info): # get the used color map viridis = cm.get_cmap('viridis', torch.max(segmentation)) fig, ax = plt.subplots() ax.imshow(segmentation.cpu().numpy()) instances_counter = defaultdict(int) handles = [] # for each segment, draw its legend for segment in segments_info: segment_id = segment['id'] segment_label_id = segment['label_id'] segment_label = model.config.id2label[segment_label_id] label = f"{segment_label}-{instances_counter[segment_label_id]}" instances_counter[segment_label_id] += 1 color = viridis(segment_id) handles.append(mpatches.Patch(color=color, label=label)) # ax.legend(handles=handles) fig.savefig('final_mask.png') return 'final_mask.png'