Werli commited on
Commit
43ad32d
·
verified ·
1 Parent(s): aecd31a

Upload 2 files

Browse files

Added "Rule34" and "Xbooru". Removed "OR_tags" as it's not needed anymore and renamed "AND_tags" to "Tags".

Files changed (2) hide show
  1. app.py +632 -636
  2. modules/booru.py +110 -131
app.py CHANGED
@@ -1,637 +1,633 @@
1
- import os,io,copy,json,requests,spaces,gradio as gr,numpy as np
2
- import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
3
- from datetime import datetime,timezone
4
- from collections import defaultdict
5
- from PIL import Image,ImageOps
6
- from modules.booru import gelbooru_gradio,fetch_gelbooru_images,on_select
7
- from apscheduler.schedulers.background import BackgroundScheduler
8
- from modules.classifyTags import classify_tags,process_tags
9
- from modules.reorganizer_model import reorganizer_list,reorganizer_class
10
- from modules.tag_enhancer import prompt_enhancer
11
- from modules.florence2 import process_image,single_task_list,update_task_dropdown
12
-
13
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
14
-
15
- TITLE = "Multi-Tagger v1.2"
16
- DESCRIPTION = """
17
- Multi-Tagger is a versatile application for advanced image analysis and captioning. Perfect for AI artists or enthusiasts, it offers a range of features:
18
-
19
- - **Automatic Tag Categorization**: Tags are grouped into categories.
20
- - **Tag Enhancement**: Boost your prompts with enhanced descriptions using a built-in prompt enhancer.
21
- - **Reorganizer**: Use a reorganizer model to format tags into a natural-language description.
22
- - **Batch Support**: Upload and process multiple images simultaneously.
23
- - **Downloadable Output**: Get almost all results as downloadable `.txt`, `.json`, and `.png` files in a `.zip` archive.
24
- - **Image Fetcher**: Search for images from **Gelbooru** using flexible tag filters.
25
- - **CUDA** and **CPU** support.
26
- """
27
-
28
- # Dataset v3 series of models:
29
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
30
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
31
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
32
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
33
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
34
- # Dataset v2 series of models:
35
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
36
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
37
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
38
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
39
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
40
- # IdolSankaku series of models:
41
- EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
42
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
43
- # Files to download from the repos
44
- MODEL_FILENAME = "model.onnx"
45
- LABEL_FILENAME = "selected_tags.csv"
46
-
47
- kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
48
- def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
49
- def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
50
- def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
51
-
52
- class Timer:
53
- def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
54
- def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
55
- def report(self,is_clear_checkpoints=True):
56
- max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1]
57
- for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
58
- if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint()
59
- def report_all(self):
60
- print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time
61
- for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
62
- total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
63
- def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
64
- class Predictor:
65
- def __init__(self):
66
- self.model_target_size = None
67
- self.last_loaded_repo = None
68
- def download_model(self, model_repo):
69
- csv_path = huggingface_hub.hf_hub_download(
70
- model_repo,
71
- LABEL_FILENAME,
72
- )
73
- model_path = huggingface_hub.hf_hub_download(
74
- model_repo,
75
- MODEL_FILENAME,
76
- )
77
- return csv_path, model_path
78
- def load_model(self, model_repo):
79
- if model_repo == self.last_loaded_repo:
80
- return
81
-
82
- csv_path, model_path = self.download_model(model_repo)
83
-
84
- tags_df = pd.read_csv(csv_path)
85
- sep_tags = load_labels(tags_df)
86
-
87
- self.tag_names = sep_tags[0]
88
- self.rating_indexes = sep_tags[1]
89
- self.general_indexes = sep_tags[2]
90
- self.character_indexes = sep_tags[3]
91
-
92
- model = rt.InferenceSession(model_path)
93
- _, height, width, _ = model.get_inputs()[0].shape
94
- self.model_target_size = height
95
-
96
- self.last_loaded_repo = model_repo
97
- self.model = model
98
- def prepare_image(self, path):
99
- image = Image.open(path)
100
- image = image.convert("RGBA")
101
- target_size = self.model_target_size
102
-
103
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
104
- canvas.alpha_composite(image)
105
- image = canvas.convert("RGB")
106
-
107
- # Pad image to square
108
- image_shape = image.size
109
- max_dim = max(image_shape)
110
- pad_left = (max_dim - image_shape[0]) // 2
111
- pad_top = (max_dim - image_shape[1]) // 2
112
-
113
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
114
- padded_image.paste(image, (pad_left, pad_top))
115
-
116
- # Resize
117
- if max_dim != target_size:
118
- padded_image = padded_image.resize(
119
- (target_size, target_size),
120
- Image.BICUBIC,
121
- )
122
- # Convert to numpy array
123
- image_array = np.asarray(padded_image, dtype=np.float32)
124
- # Convert PIL-native RGB to BGR
125
- image_array = image_array[:, :, ::-1]
126
- return np.expand_dims(image_array, axis=0)
127
-
128
- def create_file(self, content: str, directory: str, fileName: str) -> str:
129
- # Write the content to a file
130
- file_path = os.path.join(directory, fileName)
131
- if fileName.endswith('.json'):
132
- with open(file_path, 'w', encoding="utf-8") as file:
133
- file.write(content)
134
- else:
135
- with open(file_path, 'w+', encoding="utf-8") as file:
136
- file.write(content)
137
-
138
- return file_path
139
-
140
- def predict(
141
- self,
142
- gallery,
143
- model_repo,
144
- general_thresh,
145
- general_mcut_enabled,
146
- character_thresh,
147
- character_mcut_enabled,
148
- characters_merge_enabled,
149
- reorganizer_model_repo,
150
- additional_tags_prepend,
151
- additional_tags_append,
152
- tag_results,
153
- progress=gr.Progress()
154
- ):
155
- # Clear tag_results before starting a new prediction
156
- tag_results.clear()
157
-
158
- gallery_len = len(gallery)
159
- print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
160
-
161
- timer = Timer() # Create a timer
162
- progressRatio = 0.5 if reorganizer_model_repo else 1
163
- progressTotal = gallery_len + 1
164
- current_progress = 0
165
-
166
- self.load_model(model_repo)
167
- current_progress += progressRatio/progressTotal;
168
- progress(current_progress, desc="Initialize wd model finished")
169
- timer.checkpoint(f"Initialize wd model")
170
-
171
- txt_infos = []
172
- output_dir = tempfile.mkdtemp()
173
- if not os.path.exists(output_dir):
174
- os.makedirs(output_dir)
175
-
176
- sorted_general_strings = ""
177
- # Create categorized output string
178
- categorized_output_strings = []
179
- rating = None
180
- character_res = None
181
- general_res = None
182
-
183
- if reorganizer_model_repo:
184
- print(f"Reorganizer load model {reorganizer_model_repo}")
185
- reorganizer = reorganizer_class(reorganizer_model_repo, loadModel=True)
186
- current_progress += progressRatio/progressTotal;
187
- progress(current_progress, desc="Initialize reoganizer model finished")
188
- timer.checkpoint(f"Initialize reoganizer model")
189
-
190
- timer.report()
191
-
192
- prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
193
- append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
194
- if prepend_list and append_list:
195
- append_list = [item for item in append_list if item not in prepend_list]
196
-
197
- # Dictionary to track counters for each filename
198
- name_counters = defaultdict(int)
199
-
200
- for idx, value in enumerate(gallery):
201
- try:
202
- image_path = value[0]
203
- image_name = os.path.splitext(os.path.basename(image_path))[0]
204
-
205
- # Increment the counter for the current name
206
- name_counters[image_name] += 1
207
-
208
- if name_counters[image_name] > 1:
209
- image_name = f"{image_name}_{name_counters[image_name]:02d}"
210
-
211
- image = self.prepare_image(image_path)
212
-
213
- input_name = self.model.get_inputs()[0].name
214
- label_name = self.model.get_outputs()[0].name
215
- print(f"Gallery {idx:02d}: Starting run wd model...")
216
- preds = self.model.run([label_name], {input_name: image})[0]
217
-
218
- labels = list(zip(self.tag_names, preds[0].astype(float)))
219
-
220
- # First 4 labels are actually ratings: pick one with argmax
221
- ratings_names = [labels[i] for i in self.rating_indexes]
222
- rating = dict(ratings_names)
223
-
224
- # Then we have general tags: pick any where prediction confidence > threshold
225
- general_names = [labels[i] for i in self.general_indexes]
226
-
227
- if general_mcut_enabled:
228
- general_probs = np.array([x[1] for x in general_names])
229
- general_thresh = mcut_threshold(general_probs)
230
-
231
- general_res = [x for x in general_names if x[1] > general_thresh]
232
- general_res = dict(general_res)
233
-
234
- # Everything else is characters: pick any where prediction confidence > threshold
235
- character_names = [labels[i] for i in self.character_indexes]
236
-
237
- if character_mcut_enabled:
238
- character_probs = np.array([x[1] for x in character_names])
239
- character_thresh = mcut_threshold(character_probs)
240
- character_thresh = max(0.15, character_thresh)
241
-
242
- character_res = [x for x in character_names if x[1] > character_thresh]
243
- character_res = dict(character_res)
244
- character_list = list(character_res.keys())
245
-
246
- sorted_general_list = sorted(
247
- general_res.items(),
248
- key=lambda x: x[1],
249
- reverse=True,
250
- )
251
- sorted_general_list = [x[0] for x in sorted_general_list]
252
- # Remove values from character_list that already exist in sorted_general_list
253
- character_list = [item for item in character_list if item not in sorted_general_list]
254
- # Remove values from sorted_general_list that already exist in prepend_list or append_list
255
- if prepend_list:
256
- sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
257
- if append_list:
258
- sorted_general_list = [item for item in sorted_general_list if item not in append_list]
259
-
260
- sorted_general_list = prepend_list + sorted_general_list + append_list
261
-
262
- sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
263
-
264
- classified_tags, unclassified_tags = classify_tags(sorted_general_list)
265
-
266
- # Create a single string of ALL categorized tags for the current image
267
- categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()])
268
- categorized_output_strings.append(categorized_output_string)
269
- # Collect all categorized output strings into a single string
270
- final_categorized_output = ', '.join(categorized_output_strings)
271
-
272
- # Create a .txt file for "Output (string)" and "Categorized Output (string)"
273
- txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}"
274
- txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
275
- txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"})
276
-
277
- # Create a .json file for "Categorized (tags)"
278
- json_content = json.dumps(classified_tags, indent=4)
279
- json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json")
280
- txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"})
281
-
282
- # Save a copy of the uploaded image in PNG format
283
- image_path = value[0]
284
- image = Image.open(image_path)
285
- image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG")
286
- txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"})
287
-
288
- current_progress += progressRatio/progressTotal;
289
- progress(current_progress, desc=f"image{idx:02d}, predict finished")
290
- timer.checkpoint(f"image{idx:02d}, predict finished")
291
-
292
- if reorganizer_model_repo:
293
- print(f"Starting reorganizer...")
294
- reorganize_strings = reorganizer.reorganize(sorted_general_strings)
295
- reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
296
- reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
297
- reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
298
- sorted_general_strings += ",\n\n" + reorganize_strings
299
-
300
- current_progress += progressRatio/progressTotal;
301
- progress(current_progress, desc=f"image{idx:02d}, reorganizer finished")
302
- timer.checkpoint(f"image{idx:02d}, reorganizer finished")
303
-
304
- txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
305
- txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
306
-
307
- # Store the result in tag_results using image_path as the key
308
- tag_results[image_path] = {
309
- "strings": sorted_general_strings,
310
- "strings2": categorized_output_string, # Store the categorized output string here
311
- "classified_tags": classified_tags,
312
- "rating": rating,
313
- "character_res": character_res,
314
- "general_res": general_res,
315
- "unclassified_tags": unclassified_tags,
316
- "enhanced_tags": "" # Initialize as empty string
317
- }
318
-
319
- timer.report()
320
- except Exception as e:
321
- print(traceback.format_exc())
322
- print("Error predict: " + str(e))
323
- # Zip creation logic:
324
- download = []
325
- if txt_infos is not None and len(txt_infos) > 0:
326
- downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
327
- with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
328
- for info in txt_infos:
329
- # Get file name from lookup
330
- taggers_zip.write(info["path"], arcname=info["name"])
331
- download.append(downloadZipPath)
332
- # End zip creation logic
333
- if reorganizer_model_repo:
334
- reorganizer.release_vram()
335
- del reorganizer
336
-
337
- progress(1, desc=f"Predict completed")
338
- timer.report_all() # Print all recorded times
339
- print("Predict is complete.")
340
-
341
- return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
342
- def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
343
- if not selected_state:
344
- return selected_state
345
- tag_result = {
346
- "strings": "",
347
- "strings2": "",
348
- "classified_tags": "{}",
349
- "rating": "",
350
- "character_res": "",
351
- "general_res": "",
352
- "unclassified_tags": "{}",
353
- "enhanced_tags": ""
354
- }
355
- if selected_state.value["image"]["path"] in tag_results:
356
- tag_result = tag_results[selected_state.value["image"]["path"]]
357
- return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
358
- def append_gallery(gallery:list,image:str):
359
- if gallery is None:gallery=[]
360
- if not image:return gallery,None
361
- gallery.append(image);return gallery,None
362
- def extend_gallery(gallery:list,images):
363
- if gallery is None:gallery=[]
364
- if not images:return gallery
365
- gallery.extend(images);return gallery
366
- def remove_image_from_gallery(gallery:list,selected_image:str):
367
- if not gallery or not selected_image:return gallery
368
- selected_image=ast.literal_eval(selected_image)
369
- if selected_image in gallery:gallery.remove(selected_image)
370
- return gallery
371
- args = parse_args()
372
- predictor = Predictor()
373
- dropdown_list = [
374
- EVA02_LARGE_MODEL_DSV3_REPO,
375
- SWINV2_MODEL_DSV3_REPO,
376
- CONV_MODEL_DSV3_REPO,
377
- VIT_MODEL_DSV3_REPO,
378
- VIT_LARGE_MODEL_DSV3_REPO,
379
- # ---
380
- MOAT_MODEL_DSV2_REPO,
381
- SWIN_MODEL_DSV2_REPO,
382
- CONV_MODEL_DSV2_REPO,
383
- CONV2_MODEL_DSV2_REPO,
384
- VIT_MODEL_DSV2_REPO,
385
- # ---
386
- SWINV2_MODEL_IS_DSV1_REPO,
387
- EVA02_LARGE_MODEL_IS_DSV1_REPO,
388
- ]
389
-
390
- def _restart_space():
391
- HF_TOKEN=os.getenv('HF_TOKEN')
392
- if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.')
393
- huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False)
394
- scheduler=BackgroundScheduler()
395
- # Add a job to restart the space every 2 days (172800 seconds)
396
- restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
397
- scheduler.start()
398
- next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
399
- NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
400
-
401
- css = """
402
- #output {height: 500px; overflow: auto; border: 1px solid #ccc;}
403
- label.float.svelte-i3tvor {position: relative !important;}
404
- .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
405
- """
406
-
407
- with gr.Blocks(title=TITLE, css=css, theme=gr.themes.Soft(), fill_width=True) as demo:
408
- gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
409
- gr.Markdown(value=DESCRIPTION)
410
- gr.Markdown(NEXT_RESTART)
411
- with gr.Tab(label="Waifu Diffusion"):
412
- with gr.Row():
413
- with gr.Column():
414
- submit = gr.Button(value="Submit", variant="primary", size="lg")
415
- with gr.Column(variant="panel"):
416
- # Create an Image component for uploading images
417
- image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
418
- with gr.Row():
419
- upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
420
- remove_button = gr.Button("Remove Selected Image", size="sm")
421
- gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
422
- model_repo = gr.Dropdown(
423
- dropdown_list,
424
- value=EVA02_LARGE_MODEL_DSV3_REPO,
425
- label="Model",
426
- )
427
- with gr.Row():
428
- general_thresh = gr.Slider(
429
- 0,
430
- 1,
431
- step=args.score_slider_step,
432
- value=args.score_general_threshold,
433
- label="General Tags Threshold",
434
- scale=3,
435
- )
436
- general_mcut_enabled = gr.Checkbox(
437
- value=False,
438
- label="Use MCut threshold",
439
- scale=1,
440
- )
441
- with gr.Row():
442
- character_thresh = gr.Slider(
443
- 0,
444
- 1,
445
- step=args.score_slider_step,
446
- value=args.score_character_threshold,
447
- label="Character Tags Threshold",
448
- scale=3,
449
- )
450
- character_mcut_enabled = gr.Checkbox(
451
- value=False,
452
- label="Use MCut threshold",
453
- scale=1,
454
- )
455
- with gr.Row():
456
- characters_merge_enabled = gr.Checkbox(
457
- value=True,
458
- label="Merge characters into the string output",
459
- scale=1,
460
- )
461
- with gr.Row():
462
- reorganizer_model_repo = gr.Dropdown(
463
- [None] + reorganizer_list,
464
- value=None,
465
- label="Reorganizer Model",
466
- info="Use a model to create a description for you",
467
- )
468
- with gr.Row():
469
- additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
470
- additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
471
- with gr.Row():
472
- clear = gr.ClearButton(
473
- components=[
474
- gallery,
475
- model_repo,
476
- general_thresh,
477
- general_mcut_enabled,
478
- character_thresh,
479
- character_mcut_enabled,
480
- characters_merge_enabled,
481
- reorganizer_model_repo,
482
- additional_tags_prepend,
483
- additional_tags_append,
484
- ],
485
- variant="secondary",
486
- size="lg",
487
- )
488
- with gr.Column(variant="panel"):
489
- download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
490
- character_res = gr.Label(label="Output (characters)") # 1
491
- sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
492
- final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
493
- pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
494
- enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
495
- prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
496
- categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
497
- rating = gr.Label(label="Rating") # 8
498
- general_res = gr.Label(label="Output (tags)") # 9
499
- unclassified = gr.JSON(label="Unclassified (tags)") # 10
500
- clear.add(
501
- [
502
- download_file,
503
- sorted_general_strings,
504
- final_categorized_output,
505
- categorized,
506
- rating,
507
- character_res,
508
- general_res,
509
- unclassified,
510
- prompt_enhancer_model,
511
- enhanced_tags,
512
- ]
513
- )
514
- tag_results = gr.State({})
515
- # Define the event listener to add the uploaded image to the gallery
516
- image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
517
- # When the upload button is clicked, add the new images to the gallery
518
- upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
519
- # Event to update the selected image when an image is clicked in the gallery
520
- selected_image = gr.Textbox(label="Selected Image", visible=False)
521
- gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
522
- # Event to remove a selected image from the gallery
523
- remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
524
- # Event to for the Prompt Enhancer Button
525
- pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
526
- submit.click(
527
- predictor.predict,
528
- inputs=[
529
- gallery,
530
- model_repo,
531
- general_thresh,
532
- general_mcut_enabled,
533
- character_thresh,
534
- character_mcut_enabled,
535
- characters_merge_enabled,
536
- reorganizer_model_repo,
537
- additional_tags_prepend,
538
- additional_tags_append,
539
- tag_results,
540
- ],
541
- outputs=[download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, tag_results,],
542
- )
543
- gr.Examples(
544
- [["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
545
- inputs=[
546
- image_input,
547
- model_repo,
548
- general_thresh,
549
- general_mcut_enabled,
550
- character_thresh,
551
- character_mcut_enabled,
552
- ],
553
- )
554
- with gr.Tab(label="Florence 2 Image Captioning"):
555
- with gr.Row():
556
- with gr.Column(variant="panel"):
557
- input_img = gr.Image(label="Input Picture")
558
- task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
559
- task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
560
- task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
561
- text_input = gr.Textbox(label="Text Input (optional)")
562
- submit_btn = gr.Button(value="Submit")
563
- with gr.Column(variant="panel"):
564
- output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8)
565
- output_img = gr.Image(label="Output Image")
566
- gr.Examples(
567
- examples=[
568
- ["images/image1.png", 'Object Detection'],
569
- ["images/image2.png", 'OCR with Region']
570
- ],
571
- inputs=[input_img, task_prompt],
572
- outputs=[output_text, output_img],
573
- fn=process_image,
574
- cache_examples=False,
575
- label='Try examples'
576
- )
577
- submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
578
- with gr.Tab(label="Gelbooru Image Fetcher"):
579
- with gr.Row():
580
- with gr.Column():
581
- gr.Markdown("### ⚙️ Search Parameters")
582
- site = gr.Dropdown(label="Select Source", choices=["Gelbooru", "None (will not work)"], value="Gelbooru")
583
- OR_tags = gr.Textbox(label="OR Tags (comma-separated)", placeholder="e.g. solo, 1girl, 1boy, artist, character, ...")
584
- AND_tags = gr.Textbox(label="AND Tags (comma-separated)", placeholder="e.g. black hair, cat ears, holding, granblue fantasy, ...")
585
- exclude_tags = gr.Textbox(label="Exclude Tags (comma-separated)", placeholder="e.g. animated, watermark, username, ...")
586
- score = gr.Number(label="Minimum Score", value=0)
587
- count = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1) # Increase if necessary (not recommend)
588
- Safe = gr.Checkbox(label="Include Safe", value=True)
589
- Questionable = gr.Checkbox(label="Include Questionable", value=True)
590
- Explicit = gr.Checkbox(label="Include Explicit", value=False)
591
- #user_id = gr.Textbox(label="User ID (Optional)", value="")
592
- #api_key = gr.Textbox(label="API Key (Optional)", value="", type="password")
593
-
594
- submit_btn = gr.Button("Fetch Images", variant="primary")
595
-
596
- with gr.Column():
597
- gr.Markdown("### 📄 Results")
598
- images_output = gr.Gallery(label="Images", columns=3, rows=2, object_fit="contain", height=500)
599
- tags_output = gr.Textbox(label="Tags", placeholder="Select an image to show tags", lines=5, show_copy_button=True)
600
- post_url_output = gr.Textbox(label="Post URL", lines=1, show_copy_button=True)
601
- image_url_output = gr.Textbox(label="Image URL", lines=1, show_copy_button=True)
602
-
603
- # State to store tags, URLs
604
- tags_state = gr.State([])
605
- post_url_state = gr.State([])
606
- image_url_state = gr.State([])
607
-
608
- submit_btn.click(
609
- fn=gelbooru_gradio,
610
- inputs=[OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit, site], # add 'api_key' and 'user_id' if necessary
611
- outputs=[images_output, tags_state, post_url_state, image_url_state],
612
- )
613
-
614
- images_output.select(
615
- fn=on_select,
616
- inputs=[tags_state, post_url_state, image_url_state],
617
- outputs=[tags_output, post_url_output, image_url_output],
618
- )
619
- gr.Markdown("""
620
- ---
621
- ComfyUI version: [Comfyui-Gelbooru](https://github.com/1mckw/Comfyui-Gelbooru)
622
- """)
623
- with gr.Tab(label="Categorizer++"):
624
- with gr.Row():
625
- with gr.Column(variant="panel"):
626
- input_tags = gr.Textbox(label="Input Tags", placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...", lines=4)
627
- submit_button = gr.Button(value="Submit", variant="primary", size="lg")
628
- with gr.Column(variant="panel"):
629
- categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
630
- categorized_json = gr.JSON(label="Categorized (tags) - JSON")
631
- submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
632
- with gr.Column(variant="panel"):
633
- pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
634
- enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
635
- prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
636
- pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
637
  demo.queue(max_size=2).launch()
 
1
+ import os,io,copy,json,requests,spaces,gradio as gr,numpy as np
2
+ import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
3
+ from datetime import datetime,timezone
4
+ from collections import defaultdict
5
+ from PIL import Image,ImageOps
6
+ from modules.booru import booru_gradio,on_select
7
+ from apscheduler.schedulers.background import BackgroundScheduler
8
+ from modules.classifyTags import classify_tags,process_tags
9
+ from modules.reorganizer_model import reorganizer_list,reorganizer_class
10
+ from modules.tag_enhancer import prompt_enhancer
11
+ from modules.florence2 import process_image,single_task_list,update_task_dropdown
12
+
13
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
14
+
15
+ TITLE = "Multi-Tagger v1.2"
16
+ DESCRIPTION = """
17
+ Multi-Tagger is a versatile application for advanced image analysis and captioning. Perfect for AI artists or enthusiasts, it offers a range of features:
18
+
19
+ - **Automatic Tag Categorization**: Tags are grouped into categories.
20
+ - **Tag Enhancement**: Boost your prompts with enhanced descriptions using a built-in prompt enhancer.
21
+ - **Reorganizer**: Use a reorganizer model to format tags into a natural-language description.
22
+ - **Batch Support**: Upload and process multiple images simultaneously.
23
+ - **Downloadable Output**: Get almost all results as downloadable `.txt`, `.json`, and `.png` files in a `.zip` archive.
24
+ - **Image Fetcher**: Search for images from **Gelbooru** using flexible tag filters.
25
+ - **CUDA** and **CPU** support.
26
+ """
27
+
28
+ # Dataset v3 series of models:
29
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
30
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
31
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
32
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
33
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
34
+ # Dataset v2 series of models:
35
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
36
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
37
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
38
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
39
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
40
+ # IdolSankaku series of models:
41
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
42
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
43
+ # Files to download from the repos
44
+ MODEL_FILENAME = "model.onnx"
45
+ LABEL_FILENAME = "selected_tags.csv"
46
+
47
+ kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
48
+ def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
49
+ def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
50
+ def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
51
+
52
+ class Timer:
53
+ def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
54
+ def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
55
+ def report(self,is_clear_checkpoints=True):
56
+ max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1]
57
+ for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
58
+ if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint()
59
+ def report_all(self):
60
+ print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time
61
+ for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
62
+ total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
63
+ def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
64
+ class Predictor:
65
+ def __init__(self):
66
+ self.model_target_size = None
67
+ self.last_loaded_repo = None
68
+ def download_model(self, model_repo):
69
+ csv_path = huggingface_hub.hf_hub_download(
70
+ model_repo,
71
+ LABEL_FILENAME,
72
+ )
73
+ model_path = huggingface_hub.hf_hub_download(
74
+ model_repo,
75
+ MODEL_FILENAME,
76
+ )
77
+ return csv_path, model_path
78
+ def load_model(self, model_repo):
79
+ if model_repo == self.last_loaded_repo:
80
+ return
81
+
82
+ csv_path, model_path = self.download_model(model_repo)
83
+
84
+ tags_df = pd.read_csv(csv_path)
85
+ sep_tags = load_labels(tags_df)
86
+
87
+ self.tag_names = sep_tags[0]
88
+ self.rating_indexes = sep_tags[1]
89
+ self.general_indexes = sep_tags[2]
90
+ self.character_indexes = sep_tags[3]
91
+
92
+ model = rt.InferenceSession(model_path)
93
+ _, height, width, _ = model.get_inputs()[0].shape
94
+ self.model_target_size = height
95
+
96
+ self.last_loaded_repo = model_repo
97
+ self.model = model
98
+ def prepare_image(self, path):
99
+ image = Image.open(path)
100
+ image = image.convert("RGBA")
101
+ target_size = self.model_target_size
102
+
103
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
104
+ canvas.alpha_composite(image)
105
+ image = canvas.convert("RGB")
106
+
107
+ # Pad image to square
108
+ image_shape = image.size
109
+ max_dim = max(image_shape)
110
+ pad_left = (max_dim - image_shape[0]) // 2
111
+ pad_top = (max_dim - image_shape[1]) // 2
112
+
113
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
114
+ padded_image.paste(image, (pad_left, pad_top))
115
+
116
+ # Resize
117
+ if max_dim != target_size:
118
+ padded_image = padded_image.resize(
119
+ (target_size, target_size),
120
+ Image.BICUBIC,
121
+ )
122
+ # Convert to numpy array
123
+ image_array = np.asarray(padded_image, dtype=np.float32)
124
+ # Convert PIL-native RGB to BGR
125
+ image_array = image_array[:, :, ::-1]
126
+ return np.expand_dims(image_array, axis=0)
127
+
128
+ def create_file(self, content: str, directory: str, fileName: str) -> str:
129
+ # Write the content to a file
130
+ file_path = os.path.join(directory, fileName)
131
+ if fileName.endswith('.json'):
132
+ with open(file_path, 'w', encoding="utf-8") as file:
133
+ file.write(content)
134
+ else:
135
+ with open(file_path, 'w+', encoding="utf-8") as file:
136
+ file.write(content)
137
+
138
+ return file_path
139
+
140
+ def predict(
141
+ self,
142
+ gallery,
143
+ model_repo,
144
+ general_thresh,
145
+ general_mcut_enabled,
146
+ character_thresh,
147
+ character_mcut_enabled,
148
+ characters_merge_enabled,
149
+ reorganizer_model_repo,
150
+ additional_tags_prepend,
151
+ additional_tags_append,
152
+ tag_results,
153
+ progress=gr.Progress()
154
+ ):
155
+ # Clear tag_results before starting a new prediction
156
+ tag_results.clear()
157
+
158
+ gallery_len = len(gallery)
159
+ print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
160
+
161
+ timer = Timer() # Create a timer
162
+ progressRatio = 0.5 if reorganizer_model_repo else 1
163
+ progressTotal = gallery_len + 1
164
+ current_progress = 0
165
+
166
+ self.load_model(model_repo)
167
+ current_progress += progressRatio/progressTotal;
168
+ progress(current_progress, desc="Initialize wd model finished")
169
+ timer.checkpoint(f"Initialize wd model")
170
+
171
+ txt_infos = []
172
+ output_dir = tempfile.mkdtemp()
173
+ if not os.path.exists(output_dir):
174
+ os.makedirs(output_dir)
175
+
176
+ sorted_general_strings = ""
177
+ # Create categorized output string
178
+ categorized_output_strings = []
179
+ rating = None
180
+ character_res = None
181
+ general_res = None
182
+
183
+ if reorganizer_model_repo:
184
+ print(f"Reorganizer load model {reorganizer_model_repo}")
185
+ reorganizer = reorganizer_class(reorganizer_model_repo, loadModel=True)
186
+ current_progress += progressRatio/progressTotal;
187
+ progress(current_progress, desc="Initialize reoganizer model finished")
188
+ timer.checkpoint(f"Initialize reoganizer model")
189
+
190
+ timer.report()
191
+
192
+ prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
193
+ append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
194
+ if prepend_list and append_list:
195
+ append_list = [item for item in append_list if item not in prepend_list]
196
+
197
+ # Dictionary to track counters for each filename
198
+ name_counters = defaultdict(int)
199
+
200
+ for idx, value in enumerate(gallery):
201
+ try:
202
+ image_path = value[0]
203
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
204
+
205
+ # Increment the counter for the current name
206
+ name_counters[image_name] += 1
207
+
208
+ if name_counters[image_name] > 1:
209
+ image_name = f"{image_name}_{name_counters[image_name]:02d}"
210
+
211
+ image = self.prepare_image(image_path)
212
+
213
+ input_name = self.model.get_inputs()[0].name
214
+ label_name = self.model.get_outputs()[0].name
215
+ print(f"Gallery {idx:02d}: Starting run wd model...")
216
+ preds = self.model.run([label_name], {input_name: image})[0]
217
+
218
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
219
+
220
+ # First 4 labels are actually ratings: pick one with argmax
221
+ ratings_names = [labels[i] for i in self.rating_indexes]
222
+ rating = dict(ratings_names)
223
+
224
+ # Then we have general tags: pick any where prediction confidence > threshold
225
+ general_names = [labels[i] for i in self.general_indexes]
226
+
227
+ if general_mcut_enabled:
228
+ general_probs = np.array([x[1] for x in general_names])
229
+ general_thresh = mcut_threshold(general_probs)
230
+
231
+ general_res = [x for x in general_names if x[1] > general_thresh]
232
+ general_res = dict(general_res)
233
+
234
+ # Everything else is characters: pick any where prediction confidence > threshold
235
+ character_names = [labels[i] for i in self.character_indexes]
236
+
237
+ if character_mcut_enabled:
238
+ character_probs = np.array([x[1] for x in character_names])
239
+ character_thresh = mcut_threshold(character_probs)
240
+ character_thresh = max(0.15, character_thresh)
241
+
242
+ character_res = [x for x in character_names if x[1] > character_thresh]
243
+ character_res = dict(character_res)
244
+ character_list = list(character_res.keys())
245
+
246
+ sorted_general_list = sorted(
247
+ general_res.items(),
248
+ key=lambda x: x[1],
249
+ reverse=True,
250
+ )
251
+ sorted_general_list = [x[0] for x in sorted_general_list]
252
+ # Remove values from character_list that already exist in sorted_general_list
253
+ character_list = [item for item in character_list if item not in sorted_general_list]
254
+ # Remove values from sorted_general_list that already exist in prepend_list or append_list
255
+ if prepend_list:
256
+ sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
257
+ if append_list:
258
+ sorted_general_list = [item for item in sorted_general_list if item not in append_list]
259
+
260
+ sorted_general_list = prepend_list + sorted_general_list + append_list
261
+
262
+ sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
263
+
264
+ classified_tags, unclassified_tags = classify_tags(sorted_general_list)
265
+
266
+ # Create a single string of ALL categorized tags for the current image
267
+ categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()])
268
+ categorized_output_strings.append(categorized_output_string)
269
+ # Collect all categorized output strings into a single string
270
+ final_categorized_output = ', '.join(categorized_output_strings)
271
+
272
+ # Create a .txt file for "Output (string)" and "Categorized Output (string)"
273
+ txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}"
274
+ txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
275
+ txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"})
276
+
277
+ # Create a .json file for "Categorized (tags)"
278
+ json_content = json.dumps(classified_tags, indent=4)
279
+ json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json")
280
+ txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"})
281
+
282
+ # Save a copy of the uploaded image in PNG format
283
+ image_path = value[0]
284
+ image = Image.open(image_path)
285
+ image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG")
286
+ txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"})
287
+
288
+ current_progress += progressRatio/progressTotal;
289
+ progress(current_progress, desc=f"image{idx:02d}, predict finished")
290
+ timer.checkpoint(f"image{idx:02d}, predict finished")
291
+
292
+ if reorganizer_model_repo:
293
+ print(f"Starting reorganizer...")
294
+ reorganize_strings = reorganizer.reorganize(sorted_general_strings)
295
+ reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
296
+ reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
297
+ reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
298
+ sorted_general_strings += ",\n\n" + reorganize_strings
299
+
300
+ current_progress += progressRatio/progressTotal;
301
+ progress(current_progress, desc=f"image{idx:02d}, reorganizer finished")
302
+ timer.checkpoint(f"image{idx:02d}, reorganizer finished")
303
+
304
+ txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
305
+ txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
306
+
307
+ # Store the result in tag_results using image_path as the key
308
+ tag_results[image_path] = {
309
+ "strings": sorted_general_strings,
310
+ "strings2": categorized_output_string, # Store the categorized output string here
311
+ "classified_tags": classified_tags,
312
+ "rating": rating,
313
+ "character_res": character_res,
314
+ "general_res": general_res,
315
+ "unclassified_tags": unclassified_tags,
316
+ "enhanced_tags": "" # Initialize as empty string
317
+ }
318
+
319
+ timer.report()
320
+ except Exception as e:
321
+ print(traceback.format_exc())
322
+ print("Error predict: " + str(e))
323
+ # Zip creation logic:
324
+ download = []
325
+ if txt_infos is not None and len(txt_infos) > 0:
326
+ downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
327
+ with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
328
+ for info in txt_infos:
329
+ # Get file name from lookup
330
+ taggers_zip.write(info["path"], arcname=info["name"])
331
+ download.append(downloadZipPath)
332
+ # End zip creation logic
333
+ if reorganizer_model_repo:
334
+ reorganizer.release_vram()
335
+ del reorganizer
336
+
337
+ progress(1, desc=f"Predict completed")
338
+ timer.report_all() # Print all recorded times
339
+ print("Predict is complete.")
340
+
341
+ return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
342
+ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
343
+ if not selected_state:
344
+ return selected_state
345
+ tag_result = {
346
+ "strings": "",
347
+ "strings2": "",
348
+ "classified_tags": "{}",
349
+ "rating": "",
350
+ "character_res": "",
351
+ "general_res": "",
352
+ "unclassified_tags": "{}",
353
+ "enhanced_tags": ""
354
+ }
355
+ if selected_state.value["image"]["path"] in tag_results:
356
+ tag_result = tag_results[selected_state.value["image"]["path"]]
357
+ return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
358
+ def append_gallery(gallery:list,image:str):
359
+ if gallery is None:gallery=[]
360
+ if not image:return gallery,None
361
+ gallery.append(image);return gallery,None
362
+ def extend_gallery(gallery:list,images):
363
+ if gallery is None:gallery=[]
364
+ if not images:return gallery
365
+ gallery.extend(images);return gallery
366
+ def remove_image_from_gallery(gallery:list,selected_image:str):
367
+ if not gallery or not selected_image:return gallery
368
+ selected_image=ast.literal_eval(selected_image)
369
+ if selected_image in gallery:gallery.remove(selected_image)
370
+ return gallery
371
+ args = parse_args()
372
+ predictor = Predictor()
373
+ dropdown_list = [
374
+ EVA02_LARGE_MODEL_DSV3_REPO,
375
+ SWINV2_MODEL_DSV3_REPO,
376
+ CONV_MODEL_DSV3_REPO,
377
+ VIT_MODEL_DSV3_REPO,
378
+ VIT_LARGE_MODEL_DSV3_REPO,
379
+ # ---
380
+ MOAT_MODEL_DSV2_REPO,
381
+ SWIN_MODEL_DSV2_REPO,
382
+ CONV_MODEL_DSV2_REPO,
383
+ CONV2_MODEL_DSV2_REPO,
384
+ VIT_MODEL_DSV2_REPO,
385
+ # ---
386
+ SWINV2_MODEL_IS_DSV1_REPO,
387
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
388
+ ]
389
+
390
+ def _restart_space():
391
+ HF_TOKEN=os.getenv('HF_TOKEN')
392
+ if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.')
393
+ huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False)
394
+ scheduler=BackgroundScheduler()
395
+ # Add a job to restart the space every 2 days (172800 seconds)
396
+ restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
397
+ scheduler.start()
398
+ next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
399
+ NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
400
+
401
+ css = """
402
+ #output {height: 500px; overflow: auto; border: 1px solid #ccc;}
403
+ label.float.svelte-i3tvor {position: relative !important;}
404
+ .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
405
+ """
406
+
407
+ with gr.Blocks(title=TITLE, css=css, theme=gr.themes.Soft(), fill_width=True) as demo:
408
+ gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
409
+ gr.Markdown(value=DESCRIPTION)
410
+ gr.Markdown(NEXT_RESTART)
411
+ with gr.Tab(label="Waifu Diffusion"):
412
+ with gr.Row():
413
+ with gr.Column():
414
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
415
+ with gr.Column(variant="panel"):
416
+ # Create an Image component for uploading images
417
+ image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
418
+ with gr.Row():
419
+ upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
420
+ remove_button = gr.Button("Remove Selected Image", size="sm")
421
+ gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
422
+ model_repo = gr.Dropdown(
423
+ dropdown_list,
424
+ value=EVA02_LARGE_MODEL_DSV3_REPO,
425
+ label="Model",
426
+ )
427
+ with gr.Row():
428
+ general_thresh = gr.Slider(
429
+ 0,
430
+ 1,
431
+ step=args.score_slider_step,
432
+ value=args.score_general_threshold,
433
+ label="General Tags Threshold",
434
+ scale=3,
435
+ )
436
+ general_mcut_enabled = gr.Checkbox(
437
+ value=False,
438
+ label="Use MCut threshold",
439
+ scale=1,
440
+ )
441
+ with gr.Row():
442
+ character_thresh = gr.Slider(
443
+ 0,
444
+ 1,
445
+ step=args.score_slider_step,
446
+ value=args.score_character_threshold,
447
+ label="Character Tags Threshold",
448
+ scale=3,
449
+ )
450
+ character_mcut_enabled = gr.Checkbox(
451
+ value=False,
452
+ label="Use MCut threshold",
453
+ scale=1,
454
+ )
455
+ with gr.Row():
456
+ characters_merge_enabled = gr.Checkbox(
457
+ value=True,
458
+ label="Merge characters into the string output",
459
+ scale=1,
460
+ )
461
+ with gr.Row():
462
+ reorganizer_model_repo = gr.Dropdown(
463
+ [None] + reorganizer_list,
464
+ value=None,
465
+ label="Reorganizer Model",
466
+ info="Use a model to create a description for you",
467
+ )
468
+ with gr.Row():
469
+ additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
470
+ additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
471
+ with gr.Row():
472
+ clear = gr.ClearButton(
473
+ components=[
474
+ gallery,
475
+ model_repo,
476
+ general_thresh,
477
+ general_mcut_enabled,
478
+ character_thresh,
479
+ character_mcut_enabled,
480
+ characters_merge_enabled,
481
+ reorganizer_model_repo,
482
+ additional_tags_prepend,
483
+ additional_tags_append,
484
+ ],
485
+ variant="secondary",
486
+ size="lg",
487
+ )
488
+ with gr.Column(variant="panel"):
489
+ download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
490
+ character_res = gr.Label(label="Output (characters)") # 1
491
+ sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
492
+ final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
493
+ pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
494
+ enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
495
+ prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
496
+ categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
497
+ rating = gr.Label(label="Rating") # 8
498
+ general_res = gr.Label(label="Output (tags)") # 9
499
+ unclassified = gr.JSON(label="Unclassified (tags)") # 10
500
+ clear.add(
501
+ [
502
+ download_file,
503
+ sorted_general_strings,
504
+ final_categorized_output,
505
+ categorized,
506
+ rating,
507
+ character_res,
508
+ general_res,
509
+ unclassified,
510
+ prompt_enhancer_model,
511
+ enhanced_tags,
512
+ ]
513
+ )
514
+ tag_results = gr.State({})
515
+ # Define the event listener to add the uploaded image to the gallery
516
+ image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
517
+ # When the upload button is clicked, add the new images to the gallery
518
+ upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
519
+ # Event to update the selected image when an image is clicked in the gallery
520
+ selected_image = gr.Textbox(label="Selected Image", visible=False)
521
+ gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
522
+ # Event to remove a selected image from the gallery
523
+ remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
524
+ # Event to for the Prompt Enhancer Button
525
+ pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
526
+ submit.click(
527
+ predictor.predict,
528
+ inputs=[
529
+ gallery,
530
+ model_repo,
531
+ general_thresh,
532
+ general_mcut_enabled,
533
+ character_thresh,
534
+ character_mcut_enabled,
535
+ characters_merge_enabled,
536
+ reorganizer_model_repo,
537
+ additional_tags_prepend,
538
+ additional_tags_append,
539
+ tag_results,
540
+ ],
541
+ outputs=[download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, tag_results,],
542
+ )
543
+ gr.Examples(
544
+ [["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
545
+ inputs=[
546
+ image_input,
547
+ model_repo,
548
+ general_thresh,
549
+ general_mcut_enabled,
550
+ character_thresh,
551
+ character_mcut_enabled,
552
+ ],
553
+ )
554
+ with gr.Tab(label="Florence 2 Image Captioning"):
555
+ with gr.Row():
556
+ with gr.Column(variant="panel"):
557
+ input_img = gr.Image(label="Input Picture")
558
+ task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
559
+ task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
560
+ task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
561
+ text_input = gr.Textbox(label="Text Input (optional)")
562
+ submit_btn = gr.Button(value="Submit")
563
+ with gr.Column(variant="panel"):
564
+ output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8)
565
+ output_img = gr.Image(label="Output Image")
566
+ gr.Examples(
567
+ examples=[
568
+ ["images/image1.png", 'Object Detection'],
569
+ ["images/image2.png", 'OCR with Region']
570
+ ],
571
+ inputs=[input_img, task_prompt],
572
+ outputs=[output_text, output_img],
573
+ fn=process_image,
574
+ cache_examples=False,
575
+ label='Try examples'
576
+ )
577
+ submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
578
+ with gr.Tab("Booru Image Fetcher"):
579
+ with gr.Row():
580
+ with gr.Column():
581
+ gr.Markdown("### ⚙️ Search Parameters")
582
+ site = gr.Dropdown(label="Select Source", choices=["Gelbooru", "Rule34", "Xbooru"], value="Gelbooru")
583
+ Tags = gr.Textbox(label="Tags (comma-separated)", placeholder="e.g. solo, 1girl, 1boy, artist name, character, black hair, cat ears, holding, granblue fantasy, ...")
584
+ exclude_tags = gr.Textbox(label="Exclude Tags (comma-separated)", placeholder="e.g. animated, watermark, username, ...")
585
+ score = gr.Number(label="Minimum Score", value=0)
586
+ count = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1)
587
+ Safe = gr.Checkbox(label="Include Safe", value=True)
588
+ Questionable = gr.Checkbox(label="Include Questionable", value=True)
589
+ Explicit = gr.Checkbox(label="Include Explicit", value=False)
590
+ submit_btn = gr.Button("Fetch Images", variant="primary")
591
+
592
+ with gr.Column():
593
+ gr.Markdown("### 📄 Results")
594
+ images_output = gr.Gallery(label="Images", columns=3, rows=2, object_fit="contain", height=500)
595
+ tags_output = gr.Textbox(label="Tags", placeholder="Select an image to show tags", lines=5, show_copy_button=True)
596
+ post_url_output = gr.Textbox(label="Post URL", lines=1, show_copy_button=True)
597
+ image_url_output = gr.Textbox(label="Image URL", lines=1, show_copy_button=True)
598
+
599
+ # State to store tags, URLs
600
+ tags_state = gr.State([])
601
+ post_url_state = gr.State([])
602
+ image_url_state = gr.State([])
603
+
604
+ submit_btn.click(
605
+ fn=booru_gradio,
606
+ inputs=[Tags, exclude_tags, score, count, Safe, Questionable, Explicit, site],
607
+ outputs=[images_output, tags_state, post_url_state, image_url_state],
608
+ )
609
+
610
+ images_output.select(
611
+ fn=on_select,
612
+ inputs=[tags_state, post_url_state, image_url_state],
613
+ outputs=[tags_output, post_url_output, image_url_output],
614
+ )
615
+ gr.Markdown("""
616
+ ---
617
+ ComfyUI version: [Comfyui-Gelbooru](https://github.com/1mckw/Comfyui-Gelbooru)
618
+ """)
619
+ with gr.Tab(label="Categorizer++"):
620
+ with gr.Row():
621
+ with gr.Column(variant="panel"):
622
+ input_tags = gr.Textbox(label="Input Tags", placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...", lines=4)
623
+ submit_button = gr.Button(value="Submit", variant="primary", size="lg")
624
+ with gr.Column(variant="panel"):
625
+ categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
626
+ categorized_json = gr.JSON(label="Categorized (tags) - JSON")
627
+ submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
628
+ with gr.Column(variant="panel"):
629
+ pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
630
+ enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
631
+ prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
632
+ pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
 
 
 
 
633
  demo.queue(max_size=2).launch()
modules/booru.py CHANGED
@@ -1,132 +1,111 @@
1
- import requests,re,base64,io,numpy as np
2
- from PIL import Image,ImageOps
3
- import torch,gradio as gr
4
-
5
- # Helper to load image from URL
6
- def loadImageFromUrl(url):
7
- if url.startswith("data:image/"):
8
- i = Image.open(io.BytesIO(base64.b64decode(url.split(",")[1])))
9
- elif url.startswith("s3://"):
10
- raise Exception("S3 URLs not supported in this interface")
11
- else:
12
- response = requests.get(url, timeout=5)
13
- if response.status_code != 200:
14
- raise Exception(response.text)
15
- i = Image.open(io.BytesIO(response.content))
16
-
17
- i = ImageOps.exif_transpose(i)
18
- if i.mode != "RGBA":
19
- i = i.convert("RGBA")
20
-
21
- alpha = i.split()[-1]
22
- image = Image.new("RGB", i.size, (0, 0, 0))
23
- image.paste(i, mask=alpha)
24
-
25
- image = np.array(image).astype(np.float32) / 255.0
26
- image = torch.from_numpy(image)[None,]
27
- return image
28
-
29
- # Fetch data from Gelbooru or None
30
- def fetch_gelbooru_images(site, OR_tags, AND_tags, exclude_tag, score, count, Safe, Questionable, Explicit): # add 'api_key' and 'user_id' if necessary
31
- # AND_tags
32
- AND_tags = AND_tags.rstrip(',').rstrip(' ')
33
- AND_tags = AND_tags.split(',')
34
- AND_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in AND_tags]
35
- AND_tags = [item for item in AND_tags if item]
36
- if len(AND_tags) > 1:
37
- AND_tags = '+'.join(AND_tags)
38
- else:
39
- AND_tags = AND_tags[0] if AND_tags else ''
40
-
41
- # OR_tags
42
- OR_tags = OR_tags.rstrip(',').rstrip(' ')
43
- OR_tags = OR_tags.split(',')
44
- OR_tags = [item.strip().replace(' ', '_').replace('\\', '') for item in OR_tags]
45
- OR_tags = [item for item in OR_tags if item]
46
- if len(OR_tags) > 1:
47
- OR_tags = '{' + ' ~ '.join(OR_tags) + '}'
48
- else:
49
- OR_tags = OR_tags[0] if OR_tags else ''
50
-
51
- # Exclude tags
52
- exclude_tag = '+'.join('-' + item.strip().replace(' ', '_') for item in exclude_tag.split(','))
53
-
54
- rate_exclusion = ""
55
- if not Safe:
56
- if site == "None":
57
- rate_exclusion += "+-rating%3asafe"
58
- else:
59
- rate_exclusion += "+-rating%3ageneral"
60
- if not Questionable:
61
- if site == "None":
62
- rate_exclusion += "+-rating%3aquestionable"
63
- else:
64
- rate_exclusion += "+-rating%3aquestionable+-rating%3aSensitive"
65
- if not Explicit:
66
- if site == "None":
67
- rate_exclusion += "+-rating%3aexplicit"
68
- else:
69
- rate_exclusion += "+-rating%3aexplicit"
70
-
71
- if site == "None":
72
- base_url = "https://api.example.com/index.php"
73
- else:
74
- base_url = "https://gelbooru.com/index.php"
75
-
76
- query_params = (
77
- f"page=dapi&s=post&q=index&tags=sort%3arandom+"
78
- f"{exclude_tag}+{OR_tags}+{AND_tags}+{rate_exclusion}"
79
- f"+score%3a>{score}&limit={count}&json=1"
80
- #f"+score%3a>{score}&api_key={api_key}&user_id={user_id}&limit={count}&json=1"
81
- )
82
- url = f"{base_url}?{query_params}".replace("-+", "")
83
- url = re.sub(r"\++", "+", url)
84
-
85
- response = requests.get(url, verify=True)
86
- if site == "None":
87
- posts = response.json()
88
- else:
89
- posts = response.json().get('post', [])
90
-
91
- image_urls = [post.get("file_url", "") for post in posts]
92
- tags_list = [post.get("tags", "").replace(" ", ", ").replace("_", " ").replace("(", "\\(").replace(")", "\\)").strip() for post in posts]
93
- #tags_list = [post.get("tags", "").replace("_", " ").replace(" ", ", ").strip() for post in posts]
94
- ids_list = [str(post.get("id", "")) for post in posts]
95
-
96
- if site == "Gelbooru":
97
- post_urls = [f"https://gelbooru.com/index.php?page=post&s=view&id={id}" for id in ids_list]
98
- #else:
99
- # post_urls = [f"https://api.none.com/index.php?page=post&s=view&id={id}" for id in ids_list]
100
-
101
- return image_urls, tags_list, post_urls
102
-
103
- # Main function to fetch and return processed images
104
- def gelbooru_gradio(
105
- OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit, site # add 'api_key' and 'user_id' if necessary
106
- ):
107
- image_urls, tags_list, post_urls = fetch_gelbooru_images(
108
- site, OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit # 'api_key' and 'user_id' if necessary
109
- )
110
-
111
- if not image_urls:
112
- return [], [], [], []
113
-
114
- image_data = []
115
- for url in image_urls:
116
- try:
117
- image = loadImageFromUrl(url)
118
- image = (image * 255).clamp(0, 255).cpu().numpy().astype(np.uint8)[0]
119
- image = Image.fromarray(image)
120
- image_data.append(image)
121
- except Exception as e:
122
- print(f"Error loading image from {url}: {e}")
123
- continue
124
-
125
- return image_data, tags_list, post_urls, image_urls
126
-
127
- # Update UI on image click
128
- def on_select(evt: gr.SelectData, tags_list, post_url_list, image_url_list):
129
- idx = evt.index
130
- if idx < len(tags_list):
131
- return tags_list[idx], post_url_list[idx], image_url_list[idx]
132
  return "No tags", "", ""
 
