File size: 4,924 Bytes
665df06
 
 
9988a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import requests,re,base64,io,numpy as np
from PIL import Image,ImageOps
import torch,gradio as gr

# Helper to load image from URL
def loadImageFromUrl(url):
    if url.startswith("data:image/"):
        i = Image.open(io.BytesIO(base64.b64decode(url.split(",")[1])))
    elif url.startswith("s3://"):
        raise Exception("S3 URLs not supported in this interface")
    else:
        response = requests.get(url, timeout=5)
        if response.status_code != 200:
            raise Exception(response.text)
        i = Image.open(io.BytesIO(response.content))

    i = ImageOps.exif_transpose(i)
    if i.mode != "RGBA":
        i = i.convert("RGBA")

    alpha = i.split()[-1]
    image = Image.new("RGB", i.size, (0, 0, 0))
    image.paste(i, mask=alpha)

    image = np.array(image).astype(np.float32) / 255.0
    image = torch.from_numpy(image)[None,]
    return image

# Fetch data from Gelbooru or None
def fetch_gelbooru_images(site, OR_tags, AND_tags, exclude_tag, score, count, Safe, Questionable, Explicit): # add 'api_key' and 'user_id' if necessary
    # AND_tags
    AND_tags = AND_tags.rstrip(',').rstrip(' ')
    AND_tags = AND_tags.split(',')
    AND_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in AND_tags]
    AND_tags = [item for item in AND_tags if item]
    if len(AND_tags) > 1:
        AND_tags = '+'.join(AND_tags)
    else:
        AND_tags = AND_tags[0] if AND_tags else ''

    # OR_tags
    OR_tags = OR_tags.rstrip(',').rstrip(' ')
    OR_tags = OR_tags.split(',')
    OR_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in OR_tags]
    OR_tags = [item for item in OR_tags if item]
    if len(OR_tags) > 1:
        OR_tags = '{' + ' ~ '.join(OR_tags) + '}'
    else:
        OR_tags = OR_tags[0] if OR_tags else ''

    # Exclude tags
    exclude_tag = '+'.join('-' + item.strip().replace(' ', '_') for item in exclude_tag.split(','))

    rate_exclusion = ""
    if not Safe:
        if site == "None":
            rate_exclusion += "+-rating%3asafe"
        else:
            rate_exclusion += "+-rating%3ageneral"
    if not Questionable:
        if site == "None":
            rate_exclusion += "+-rating%3aquestionable"
        else:
            rate_exclusion += "+-rating%3aquestionable+-rating%3aSensitive"
    if not Explicit:
        if site == "None":
            rate_exclusion += "+-rating%3aexplicit"
        else:
            rate_exclusion += "+-rating%3aexplicit"

    if site == "None":
        base_url = "https://api.example.com/index.php"
    else:
        base_url = "https://gelbooru.com/index.php"

    query_params = (
        f"page=dapi&s=post&q=index&tags=sort%3arandom+"
        f"{exclude_tag}+{OR_tags}+{AND_tags}+{rate_exclusion}"
        f"+score%3a>{score}&limit={count}&json=1"
        #f"+score%3a>{score}&api_key={api_key}&user_id={user_id}&limit={count}&json=1"
    )
    url = f"{base_url}?{query_params}".replace("-+", "")
    url = re.sub(r"\++", "+", url)

    response = requests.get(url, verify=True)
    if site == "None":
        posts = response.json()
    else:
        posts = response.json().get('post', [])

    image_urls = [post.get("file_url", "") for post in posts]
    tags_list = [post.get("tags", "").replace(" ", ", ").replace("_", " ").replace("(", "\\(").replace(")", "\\)").strip() for post in posts]
    #tags_list = [post.get("tags", "").replace("_", " ").replace(" ", ", ").strip() for post in posts]
    ids_list = [str(post.get("id", "")) for post in posts]

    if site == "Gelbooru":
        post_urls = [f"https://gelbooru.com/index.php?page=post&s=view&id={id}" for id in ids_list]
    #else:
    #   post_urls = [f"https://api.none.com/index.php?page=post&s=view&id={id}" for id in ids_list]

    return image_urls, tags_list, post_urls

# Main function to fetch and return processed images
def gelbooru_gradio(
    OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit, site  # add 'api_key' and 'user_id' if necessary
):
    image_urls, tags_list, post_urls = fetch_gelbooru_images(
        site, OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit # 'api_key' and 'user_id' if necessary
    )

    if not image_urls:
        return [], [], [], []

    image_data = []
    for url in image_urls:
        try:
            image = loadImageFromUrl(url)
            image = (image * 255).clamp(0, 255).cpu().numpy().astype(np.uint8)[0]
            image = Image.fromarray(image)
            image_data.append(image)
        except Exception as e:
            print(f"Error loading image from {url}: {e}")
            continue

    return image_data, tags_list, post_urls, image_urls

# Update UI on image click
def on_select(evt: gr.SelectData, tags_list, post_url_list, image_url_list):
    idx = evt.index
    if idx < len(tags_list):
        return tags_list[idx], post_url_list[idx], image_url_list[idx]
    return "No tags", "", ""