Werli commited on
Commit
8ad3890
·
verified ·
1 Parent(s): b3766ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +640 -640
app.py CHANGED
@@ -1,641 +1,641 @@
1
- import os
2
- import io,copy,requests,spaces,gradio as gr,numpy as np
3
- from PIL import Image, ImageOps
4
- import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
5
- from datetime import datetime,timezone
6
- from collections import defaultdict
7
- from apscheduler.schedulers.background import BackgroundScheduler
8
- import json
9
- from modules.classifyTags import classify_tags,process_tags
10
- from modules.florence2 import process_image,single_task_list,update_task_dropdown
11
- from modules.reorganizer_model import reorganizer_list,reorganizer_class
12
- from modules.tag_enhancer import prompt_enhancer
13
- from modules.booru import gelbooru_gradio,fetch_gelbooru_images,on_select
14
-
15
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
16
-
17
- TITLE = "Multi-Tagger v1.2"
18
- DESCRIPTION = """
19
- Multi-Tagger is a versatile application for advanced image analysis and captioning. Perfect for AI artists or enthusiasts, it offers a range of features:
20
-
21
- - **Automatic Tag Categorization**: Tags are grouped into categories.
22
- - **Tag Enhancement**: Boost your prompts with enhanced descriptions using a built-in prompt enhancer.
23
- - **Reorganizer**: Use a reorganizer model to format tags into a natural-language description.
24
- - **Batch Support**: Upload and process multiple images simultaneously.
25
- - **Downloadable Output**: Get almost all results as downloadable `.txt`, `.json`, and `.png` files in a `.zip` archive.
26
- - **Image Fetcher**: Search for images from **Gelbooru** using flexible tag filters.
27
- - CUDA or CPU support.
28
-
29
- Example image by [me](https://huggingface.co/Werli).
30
- """
31
-
32
- # Dataset v3 series of models:
33
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
34
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
35
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
36
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
37
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
38
- # Dataset v2 series of models:
39
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
40
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
41
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
42
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
43
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
44
- # IdolSankaku series of models:
45
- EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
46
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
47
- # Files to download from the repos
48
- MODEL_FILENAME = "model.onnx"
49
- LABEL_FILENAME = "selected_tags.csv"
50
-
51
- kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
52
- def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
53
- def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
54
- def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
55
-
56
- class Timer:
57
- def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
58
- def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
59
- def report(self,is_clear_checkpoints=True):
60
- max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1]
61
- for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
62
- if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint()
63
- def report_all(self):
64
- print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time
65
- for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
66
- total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
67
- def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
68
- class Predictor:
69
- def __init__(self):
70
- self.model_target_size = None
71
- self.last_loaded_repo = None
72
- def download_model(self, model_repo):
73
- csv_path = huggingface_hub.hf_hub_download(
74
- model_repo,
75
- LABEL_FILENAME,
76
- )
77
- model_path = huggingface_hub.hf_hub_download(
78
- model_repo,
79
- MODEL_FILENAME,
80
- )
81
- return csv_path, model_path
82
- def load_model(self, model_repo):
83
- if model_repo == self.last_loaded_repo:
84
- return
85
-
86
- csv_path, model_path = self.download_model(model_repo)
87
-
88
- tags_df = pd.read_csv(csv_path)
89
- sep_tags = load_labels(tags_df)
90
-
91
- self.tag_names = sep_tags[0]
92
- self.rating_indexes = sep_tags[1]
93
- self.general_indexes = sep_tags[2]
94
- self.character_indexes = sep_tags[3]
95
-
96
- model = rt.InferenceSession(model_path)
97
- _, height, width, _ = model.get_inputs()[0].shape
98
- self.model_target_size = height
99
-
100
- self.last_loaded_repo = model_repo
101
- self.model = model
102
- def prepare_image(self, path):
103
- image = Image.open(path)
104
- image = image.convert("RGBA")
105
- target_size = self.model_target_size
106
-
107
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
108
- canvas.alpha_composite(image)
109
- image = canvas.convert("RGB")
110
-
111
- # Pad image to square
112
- image_shape = image.size
113
- max_dim = max(image_shape)
114
- pad_left = (max_dim - image_shape[0]) // 2
115
- pad_top = (max_dim - image_shape[1]) // 2
116
-
117
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
118
- padded_image.paste(image, (pad_left, pad_top))
119
-
120
- # Resize
121
- if max_dim != target_size:
122
- padded_image = padded_image.resize(
123
- (target_size, target_size),
124
- Image.BICUBIC,
125
- )
126
- # Convert to numpy array
127
- image_array = np.asarray(padded_image, dtype=np.float32)
128
- # Convert PIL-native RGB to BGR
129
- image_array = image_array[:, :, ::-1]
130
- return np.expand_dims(image_array, axis=0)
131
-
132
- def create_file(self, content: str, directory: str, fileName: str) -> str:
133
- # Write the content to a file
134
- file_path = os.path.join(directory, fileName)
135
- if fileName.endswith('.json'):
136
- with open(file_path, 'w', encoding="utf-8") as file:
137
- file.write(content)
138
- else:
139
- with open(file_path, 'w+', encoding="utf-8") as file:
140
- file.write(content)
141
-
142
- return file_path
143
-
144
- def predict(
145
- self,
146
- gallery,
147
- model_repo,
148
- general_thresh,
149
- general_mcut_enabled,
150
- character_thresh,
151
- character_mcut_enabled,
152
- characters_merge_enabled,
153
- reorganizer_model_repo,
154
- additional_tags_prepend,
155
- additional_tags_append,
156
- tag_results,
157
- progress=gr.Progress()
158
- ):
159
- # Clear tag_results before starting a new prediction
160
- tag_results.clear()
161
-
162
- gallery_len = len(gallery)
163
- print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
164
-
165
- timer = Timer() # Create a timer
166
- progressRatio = 0.5 if reorganizer_model_repo else 1
167
- progressTotal = gallery_len + 1
168
- current_progress = 0
169
-
170
- self.load_model(model_repo)
171
- current_progress += progressRatio/progressTotal;
172
- progress(current_progress, desc="Initialize wd model finished")
173
- timer.checkpoint(f"Initialize wd model")
174
-
175
- txt_infos = []
176
- output_dir = tempfile.mkdtemp()
177
- if not os.path.exists(output_dir):
178
- os.makedirs(output_dir)
179
-
180
- sorted_general_strings = ""
181
- # Create categorized output string
182
- categorized_output_strings = []
183
- rating = None
184
- character_res = None
185
- general_res = None
186
-
187
- if reorganizer_model_repo:
188
- print(f"Reorganizer load model {reorganizer_model_repo}")
189
- reorganizer = reorganizer_class(reorganizer_model_repo, loadModel=True)
190
- current_progress += progressRatio/progressTotal;
191
- progress(current_progress, desc="Initialize reoganizer model finished")
192
- timer.checkpoint(f"Initialize reoganizer model")
193
-
194
- timer.report()
195
-
196
- prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
197
- append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
198
- if prepend_list and append_list:
199
- append_list = [item for item in append_list if item not in prepend_list]
200
-
201
- # Dictionary to track counters for each filename
202
- name_counters = defaultdict(int)
203
-
204
- for idx, value in enumerate(gallery):
205
- try:
206
- image_path = value[0]
207
- image_name = os.path.splitext(os.path.basename(image_path))[0]
208
-
209
- # Increment the counter for the current name
210
- name_counters[image_name] += 1
211
-
212
- if name_counters[image_name] > 1:
213
- image_name = f"{image_name}_{name_counters[image_name]:02d}"
214
-
215
- image = self.prepare_image(image_path)
216
-
217
- input_name = self.model.get_inputs()[0].name
218
- label_name = self.model.get_outputs()[0].name
219
- print(f"Gallery {idx:02d}: Starting run wd model...")
220
- preds = self.model.run([label_name], {input_name: image})[0]
221
-
222
- labels = list(zip(self.tag_names, preds[0].astype(float)))
223
-
224
- # First 4 labels are actually ratings: pick one with argmax
225
- ratings_names = [labels[i] for i in self.rating_indexes]
226
- rating = dict(ratings_names)
227
-
228
- # Then we have general tags: pick any where prediction confidence > threshold
229
- general_names = [labels[i] for i in self.general_indexes]
230
-
231
- if general_mcut_enabled:
232
- general_probs = np.array([x[1] for x in general_names])
233
- general_thresh = mcut_threshold(general_probs)
234
-
235
- general_res = [x for x in general_names if x[1] > general_thresh]
236
- general_res = dict(general_res)
237
-
238
- # Everything else is characters: pick any where prediction confidence > threshold
239
- character_names = [labels[i] for i in self.character_indexes]
240
-
241
- if character_mcut_enabled:
242
- character_probs = np.array([x[1] for x in character_names])
243
- character_thresh = mcut_threshold(character_probs)
244
- character_thresh = max(0.15, character_thresh)
245
-
246
- character_res = [x for x in character_names if x[1] > character_thresh]
247
- character_res = dict(character_res)
248
- character_list = list(character_res.keys())
249
-
250
- sorted_general_list = sorted(
251
- general_res.items(),
252
- key=lambda x: x[1],
253
- reverse=True,
254
- )
255
- sorted_general_list = [x[0] for x in sorted_general_list]
256
- # Remove values from character_list that already exist in sorted_general_list
257
- character_list = [item for item in character_list if item not in sorted_general_list]
258
- # Remove values from sorted_general_list that already exist in prepend_list or append_list
259
- if prepend_list:
260
- sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
261
- if append_list:
262
- sorted_general_list = [item for item in sorted_general_list if item not in append_list]
263
-
264
- sorted_general_list = prepend_list + sorted_general_list + append_list
265
-
266
- sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
267
-
268
- classified_tags, unclassified_tags = classify_tags(sorted_general_list)
269
-
270
- # Create a single string of ALL categorized tags for the current image
271
- categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()])
272
- categorized_output_strings.append(categorized_output_string)
273
- # Collect all categorized output strings into a single string
274
- final_categorized_output = ', '.join(categorized_output_strings)
275
-
276
- # Create a .txt file for "Output (string)" and "Categorized Output (string)"
277
- txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}"
278
- txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
279
- txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"})
280
-
281
- # Create a .json file for "Categorized (tags)"
282
- json_content = json.dumps(classified_tags, indent=4)
283
- json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json")
284
- txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"})
285
-
286
- # Save a copy of the uploaded image in PNG format
287
- image_path = value[0]
288
- image = Image.open(image_path)
289
- image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG")
290
- txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"})
291
-
292
- current_progress += progressRatio/progressTotal;
293
- progress(current_progress, desc=f"image{idx:02d}, predict finished")
294
- timer.checkpoint(f"image{idx:02d}, predict finished")
295
-
296
- if reorganizer_model_repo:
297
- print(f"Starting reorganizer...")
298
- reorganize_strings = reorganizer.reorganize(sorted_general_strings)
299
- reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
300
- reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
301
- reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
302
- sorted_general_strings += ",\n\n" + reorganize_strings
303
-
304
- current_progress += progressRatio/progressTotal;
305
- progress(current_progress, desc=f"image{idx:02d}, reorganizer finished")
306
- timer.checkpoint(f"image{idx:02d}, reorganizer finished")
307
-
308
- txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
309
- txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
310
-
311
- # Store the result in tag_results using image_path as the key
312
- tag_results[image_path] = {
313
- "strings": sorted_general_strings,
314
- "strings2": categorized_output_string, # Store the categorized output string here
315
- "classified_tags": classified_tags,
316
- "rating": rating,
317
- "character_res": character_res,
318
- "general_res": general_res,
319
- "unclassified_tags": unclassified_tags,
320
- "enhanced_tags": "" # Initialize as empty string
321
- }
322
-
323
- timer.report()
324
- except Exception as e:
325
- print(traceback.format_exc())
326
- print("Error predict: " + str(e))
327
- # Zip creation logic:
328
- download = []
329
- if txt_infos is not None and len(txt_infos) > 0:
330
- downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
331
- with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
332
- for info in txt_infos:
333
- # Get file name from lookup
334
- taggers_zip.write(info["path"], arcname=info["name"])
335
- download.append(downloadZipPath)
336
- # End zip creation logic
337
- if reorganizer_model_repo:
338
- reorganizer.release_vram()
339
- del reorganizer
340
-
341
- progress(1, desc=f"Predict completed")
342
- timer.report_all() # Print all recorded times
343
- print("Predict is complete.")
344
-
345
- return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
346
- def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
347
- if not selected_state:
348
- return selected_state
349
- tag_result = {
350
- "strings": "",
351
- "strings2": "",
352
- "classified_tags": "{}",
353
- "rating": "",
354
- "character_res": "",
355
- "general_res": "",
356
- "unclassified_tags": "{}",
357
- "enhanced_tags": ""
358
- }
359
- if selected_state.value["image"]["path"] in tag_results:
360
- tag_result = tag_results[selected_state.value["image"]["path"]]
361
- return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
362
- def append_gallery(gallery:list,image:str):
363
- if gallery is None:gallery=[]
364
- if not image:return gallery,None
365
- gallery.append(image);return gallery,None
366
- def extend_gallery(gallery:list,images):
367
- if gallery is None:gallery=[]
368
- if not images:return gallery
369
- gallery.extend(images);return gallery
370
- def remove_image_from_gallery(gallery:list,selected_image:str):
371
- if not gallery or not selected_image:return gallery
372
- selected_image=ast.literal_eval(selected_image)
373
- if selected_image in gallery:gallery.remove(selected_image)
374
- return gallery
375
- args = parse_args()
376
- predictor = Predictor()
377
- dropdown_list = [
378
- EVA02_LARGE_MODEL_DSV3_REPO,
379
- SWINV2_MODEL_DSV3_REPO,
380
- CONV_MODEL_DSV3_REPO,
381
- VIT_MODEL_DSV3_REPO,
382
- VIT_LARGE_MODEL_DSV3_REPO,
383
- # ---
384
- MOAT_MODEL_DSV2_REPO,
385
- SWIN_MODEL_DSV2_REPO,
386
- CONV_MODEL_DSV2_REPO,
387
- CONV2_MODEL_DSV2_REPO,
388
- VIT_MODEL_DSV2_REPO,
389
- # ---
390
- SWINV2_MODEL_IS_DSV1_REPO,
391
- EVA02_LARGE_MODEL_IS_DSV1_REPO,
392
- ]
393
-
394
- def _restart_space():
395
- HF_TOKEN=os.getenv('HF_TOKEN')
396
- if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.')
397
- huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False)
398
- scheduler=BackgroundScheduler()
399
- # Add a job to restart the space every 2 days (172800 seconds)
400
- restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
401
- scheduler.start()
402
- next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
403
- NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
404
-
405
- css = """
406
- #output {height: 500px; overflow: auto; border: 1px solid #ccc;}
407
- label.float.svelte-i3tvor {position: relative !important;}
408
- .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
409
- """
410
-
411
- with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo:
412
- gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
413
- gr.Markdown(value=DESCRIPTION)
414
- gr.Markdown(NEXT_RESTART)
415
- with gr.Tab(label="Waifu Diffusion"):
416
- with gr.Row():
417
- with gr.Column():
418
- submit = gr.Button(value="Submit", variant="primary", size="lg")
419
- with gr.Column(variant="panel"):
420
- # Create an Image component for uploading images
421
- image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
422
- with gr.Row():
423
- upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
424
- remove_button = gr.Button("Remove Selected Image", size="sm")
425
- gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
426
- model_repo = gr.Dropdown(
427
- dropdown_list,
428
- value=EVA02_LARGE_MODEL_DSV3_REPO,
429
- label="Model",
430
- )
431
- with gr.Row():
432
- general_thresh = gr.Slider(
433
- 0,
434
- 1,
435
- step=args.score_slider_step,
436
- value=args.score_general_threshold,
437
- label="General Tags Threshold",
438
- scale=3,
439
- )
440
- general_mcut_enabled = gr.Checkbox(
441
- value=False,
442
- label="Use MCut threshold",
443
- scale=1,
444
- )
445
- with gr.Row():
446
- character_thresh = gr.Slider(
447
- 0,
448
- 1,
449
- step=args.score_slider_step,
450
- value=args.score_character_threshold,
451
- label="Character Tags Threshold",
452
- scale=3,
453
- )
454
- character_mcut_enabled = gr.Checkbox(
455
- value=False,
456
- label="Use MCut threshold",
457
- scale=1,
458
- )
459
- with gr.Row():
460
- characters_merge_enabled = gr.Checkbox(
461
- value=True,
462
- label="Merge characters into the string output",
463
- scale=1,
464
- )
465
- with gr.Row():
466
- reorganizer_model_repo = gr.Dropdown(
467
- [None] + reorganizer_list,
468
- value=None,
469
- label="Reorganizer Model",
470
- info="Use a model to create a description for you",
471
- )
472
- with gr.Row():
473
- additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
474
- additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
475
- with gr.Row():
476
- clear = gr.ClearButton(
477
- components=[
478
- gallery,
479
- model_repo,
480
- general_thresh,
481
- general_mcut_enabled,
482
- character_thresh,
483
- character_mcut_enabled,
484
- characters_merge_enabled,
485
- reorganizer_model_repo,
486
- additional_tags_prepend,
487
- additional_tags_append,
488
- ],
489
- variant="secondary",
490
- size="lg",
491
- )
492
- with gr.Column(variant="panel"):
493
- download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
494
- character_res = gr.Label(label="Output (characters)") # 1
495
- sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
496
- final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
497
- pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
498
- enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
499
- prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
500
- categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
501
- rating = gr.Label(label="Rating") # 8
502
- general_res = gr.Label(label="Output (tags)") # 9
503
- unclassified = gr.JSON(label="Unclassified (tags)") # 10
504
- clear.add(
505
- [
506
- download_file,
507
- sorted_general_strings,
508
- final_categorized_output,
509
- categorized,
510
- rating,
511
- character_res,
512
- general_res,
513
- unclassified,
514
- prompt_enhancer_model,
515
- enhanced_tags,
516
- ]
517
- )
518
- tag_results = gr.State({})
519
- # Define the event listener to add the uploaded image to the gallery
520
- image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
521
- # When the upload button is clicked, add the new images to the gallery
522
- upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
523
- # Event to update the selected image when an image is clicked in the gallery
524
- selected_image = gr.Textbox(label="Selected Image", visible=False)
525
- gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
526
- # Event to remove a selected image from the gallery
527
- remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
528
- # Event to for the Prompt Enhancer Button
529
- pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
530
- submit.click(
531
- predictor.predict,
532
- inputs=[
533
- gallery,
534
- model_repo,
535
- general_thresh,
536
- general_mcut_enabled,
537
- character_thresh,
538
- character_mcut_enabled,
539
- characters_merge_enabled,
540
- reorganizer_model_repo,
541
- additional_tags_prepend,
542
- additional_tags_append,
543
- tag_results,
544
- ],
545
- outputs=[download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, tag_results,],
546
- )
547
- gr.Examples(
548
- [["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
549
- inputs=[
550
- image_input,
551
- model_repo,
552
- general_thresh,
553
- general_mcut_enabled,
554
- character_thresh,
555
- character_mcut_enabled,
556
- ],
557
- )
558
- with gr.Tab(label="Florence 2 Image Captioning"):
559
- with gr.Row():
560
- with gr.Column(variant="panel"):
561
- input_img = gr.Image(label="Input Picture")
562
- task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
563
- task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
564
- task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
565
- text_input = gr.Textbox(label="Text Input (optional)")
566
- submit_btn = gr.Button(value="Submit")
567
- with gr.Column(variant="panel"):
568
- output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8)
569
- output_img = gr.Image(label="Output Image")
570
- gr.Examples(
571
- examples=[
572
- ["images/image1.png", 'Object Detection'],
573
- ["images/image2.png", 'OCR with Region']
574
- ],
575
- inputs=[input_img, task_prompt],
576
- outputs=[output_text, output_img],
577
- fn=process_image,
578
- cache_examples=False,
579
- label='Try examples'
580
- )
581
- submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
582
- with gr.Tab(label="Gelbooru Image Fetcher"):
583
- with gr.Row():
584
- with gr.Column():
585
- gr.Markdown("### ⚙️ Search Parameters")
586
- site = gr.Dropdown(label="Select Source", choices=["Gelbooru", "None (will not work)"], value="Gelbooru")
587
- OR_tags = gr.Textbox(label="OR Tags (comma-separated)", placeholder="e.g. solo, 1girl, 1boy, artist, character, ...")
588
- AND_tags = gr.Textbox(label="AND Tags (comma-separated)", placeholder="e.g. black hair, cat ears, black hair, granblue fantasy, ...")
589
- exclude_tags = gr.Textbox(label="Exclude Tags (comma-separated)", placeholder="e.g. animated, watermark, username, ...")
590
- score = gr.Number(label="Minimum Score", value=0)
591
- count = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1) # Increase if necessary (not recommend)
592
- Safe = gr.Checkbox(label="Include Safe", value=True)
593
- Questionable = gr.Checkbox(label="Include Questionable", value=True)
594
- Explicit = gr.Checkbox(label="Include Explicit", value=False)
595
- #user_id = gr.Textbox(label="User ID (Optional)", value="")
596
- #api_key = gr.Textbox(label="API Key (Optional)", value="", type="password")
597
-
598
- submit_btn = gr.Button("Fetch Images", variant="primary")
599
-
600
- with gr.Column():
601
- gr.Markdown("### 📄 Results")
602
- images_output = gr.Gallery(label="Images", columns=3, rows=2, object_fit="contain", height=500)
603
- tags_output = gr.Textbox(label="Tags", placeholder="Select an image to show tags", lines=5, show_copy_button=True)
604
- post_url_output = gr.Textbox(label="Post URL", lines=1, show_copy_button=True)
605
- image_url_output = gr.Textbox(label="Image URL", lines=1, show_copy_button=True)
606
-
607
- # State to store tags, URLs
608
- tags_state = gr.State([])
609
- post_url_state = gr.State([])
610
- image_url_state = gr.State([])
611
-
612
- submit_btn.click(
613
- fn=gelbooru_gradio,
614
- inputs=[OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit, site], # add 'api_key' and 'user_id' if necessary
615
- outputs=[images_output, tags_state, post_url_state, image_url_state],
616
- )
617
-
618
- images_output.select(
619
- fn=on_select,
620
- inputs=[tags_state, post_url_state, image_url_state],
621
- outputs=[tags_output, post_url_output, image_url_output],
622
- )
623
- gr.Markdown("""
624
- ---
625
- ComfyUI version: [Comfyui-Gelbooru](https://github.com/1mckw/Comfyui-Gelbooru)
626
- """)
627
- with gr.Tab(label="Categorizer++"):
628
- with gr.Row():
629
- with gr.Column(variant="panel"):
630
- input_tags = gr.Textbox(label="Input Tags", placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...", lines=4)
631
- submit_button = gr.Button(value="Submit", variant="primary", size="lg")
632
- with gr.Column(variant="panel"):
633
- categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
634
- categorized_json = gr.JSON(label="Categorized (tags) - JSON")
635
- submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
636
- with gr.Column(variant="panel"):
637
- pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
638
- enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
639
- prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
640
- pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
641
  demo.queue(max_size=2).launch()
 
1
+ import os
2
+ import io,copy,requests,spaces,gradio as gr,numpy as np
3
+ from PIL import Image, ImageOps
4
+ import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
5
+ from datetime import datetime,timezone
6
+ from collections import defaultdict
7
+ from apscheduler.schedulers.background import BackgroundScheduler
8
+ import json
9
+ from modules.classifyTags import classify_tags,process_tags
10
+ from modules.florence2 import process_image,single_task_list,update_task_dropdown
11
+ from modules.reorganizer_model import reorganizer_list,reorganizer_class
12
+ from modules.tag_enhancer import prompt_enhancer
13
+ from modules.booru import gelbooru_gradio,fetch_gelbooru_images,on_select
14
+
15
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
16
+
17
+ TITLE = "Multi-Tagger v1.2"
18
+ DESCRIPTION = """
19
+ Multi-Tagger is a versatile application for advanced image analysis and captioning. Perfect for AI artists or enthusiasts, it offers a range of features:
20
+
21
+ - **Automatic Tag Categorization**: Tags are grouped into categories.
22
+ - **Tag Enhancement**: Boost your prompts with enhanced descriptions using a built-in prompt enhancer.
23
+ - **Reorganizer**: Use a reorganizer model to format tags into a natural-language description.
24
+ - **Batch Support**: Upload and process multiple images simultaneously.
25
+ - **Downloadable Output**: Get almost all results as downloadable `.txt`, `.json`, and `.png` files in a `.zip` archive.
26
+ - **Image Fetcher**: Search for images from **Gelbooru** using flexible tag filters.
27
+ - CUDA or CPU support.
28
+
29
+ Example image by [me](https://huggingface.co/Werli).
30
+ """
31
+
32
+ # Dataset v3 series of models:
33
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
34
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
35
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
36
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
37
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
38
+ # Dataset v2 series of models:
39
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
40
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
41
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
42
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
43
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
44
+ # IdolSankaku series of models:
45
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
46
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
47
+ # Files to download from the repos
48
+ MODEL_FILENAME = "model.onnx"
49
+ LABEL_FILENAME = "selected_tags.csv"
50
+
51
+ kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
52
+ def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
53
+ def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
54
+ def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
55
+
56
+ class Timer:
57
+ def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
58
+ def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
59
+ def report(self,is_clear_checkpoints=True):
60
+ max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1]
61
+ for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
62
+ if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint()
63
+ def report_all(self):
64
+ print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time
65
+ for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
66
+ total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
67
+ def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
68
+ class Predictor:
69
+ def __init__(self):
70
+ self.model_target_size = None
71
+ self.last_loaded_repo = None
72
+ def download_model(self, model_repo):
73
+ csv_path = huggingface_hub.hf_hub_download(
74
+ model_repo,
75
+ LABEL_FILENAME,
76
+ )
77
+ model_path = huggingface_hub.hf_hub_download(
78
+ model_repo,
79
+ MODEL_FILENAME,
80
+ )
81
+ return csv_path, model_path
82
+ def load_model(self, model_repo):
83
+ if model_repo == self.last_loaded_repo:
84
+ return
85
+
86
+ csv_path, model_path = self.download_model(model_repo)
87
+
88
+ tags_df = pd.read_csv(csv_path)
89
+ sep_tags = load_labels(tags_df)
90
+
91
+ self.tag_names = sep_tags[0]
92
+ self.rating_indexes = sep_tags[1]
93
+ self.general_indexes = sep_tags[2]
94
+ self.character_indexes = sep_tags[3]
95
+
96
+ model = rt.InferenceSession(model_path)
97
+ _, height, width, _ = model.get_inputs()[0].shape
98
+ self.model_target_size = height
99
+
100
+ self.last_loaded_repo = model_repo
101
+ self.model = model
102
+ def prepare_image(self, path):
103
+ image = Image.open(path)
104
+ image = image.convert("RGBA")
105
+ target_size = self.model_target_size
106
+
107
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
108
+ canvas.alpha_composite(image)
109
+ image = canvas.convert("RGB")
110
+
111
+ # Pad image to square
112
+ image_shape = image.size
113
+ max_dim = max(image_shape)
114
+ pad_left = (max_dim - image_shape[0]) // 2
115
+ pad_top = (max_dim - image_shape[1]) // 2
116
+
117
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
118
+ padded_image.paste(image, (pad_left, pad_top))
119
+
120
+ # Resize
121
+ if max_dim != target_size:
122
+ padded_image = padded_image.resize(
123
+ (target_size, target_size),
124
+ Image.BICUBIC,
125
+ )
126
+ # Convert to numpy array
127
+ image_array = np.asarray(padded_image, dtype=np.float32)
128
+ # Convert PIL-native RGB to BGR
129
+ image_array = image_array[:, :, ::-1]
130
+ return np.expand_dims(image_array, axis=0)
131
+
132
+ def create_file(self, content: str, directory: str, fileName: str) -> str:
133
+ # Write the content to a file
134
+ file_path = os.path.join(directory, fileName)
135
+ if fileName.endswith('.json'):
136
+ with open(file_path, 'w', encoding="utf-8") as file:
137
+ file.write(content)
138
+ else:
139
+ with open(file_path, 'w+', encoding="utf-8") as file:
140
+ file.write(content)
141
+
142
+ return file_path
143
+
144
+ def predict(
145
+ self,
146
+ gallery,
147
+ model_repo,
148
+ general_thresh,
149
+ general_mcut_enabled,
150
+ character_thresh,
151
+ character_mcut_enabled,
152
+ characters_merge_enabled,
153
+ reorganizer_model_repo,
154
+ additional_tags_prepend,
155
+ additional_tags_append,
156
+ tag_results,
157
+ progress=gr.Progress()
158
+ ):
159
+ # Clear tag_results before starting a new prediction
160
+ tag_results.clear()
161
+
162
+ gallery_len = len(gallery)
163
+ print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
164
+
165
+ timer = Timer() # Create a timer
166
+ progressRatio = 0.5 if reorganizer_model_repo else 1
167
+ progressTotal = gallery_len + 1
168
+ current_progress = 0
169
+
170
+ self.load_model(model_repo)
171
+ current_progress += progressRatio/progressTotal;
172
+ progress(current_progress, desc="Initialize wd model finished")
173
+ timer.checkpoint(f"Initialize wd model")
174
+
175
+ txt_infos = []
176
+ output_dir = tempfile.mkdtemp()
177
+ if not os.path.exists(output_dir):
178
+ os.makedirs(output_dir)
179
+
180
+ sorted_general_strings = ""
181
+ # Create categorized output string
182
+ categorized_output_strings = []
183
+ rating = None
184
+ character_res = None
185
+ general_res = None
186
+
187
+ if reorganizer_model_repo:
188
+ print(f"Reorganizer load model {reorganizer_model_repo}")
189
+ reorganizer = reorganizer_class(reorganizer_model_repo, loadModel=True)
190
+ current_progress += progressRatio/progressTotal;
191
+ progress(current_progress, desc="Initialize reoganizer model finished")
192
+ timer.checkpoint(f"Initialize reoganizer model")
193
+
194
+ timer.report()
195
+
196
+ prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
197
+ append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
198
+ if prepend_list and append_list:
199
+ append_list = [item for item in append_list if item not in prepend_list]
200
+
201
+ # Dictionary to track counters for each filename
202
+ name_counters = defaultdict(int)
203
+
204
+ for idx, value in enumerate(gallery):
205
+ try:
206
+ image_path = value[0]
207
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
208
+
209
+ # Increment the counter for the current name
210
+ name_counters[image_name] += 1
211
+
212
+ if name_counters[image_name] > 1:
213
+ image_name = f"{image_name}_{name_counters[image_name]:02d}"
214
+
215
+ image = self.prepare_image(image_path)
216
+
217
+ input_name = self.model.get_inputs()[0].name
218
+ label_name = self.model.get_outputs()[0].name
219
+ print(f"Gallery {idx:02d}: Starting run wd model...")
220
+ preds = self.model.run([label_name], {input_name: image})[0]
221
+
222
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
223
+
224
+ # First 4 labels are actually ratings: pick one with argmax
225
+ ratings_names = [labels[i] for i in self.rating_indexes]
226
+ rating = dict(ratings_names)
227
+
228
+ # Then we have general tags: pick any where prediction confidence > threshold
229
+ general_names = [labels[i] for i in self.general_indexes]
230
+
231
+ if general_mcut_enabled:
232
+ general_probs = np.array([x[1] for x in general_names])
233
+ general_thresh = mcut_threshold(general_probs)
234
+
235
+ general_res = [x for x in general_names if x[1] > general_thresh]
236
+ general_res = dict(general_res)
237
+
238
+ # Everything else is characters: pick any where prediction confidence > threshold
239
+ character_names = [labels[i] for i in self.character_indexes]
240
+
241
+ if character_mcut_enabled:
242
+ character_probs = np.array([x[1] for x in character_names])
243
+ character_thresh = mcut_threshold(character_probs)
244
+ character_thresh = max(0.15, character_thresh)
245
+
246
+ character_res = [x for x in character_names if x[1] > character_thresh]
247
+ character_res = dict(character_res)
248
+ character_list = list(character_res.keys())
249
+
250
+ sorted_general_list = sorted(
251
+ general_res.items(),
252
+ key=lambda x: x[1],
253
+ reverse=True,
254
+ )
255
+ sorted_general_list = [x[0] for x in sorted_general_list]
256
+ # Remove values from character_list that already exist in sorted_general_list
257
+ character_list = [item for item in character_list if item not in sorted_general_list]
258
+ # Remove values from sorted_general_list that already exist in prepend_list or append_list
259
+ if prepend_list:
260
+ sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
261
+ if append_list:
262
+ sorted_general_list = [item for item in sorted_general_list if item not in append_list]
263
+
264
+ sorted_general_list = prepend_list + sorted_general_list + append_list
265
+
266
+ sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
267
+
268
+ classified_tags, unclassified_tags = classify_tags(sorted_general_list)
269
+
270
+ # Create a single string of ALL categorized tags for the current image
271
+ categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()])
272
+ categorized_output_strings.append(categorized_output_string)
273
+ # Collect all categorized output strings into a single string
274
+ final_categorized_output = ', '.join(categorized_output_strings)
275
+
276
+ # Create a .txt file for "Output (string)" and "Categorized Output (string)"
277
+ txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}"
278
+ txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
279
+ txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"})
280
+
281
+ # Create a .json file for "Categorized (tags)"
282
+ json_content = json.dumps(classified_tags, indent=4)
283
+ json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json")
284
+ txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"})
285
+
286
+ # Save a copy of the uploaded image in PNG format
287
+ image_path = value[0]
288
+ image = Image.open(image_path)
289
+ image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG")
290
+ txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"})
291
+
292
+ current_progress += progressRatio/progressTotal;
293
+ progress(current_progress, desc=f"image{idx:02d}, predict finished")
294
+ timer.checkpoint(f"image{idx:02d}, predict finished")
295
+
296
+ if reorganizer_model_repo:
297
+ print(f"Starting reorganizer...")
298
+ reorganize_strings = reorganizer.reorganize(sorted_general_strings)
299
+ reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
300
+ reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
301
+ reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
302
+ sorted_general_strings += ",\n\n" + reorganize_strings
303
+
304
+ current_progress += progressRatio/progressTotal;
305
+ progress(current_progress, desc=f"image{idx:02d}, reorganizer finished")
306
+ timer.checkpoint(f"image{idx:02d}, reorganizer finished")
307
+
308
+ txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
309
+ txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
310
+
311
+ # Store the result in tag_results using image_path as the key
312
+ tag_results[image_path] = {
313
+ "strings": sorted_general_strings,
314
+ "strings2": categorized_output_string, # Store the categorized output string here
315
+ "classified_tags": classified_tags,
316
+ "rating": rating,
317
+ "character_res": character_res,
318
+ "general_res": general_res,
319
+ "unclassified_tags": unclassified_tags,
320
+ "enhanced_tags": "" # Initialize as empty string
321
+ }
322
+
323
+ timer.report()
324
+ except Exception as e:
325
+ print(traceback.format_exc())
326
+ print("Error predict: " + str(e))
327
+ # Zip creation logic:
328
+ download = []
329
+ if txt_infos is not None and len(txt_infos) > 0:
330
+ downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
331
+ with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
332
+ for info in txt_infos:
333
+ # Get file name from lookup
334
+ taggers_zip.write(info["path"], arcname=info["name"])
335
+ download.append(downloadZipPath)
336
+ # End zip creation logic
337
+ if reorganizer_model_repo:
338
+ reorganizer.release_vram()
339
+ del reorganizer
340
+
341
+ progress(1, desc=f"Predict completed")
342
+ timer.report_all() # Print all recorded times
343
+ print("Predict is complete.")
344
+
345
+ return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
346
+ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
347
+ if not selected_state:
348
+ return selected_state
349
+ tag_result = {
350
+ "strings": "",
351
+ "strings2": "",
352
+ "classified_tags": "{}",
353
+ "rating": "",
354
+ "character_res": "",
355
+ "general_res": "",
356
+ "unclassified_tags": "{}",
357
+ "enhanced_tags": ""
358
+ }
359
+ if selected_state.value["image"]["path"] in tag_results:
360
+ tag_result = tag_results[selected_state.value["image"]["path"]]
361
+ return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
362
+ def append_gallery(gallery:list,image:str):
363
+ if gallery is None:gallery=[]
364
+ if not image:return gallery,None
365
+ gallery.append(image);return gallery,None
366
+ def extend_gallery(gallery:list,images):
367
+ if gallery is None:gallery=[]
368
+ if not images:return gallery
369
+ gallery.extend(images);return gallery
370
+ def remove_image_from_gallery(gallery:list,selected_image:str):
371
+ if not gallery or not selected_image:return gallery
372
+ selected_image=ast.literal_eval(selected_image)
373
+ if selected_image in gallery:gallery.remove(selected_image)
374
+ return gallery
375
+ args = parse_args()
376
+ predictor = Predictor()
377
+ dropdown_list = [
378
+ EVA02_LARGE_MODEL_DSV3_REPO,
379
+ SWINV2_MODEL_DSV3_REPO,
380
+ CONV_MODEL_DSV3_REPO,
381
+ VIT_MODEL_DSV3_REPO,
382
+ VIT_LARGE_MODEL_DSV3_REPO,
383
+ # ---
384
+ MOAT_MODEL_DSV2_REPO,
385
+ SWIN_MODEL_DSV2_REPO,
386
+ CONV_MODEL_DSV2_REPO,
387
+ CONV2_MODEL_DSV2_REPO,
388
+ VIT_MODEL_DSV2_REPO,
389
+ # ---
390
+ SWINV2_MODEL_IS_DSV1_REPO,
391
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
392
+ ]
393
+
394
+ def _restart_space():
395
+ HF_TOKEN=os.getenv('HF_TOKEN')
396
+ if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.')
397
+ huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False)
398
+ scheduler=BackgroundScheduler()
399
+ # Add a job to restart the space every 2 days (172800 seconds)
400
+ restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
401
+ scheduler.start()
402
+ next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
403
+ NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
404
+
405
+ css = """
406
+ #output {height: 500px; overflow: auto; border: 1px solid #ccc;}
407
+ label.float.svelte-i3tvor {position: relative !important;}
408
+ .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
409
+ """
410
+
411
+ with gr.Blocks(title=TITLE, css=css, theme=gr.themes.Soft(), fill_width=True) as demo:
412
+ gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
413
+ gr.Markdown(value=DESCRIPTION)
414
+ gr.Markdown(NEXT_RESTART)
415
+ with gr.Tab(label="Waifu Diffusion"):
416
+ with gr.Row():
417
+ with gr.Column():
418
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
419
+ with gr.Column(variant="panel"):
420
+ # Create an Image component for uploading images
421
+ image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
422
+ with gr.Row():
423
+ upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
424
+ remove_button = gr.Button("Remove Selected Image", size="sm")
425
+ gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
426
+ model_repo = gr.Dropdown(
427
+ dropdown_list,
428
+ value=EVA02_LARGE_MODEL_DSV3_REPO,
429
+ label="Model",
430
+ )
431
+ with gr.Row():
432
+ general_thresh = gr.Slider(
433
+ 0,
434
+ 1,
435
+ step=args.score_slider_step,
436
+ value=args.score_general_threshold,
437
+ label="General Tags Threshold",
438
+ scale=3,
439
+ )
440
+ general_mcut_enabled = gr.Checkbox(
441
+ value=False,
442
+ label="Use MCut threshold",
443
+ scale=1,
444
+ )
445
+ with gr.Row():
446
+ character_thresh = gr.Slider(
447
+ 0,
448
+ 1,
449
+ step=args.score_slider_step,
450
+ value=args.score_character_threshold,
451
+ label="Character Tags Threshold",
452
+ scale=3,
453
+ )
454
+ character_mcut_enabled = gr.Checkbox(
455
+ value=False,
456
+ label="Use MCut threshold",
457
+ scale=1,
458
+ )
459
+ with gr.Row():
460
+ characters_merge_enabled = gr.Checkbox(
461
+ value=True,
462
+ label="Merge characters into the string output",
463
+ scale=1,
464
+ )
465
+ with gr.Row():
466
+ reorganizer_model_repo = gr.Dropdown(
467
+ [None] + reorganizer_list,
468
+ value=None,
469
+ label="Reorganizer Model",
470
+ info="Use a model to create a description for you",
471
+ )
472
+ with gr.Row():
473
+ additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
474
+ additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
475
+ with gr.Row():
476
+ clear = gr.ClearButton(
477
+ components=[
478
+ gallery,
479
+ model_repo,
480
+ general_thresh,
481
+ general_mcut_enabled,
482
+ character_thresh,
483
+ character_mcut_enabled,
484
+ characters_merge_enabled,
485
+ reorganizer_model_repo,
486
+ additional_tags_prepend,
487
+ additional_tags_append,
488
+ ],
489
+ variant="secondary",
490
+ size="lg",
491
+ )
492
+ with gr.Column(variant="panel"):
493
+ download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
494
+ character_res = gr.Label(label="Output (characters)") # 1
495
+ sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
496
+ final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
497
+ pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
498
+ enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
499
+ prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
500
+ categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
501
+ rating = gr.Label(label="Rating") # 8
502
+ general_res = gr.Label(label="Output (tags)") # 9
503
+ unclassified = gr.JSON(label="Unclassified (tags)") # 10
504
+ clear.add(
505
+ [
506
+ download_file,
507
+ sorted_general_strings,
508
+ final_categorized_output,
509
+ categorized,
510
+ rating,
511
+ character_res,
512
+ general_res,
513
+ unclassified,
514
+ prompt_enhancer_model,
515
+ enhanced_tags,
516
+ ]
517
+ )
518
+ tag_results = gr.State({})
519
+ # Define the event listener to add the uploaded image to the gallery
520
+ image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
521
+ # When the upload button is clicked, add the new images to the gallery
522
+ upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
523
+ # Event to update the selected image when an image is clicked in the gallery
524
+ selected_image = gr.Textbox(label="Selected Image", visible=False)
525
+ gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
526
+ # Event to remove a selected image from the gallery
527
+ remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
528
+ # Event to for the Prompt Enhancer Button
529
+ pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
530
+ submit.click(
531
+ predictor.predict,
532
+ inputs=[
533
+ gallery,
534
+ model_repo,
535
+ general_thresh,
536
+ general_mcut_enabled,
537
+ character_thresh,
538
+ character_mcut_enabled,
539
+ characters_merge_enabled,
540
+ reorganizer_model_repo,
541
+ additional_tags_prepend,
542
+ additional_tags_append,
543
+ tag_results,
544
+ ],
545
+ outputs=[download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, tag_results,],
546
+ )
547
+ gr.Examples(
548
+ [["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
549
+ inputs=[
550
+ image_input,
551
+ model_repo,
552
+ general_thresh,
553
+ general_mcut_enabled,
554
+ character_thresh,
555
+ character_mcut_enabled,
556
+ ],
557
+ )
558
+ with gr.Tab(label="Florence 2 Image Captioning"):
559
+ with gr.Row():
560
+ with gr.Column(variant="panel"):
561
+ input_img = gr.Image(label="Input Picture")
562
+ task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
563
+ task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
564
+ task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
565
+ text_input = gr.Textbox(label="Text Input (optional)")
566
+ submit_btn = gr.Button(value="Submit")
567
+ with gr.Column(variant="panel"):
568
+ output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8)
569
+ output_img = gr.Image(label="Output Image")
570
+ gr.Examples(
571
+ examples=[
572
+ ["images/image1.png", 'Object Detection'],
573
+ ["images/image2.png", 'OCR with Region']
574
+ ],
575
+ inputs=[input_img, task_prompt],
576
+ outputs=[output_text, output_img],
577
+ fn=process_image,
578
+ cache_examples=False,
579
+ label='Try examples'
580
+ )
581
+ submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
582
+ with gr.Tab(label="Gelbooru Image Fetcher"):
583
+ with gr.Row():
584
+ with gr.Column():
585
+ gr.Markdown("### ⚙️ Search Parameters")
586
+ site = gr.Dropdown(label="Select Source", choices=["Gelbooru", "None (will not work)"], value="Gelbooru")
587
+ OR_tags = gr.Textbox(label="OR Tags (comma-separated)", placeholder="e.g. solo, 1girl, 1boy, artist, character, ...")
588
+ AND_tags = gr.Textbox(label="AND Tags (comma-separated)", placeholder="e.g. black hair, cat ears, black hair, granblue fantasy, ...")
589
+ exclude_tags = gr.Textbox(label="Exclude Tags (comma-separated)", placeholder="e.g. animated, watermark, username, ...")
590
+ score = gr.Number(label="Minimum Score", value=0)
591
+ count = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1) # Increase if necessary (not recommend)
592
+ Safe = gr.Checkbox(label="Include Safe", value=True)
593
+ Questionable = gr.Checkbox(label="Include Questionable", value=True)
594
+ Explicit = gr.Checkbox(label="Include Explicit", value=False)
595
+ #user_id = gr.Textbox(label="User ID (Optional)", value="")
596
+ #api_key = gr.Textbox(label="API Key (Optional)", value="", type="password")
597
+
598
+ submit_btn = gr.Button("Fetch Images", variant="primary")
599
+
600
+ with gr.Column():
601
+ gr.Markdown("### 📄 Results")
602
+ images_output = gr.Gallery(label="Images", columns=3, rows=2, object_fit="contain", height=500)
603
+ tags_output = gr.Textbox(label="Tags", placeholder="Select an image to show tags", lines=5, show_copy_button=True)
604
+ post_url_output = gr.Textbox(label="Post URL", lines=1, show_copy_button=True)
605
+ image_url_output = gr.Textbox(label="Image URL", lines=1, show_copy_button=True)
606
+
607
+ # State to store tags, URLs
608
+ tags_state = gr.State([])
609
+ post_url_state = gr.State([])
610
+ image_url_state = gr.State([])
611
+
612
+ submit_btn.click(
613
+ fn=gelbooru_gradio,
614
+ inputs=[OR_tags, AND_tags, exclude_tags, score, count, Safe, Questionable, Explicit, site], # add 'api_key' and 'user_id' if necessary
615
+ outputs=[images_output, tags_state, post_url_state, image_url_state],
616
+ )
617
+
618
+ images_output.select(
619
+ fn=on_select,
620
+ inputs=[tags_state, post_url_state, image_url_state],
621
+ outputs=[tags_output, post_url_output, image_url_output],
622
+ )
623
+ gr.Markdown("""
624
+ ---
625
+ ComfyUI version: [Comfyui-Gelbooru](https://github.com/1mckw/Comfyui-Gelbooru)
626
+ """)
627
+ with gr.Tab(label="Categorizer++"):
628
+ with gr.Row():
629
+ with gr.Column(variant="panel"):
630
+ input_tags = gr.Textbox(label="Input Tags", placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...", lines=4)
631
+ submit_button = gr.Button(value="Submit", variant="primary", size="lg")
632
+ with gr.Column(variant="panel"):
633
+ categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
634
+ categorized_json = gr.JSON(label="Categorized (tags) - JSON")
635
+ submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
636
+ with gr.Column(variant="panel"):
637
+ pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
638
+ enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
639
+ prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
640
+ pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
641
  demo.queue(max_size=2).launch()