File size: 1,971 Bytes
e546fea
 
 
f9858d8
e546fea
 
 
 
958511f
e546fea
958511f
 
 
f9858d8
 
958511f
 
 
 
 
 
 
e546fea
 
f9858d8
e546fea
958511f
 
3e75999
 
e546fea
4c33e65
e546fea
 
 
 
 
3e75999
 
43b320b
 
958511f
e546fea
43b320b
 
c368dca
e546fea
 
 
20a2fe0
 
958511f
521737f
3231e76
43b320b
958511f
 
3231e76
958511f
e546fea
3231e76
 
 
e546fea
 
958511f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
# import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
birefnet.to(device)
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


# @spaces.GPU
def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    image = load_img(im)
    input_images = transform_image(image).unsqueeze(0).to(device)
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    # return (image, origin)
    return image


# slider1 = ImageSlider(label="birefnet", type="pil")
img1 = gr.Image(type= "pil", image_mode="RGBA")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")


chameleon = load_img("chameleon.jpg", output_type="pil")

url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
demo = gr.Interface(
    fn, inputs=image, outputs=img1, examples=[chameleon], api_name="image"
)

# tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")


# demo = gr.TabbedInterface(
#     [tab1, tab2], ["image", "text"], title="birefnet for background removal (WIP 🛠️, works for linux)"
# )

if __name__ == "__main__":
    demo.launch()