Spaces:
Running
Running
import requests,re,base64,io,numpy as np | |
from PIL import Image,ImageOps | |
import torch,gradio as gr | |
# Helper to load image from URL | |
def loadImageFromUrl(url): | |
if url.startswith("data:image/"): | |
i = Image.open(io.BytesIO(base64.b64decode(url.split(",")[1]))) | |
elif url.startswith("s3://"): | |
raise Exception("S3 URLs not supported in this interface") | |
else: | |
response = requests.get(url, timeout=5) | |
if response.status_code != 200: | |
raise Exception(response.text) | |
i = Image.open(io.BytesIO(response.content)) | |
i = ImageOps.exif_transpose(i) | |
if i.mode != "RGBA": | |
i = i.convert("RGBA") | |
alpha = i.split()[-1] | |
image = Image.new("RGB", i.size, (0, 0, 0)) | |
image.paste(i, mask=alpha) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = torch.from_numpy(image)[None,] | |
return image | |
# Fetch data from Gelbooru or None | |
def fetch_gelbooru_images(site, OR_tags, AND_tags, exclude_tag, score, count, Safe, Questionable, Explicit): # add 'api_key' and 'user_id' if necessary | |
# AND_tags | |
AND_tags = AND_tags.rstrip(',').rstrip(' ') | |
AND_tags = AND_tags.split(',') | |
AND_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in AND_tags] | |
AND_tags = [item for item in AND_tags if item] | |
if len(AND_tags) > 1: | |
AND_tags = '+'.join(AND_tags) | |
else: | |
AND_tags = AND_tags[0] if AND_tags else '' | |
# OR_tags | |
OR_tags = OR_tags.rstrip(',').rstrip(' ') | |
OR_tags = OR_tags.split(',') | |
OR_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in OR_tags] | |
OR_tags = [item for item in OR_tags if item] | |
if len(OR_tags) > 1: | |
OR_tags = '{' + ' ~ '.join(OR_tags) + '}' | |
else: | |
OR_tags = OR_tags[0] if OR_tags else '' | |
# Exclude tags | |
exclude_tag = '+'.join('-' + item.strip().replace(' ', '_') for item in exclude_tag.split(',')) | |
rate_exclusion = "" | |
if not Safe: | |
if site == "None": | |
rate_exclusion += "+-rating%3asafe" | |
else: | |
rate_exclusion += "+-rating%3ageneral" | |
if not Questionable: | |
if site == "None": | |
rate_exclusion += "+-rating%3aquestionable" | |
else: | |
rate_exclusion += "+-rating%3aquestionable+-rating%3aSensitive" | |
if not Explicit: | |
if site == "None": | |
rate_exclusion += "+-rating%3aexplicit" | |
else: | |
rate_exclusion += "+-rating%3aexplicit" | |
if site == "None": | |
base_url = "https://api.example.com/index.php" | |
else: | |
base_url = "https://gelbooru.com/index.php" | |
query_params = ( | |
f"page=dapi&s=post&q=index&tags=sort%3arandom+" | |
f"{exclude_tag}+{OR_tags}+{AND_tags}+{rate_exclusion}" | |
f"+score%3a>{score}&limit={count}&json=1" | |
#f"+score%3a>{score}&api_key={api_key}&user_id={user_id}&limit={count}&json=1" | |
) | |
url = f"{base_url}?{query_params}".replace("-+", "") | |
url = re.sub(r"\++", "+", url) | |
response = requests.get(url, verify=True) | |
if site == "None": | |
posts = response.json() | |
else: | |
posts = response.json().get('post', []) | |
image_urls = [post.get("file_url", "") for post in posts] | |
tags_list = [post.get("tags", "").replace(" ", ", ").replace("_", " ").replace("(", "\\(").replace(")", "\\)").strip() for post in posts] | |
#tags_list = [post.get("tags", "").replace("_", " ").replace(" ", ", ").strip() for post in posts] | |
ids_list = [str(post.get("id", "")) for post in posts] | |
if site == "Gelbooru": | |
post_urls = [f"https://gelbooru.com/index.php?page=post&s=view&id={id}" for id in ids_list] | |
#else: | |
# post_urls = [f"https://api.none.com/index.php?page=post&s=view&id={id}" for id in ids_list] | |
return image_urls, tags_list, post_urls | |
# Main function to fetch and return processed images | |
def gelbooru_gradio( | |
OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit, site # add 'api_key' and 'user_id' if necessary | |
): | |
image_urls, tags_list, post_urls = fetch_gelbooru_images( | |
site, OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit # 'api_key' and 'user_id' if necessary | |
) | |
if not image_urls: | |
return [], [], [], [] | |
image_data = [] | |
for url in image_urls: | |
try: | |
image = loadImageFromUrl(url) | |
image = (image * 255).clamp(0, 255).cpu().numpy().astype(np.uint8)[0] | |
image = Image.fromarray(image) | |
image_data.append(image) | |
except Exception as e: | |
print(f"Error loading image from {url}: {e}") | |
continue | |
return image_data, tags_list, post_urls, image_urls | |
# Update UI on image click | |
def on_select(evt: gr.SelectData, tags_list, post_url_list, image_url_list): | |
idx = evt.index | |
if idx < len(tags_list): | |
return tags_list[idx], post_url_list[idx], image_url_list[idx] | |
return "No tags", "", "" |