1
+ import requests,re,base64,io,numpy as np
2
+ from PIL import Image,ImageOps
3
+ import torch,gradio as gr
4
+
5
+ # Helper to load image from URL
6
+ def loadImageFromUrl(url):
7
+ response = requests.get(url, timeout=10)
8
+ if response.status_code != 200:
9
+ raise Exception(f"Failed to load image from {url}")
10
+ i = Image.open(io.BytesIO(response.content))
11
+ i = ImageOps.exif_transpose(i)
12
+ if i.mode != "RGBA":
13
+ i = i.convert("RGBA")
14
+ alpha = i.split()[-1]
15
+ image = Image.new("RGB", i.size, (0, 0, 0))
16
+ image.paste(i, mask=alpha)
17
+ image = np.array(image).astype(np.float32) / 255.0
18
+ image = torch.from_numpy(image)[None,]
19
+ return image
20
+
21
+ # Fetch data from multiple booru platforms
22
+ def fetch_booru_images(site, Tags, exclude_tags, score, count, Safe, Questionable, Explicit):
23
+ # Clean and format tags
24
+ def clean_tag_list(tags):
25
+ return [item.strip().replace(' ', '_') for item in tags.split(',') if item.strip()]
26
+
27
+ Tags = '+'.join(clean_tag_list(Tags)) if Tags else ''
28
+ exclude_tags = '+'.join('-' + tag for tag in clean_tag_list(exclude_tags))
29
+
30
+ rating_filters = []
31
+ if not Safe:
32
+ rating_filters.extend(["rating:safe", "rating:general"])
33
+ if not Questionable:
34
+ rating_filters.extend(["rating:questionable", "rating:sensitive"])
35
+ if not Explicit:
36
+ rating_filters.append("rating:explicit")
37
+ rating_filters = '+'.join(f'-{r}' for r in rating_filters)
38
+
39
+ score_filter = f"score:>{score}"
40
+
41
+ # Build query
42
+ base_query = f"tags=sort:random+{Tags}+{exclude_tags}+{score_filter}+{rating_filters}&limit={count}&json=1"
43
+ base_query = re.sub(r"\++", "+", base_query)
44
+
45
+ # Fetch data based on site
46
+ if site == "Gelbooru":
47
+ url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&{base_query}"
48
+ response = requests.get(url).json()
49
+ posts = response.get("post", [])
50
+ elif site == "Rule34":
51
+ url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&{base_query}"
52
+ response = requests.get(url).json()
53
+ posts = response
54
+ elif site == "Xbooru":
55
+ url = f"https://xbooru.com/index.php?page=dapi&s=post&q=index&{base_query}"
56
+ response = requests.get(url).json()
57
+ posts = response
58
+ else:
59
+ return [], [], []
60
+
61
+ # Extract image URLs, tags, and post URLs
62
+ image_urls = []
63
+ tags_list = [post.get("tags", "").replace(" ", ", ").replace("_", " ").replace("(", "\\(").replace(")", "\\)").strip() for post in posts]
64
+ post_urls = []
65
+
66
+ for post in posts:
67
+ if site in ["Gelbooru", "Rule34", "Xbooru"]:
68
+ file_url = post.get("file_url")
69
+ tags = post.get("tags", "").replace(" ", ", ").strip()
70
+ post_id = post.get("id", "")
71
+ else:
72
+ continue
73
+
74
+ if file_url:
75
+ image_urls.append(file_url)
76
+ tags_list.append(tags)
77
+ if site == "Gelbooru":
78
+ post_urls.append(f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}")
79
+ elif site == "Rule34":
80
+ post_urls.append(f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}")
81
+ elif site == "Xbooru":
82
+ post_urls.append(f"https://xbooru.com/index.php?page=post&s=view&id={post_id}")
83
+
84
+ return image_urls, tags_list, post_urls
85
+
86
+ # Main function to fetch and return processed images
87
+ def booru_gradio(Tags, exclude_tags, score, count, Safe, Questionable, Explicit, site):
88
+ image_urls, tags_list, post_urls = fetch_booru_images(site, Tags, exclude_tags, score, count, Safe, Questionable, Explicit)
89
+
90
+ if not image_urls:
91
+ return [], [], [], []
92
+
93
+ image_data = []
94
+ for url in image_urls:
95
+ try:
96
+ image = loadImageFromUrl(url)
97
+ image = (image * 255).clamp(0, 255).cpu().numpy().astype(np.uint8)[0]
98
+ image = Image.fromarray(image)
99
+ image_data.append(image)
100
+ except Exception as e:
101
+ print(f"Error loading image from {url}: {e}")
102
+ continue
103
+
104
+ return image_data, tags_list, post_urls, image_urls
105
+
106
+ # Update UI on image click
107
+ def on_select(evt: gr.SelectData, tags_list, post_url_list, image_url_list):
108
+ idx = evt.index
109
+ if idx < len(tags_list):
110
+ return tags_list[idx], post_url_list[idx], image_url_list[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  return "No tags", "", ""