Werli's picture
Upload 6 files
665df06 verified
raw
history blame
4.92 kB
import requests,re,base64,io,numpy as np
from PIL import Image,ImageOps
import torch,gradio as gr
# Helper to load image from URL
def loadImageFromUrl(url):
if url.startswith("data:image/"):
i = Image.open(io.BytesIO(base64.b64decode(url.split(",")[1])))
elif url.startswith("s3://"):
raise Exception("S3 URLs not supported in this interface")
else:
response = requests.get(url, timeout=5)
if response.status_code != 200:
raise Exception(response.text)
i = Image.open(io.BytesIO(response.content))
i = ImageOps.exif_transpose(i)
if i.mode != "RGBA":
i = i.convert("RGBA")
alpha = i.split()[-1]
image = Image.new("RGB", i.size, (0, 0, 0))
image.paste(i, mask=alpha)
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
return image
# Fetch data from Gelbooru or None
def fetch_gelbooru_images(site, OR_tags, AND_tags, exclude_tag, score, count, Safe, Questionable, Explicit): # add 'api_key' and 'user_id' if necessary
# AND_tags
AND_tags = AND_tags.rstrip(',').rstrip(' ')
AND_tags = AND_tags.split(',')
AND_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in AND_tags]
AND_tags = [item for item in AND_tags if item]
if len(AND_tags) > 1:
AND_tags = '+'.join(AND_tags)
else:
AND_tags = AND_tags[0] if AND_tags else ''
# OR_tags
OR_tags = OR_tags.rstrip(',').rstrip(' ')
OR_tags = OR_tags.split(',')
OR_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in OR_tags]
OR_tags = [item for item in OR_tags if item]
if len(OR_tags) > 1:
OR_tags = '{' + ' ~ '.join(OR_tags) + '}'
else:
OR_tags = OR_tags[0] if OR_tags else ''
# Exclude tags
exclude_tag = '+'.join('-' + item.strip().replace(' ', '_') for item in exclude_tag.split(','))
rate_exclusion = ""
if not Safe:
if site == "None":
rate_exclusion += "+-rating%3asafe"
else:
rate_exclusion += "+-rating%3ageneral"
if not Questionable:
if site == "None":
rate_exclusion += "+-rating%3aquestionable"
else:
rate_exclusion += "+-rating%3aquestionable+-rating%3aSensitive"
if not Explicit:
if site == "None":
rate_exclusion += "+-rating%3aexplicit"
else:
rate_exclusion += "+-rating%3aexplicit"
if site == "None":
base_url = "https://api.example.com/index.php"
else:
base_url = "https://gelbooru.com/index.php"
query_params = (
f"page=dapi&s=post&q=index&tags=sort%3arandom+"
f"{exclude_tag}+{OR_tags}+{AND_tags}+{rate_exclusion}"
f"+score%3a>{score}&limit={count}&json=1"
#f"+score%3a>{score}&api_key={api_key}&user_id={user_id}&limit={count}&json=1"
)
url = f"{base_url}?{query_params}".replace("-+", "")
url = re.sub(r"\++", "+", url)
response = requests.get(url, verify=True)
if site == "None":
posts = response.json()
else:
posts = response.json().get('post', [])
image_urls = [post.get("file_url", "") for post in posts]
tags_list = [post.get("tags", "").replace(" ", ", ").replace("_", " ").replace("(", "\\(").replace(")", "\\)").strip() for post in posts]
#tags_list = [post.get("tags", "").replace("_", " ").replace(" ", ", ").strip() for post in posts]
ids_list = [str(post.get("id", "")) for post in posts]
if site == "Gelbooru":
post_urls = [f"https://gelbooru.com/index.php?page=post&s=view&id={id}" for id in ids_list]
#else:
# post_urls = [f"https://api.none.com/index.php?page=post&s=view&id={id}" for id in ids_list]
return image_urls, tags_list, post_urls
# Main function to fetch and return processed images
def gelbooru_gradio(
OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit, site # add 'api_key' and 'user_id' if necessary
):
image_urls, tags_list, post_urls = fetch_gelbooru_images(
site, OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit # 'api_key' and 'user_id' if necessary
)
if not image_urls:
return [], [], [], []
image_data = []
for url in image_urls:
try:
image = loadImageFromUrl(url)
image = (image * 255).clamp(0, 255).cpu().numpy().astype(np.uint8)[0]
image = Image.fromarray(image)
image_data.append(image)
except Exception as e:
print(f"Error loading image from {url}: {e}")
continue
return image_data, tags_list, post_urls, image_urls
# Update UI on image click
def on_select(evt: gr.SelectData, tags_list, post_url_list, image_url_list):
idx = evt.index
if idx < len(tags_list):
return tags_list[idx], post_url_list[idx], image_url_list[idx]
return "No tags", "", ""