File size: 1,289 Bytes
2c5aba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a634e56
2c5aba6
 
 
 
 
 
 
 
 
 
60daf2a
2c5aba6
60daf2a
2c5aba6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image


def preprocess(image, output, binarize, threshold):
   
    image = image.cpu().detach().numpy().squeeze()
    image = np.transpose(image,(1,2,0))
    image = (image + 1) * 0.5
    output = output.cpu().detach().numpy().squeeze()
    
    if binarize:
        output = np.where(output > threshold, 1., 0.)
        
    return image, output


def enlarge_array(output):
    df = pd.DataFrame(np.reshape(output, (14,14)))
    df = pd.DataFrame(np.repeat(df.values, 16, axis=0))
    df = pd.DataFrame(np.repeat(df.values, 16, axis=1))
    output = df.to_numpy()
    
    return output


def visualize_output(image, output, binarize, threshold):
    
    image, output = preprocess(image, output, binarize, threshold) 
    output = enlarge_array(output)
    output_mask = Image.fromarray(output * 255)

    fig = plt.figure(figsize = (6,6))
    plt.axis('off')
    plt.imshow(image)
    if binarize:
        plt.imshow(output_mask, alpha=.45)
    else:
        plt.imshow(output_mask, alpha=.45)
    fig.tight_layout(pad=0)
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    return data