import gc # get socket and check if the name is vgldgx01 import socket if socket.gethostname() != "vgldgx01": import spaces #[uncomment to use ZeroGPU] import numpy as np import PIL.Image import torch from controlnet_aux.util import HWC3 from diffusers import ( ControlNetModel, DiffusionPipeline, StableDiffusionControlNetPipeline, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler, DDIMScheduler, #rgb2x ) import torchvision from torchvision import transforms from cv_utils import resize_image from preprocessor import Preprocessor from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES from tqdm.auto import tqdm import subprocess from rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline from app_texnet import image_to_temp_path import os import time import tempfile from text2tex.scripts.generate_texture import text2tex_call, init_args from glob import glob CONTROLNET_MODEL_IDS = { # "Openpose": "lllyasviel/control_v11p_sd15_openpose", # "Canny": "lllyasviel/control_v11p_sd15_canny", # "MLSD": "lllyasviel/control_v11p_sd15_mlsd", # "scribble": "lllyasviel/control_v11p_sd15_scribble", # "softedge": "lllyasviel/control_v11p_sd15_softedge", # "segmentation": "lllyasviel/control_v11p_sd15_seg", # "depth": "lllyasviel/control_v11f1p_sd15_depth", # "NormalBae": "lllyasviel/control_v11p_sd15_normalbae", # "lineart": "lllyasviel/control_v11p_sd15_lineart", # "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime", # "shuffle": "lllyasviel/control_v11e_sd15_shuffle", # "ip2p": "lllyasviel/control_v11e_sd15_ip2p", # "inpaint": "lllyasviel/control_v11e_sd15_inpaint", # "texnet": "/home/jyang/projects/ObjectReal/logs/train_texnet_deploy/checkpoint-55000/controlnet" # load and call "texnet": "jingyangcarl/texnet", } def download_all_controlnet_weights() -> None: for model_id in CONTROLNET_MODEL_IDS.values(): ControlNetModel.from_pretrained(model_id) class Model: def __init__( self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny" ) -> None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.base_model_id = "" self.task_name = "" self.pipe = self.load_pipe(base_model_id, task_name) self.pipe_base = StableDiffusionImg2ImgPipeline.from_pretrained( 'runwayml/stable-diffusion-v1-5', safety_checker=None, torch_dtype=torch.float16 ).to(self.device) self.preprocessor = Preprocessor() # set up pipe_rgb2x self.pipe_rgb2x = StableDiffusionAOVMatEstPipeline.from_pretrained( "zheng95z/rgb-to-x", torch_dtype=torch.float16, ).to(self.device) self.pipe_rgb2x.scheduler = DDIMScheduler.from_config( self.pipe_rgb2x.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" ) self.pipe_rgb2x.set_progress_bar_config(disable=True) # setup blender self.blender_path = '/tmp/blender-3.2.2-linux-x64/blender' if not os.path.exists(self.blender_path): print("Downloading Blender...") subprocess.run(["wget", "https://download.blender.org/release/Blender3.2/blender-3.2.2-linux-x64.tar.xz", "-O", "/tmp/blender-3.2.2-linux-x64.tar.xz"], check=True) subprocess.run(["tar", "-xf", "/tmp/blender-3.2.2-linux-x64.tar.xz", "-C", "/tmp"], check=True) print("Blender downloaded and extracted.") def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline: if ( base_model_id == self.base_model_id and task_name == self.task_name and hasattr(self, "pipe") and self.pipe is not None ): return self.pipe model_id = CONTROLNET_MODEL_IDS[task_name] controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) to_upload = False if to_upload: # confirm before uploading confirm = input(f"Do you want to upload {model_id} to the hub? (y/n): ") if confirm.lower() == "y": controlnet.push_to_hub("jingyangcarl/texnet") else: print("Upload cancelled.") pipe = StableDiffusionControlNetPipeline.from_pretrained( base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.to(self.device) if self.device.type == "cuda": import os if os.environ.get("SPACES_ZERO_GPU", "0") == "1": # when running on ZeroGPU, enable CPU offload # pipe.enable_xformers_memory_efficient_attention() doens't work # pipe.enable_model_cpu_offload() pass else: pipe.enable_xformers_memory_efficient_attention() torch.cuda.empty_cache() gc.collect() self.base_model_id = base_model_id self.task_name = task_name return pipe def set_base_model(self, base_model_id: str) -> str: if not base_model_id or base_model_id == self.base_model_id: return self.base_model_id del self.pipe torch.cuda.empty_cache() gc.collect() try: self.pipe = self.load_pipe(base_model_id, self.task_name) except Exception: # noqa: BLE001 self.pipe = self.load_pipe(self.base_model_id, self.task_name) return self.base_model_id def load_controlnet_weight(self, task_name: str) -> None: if task_name == self.task_name: return if self.pipe is not None and hasattr(self.pipe, "controlnet"): del self.pipe.controlnet torch.cuda.empty_cache() gc.collect() model_id = CONTROLNET_MODEL_IDS[task_name] controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) controlnet.to(self.device) torch.cuda.empty_cache() gc.collect() self.pipe.controlnet = controlnet self.task_name = task_name def get_prompt(self, prompt: str, additional_prompt: str) -> str: return additional_prompt if not prompt else f"{prompt}, {additional_prompt}" # @spaces.GPU #[uncomment to use ZeroGPU] @torch.autocast("cuda") def run_pipe( self, prompt: str, negative_prompt: str, control_image: PIL.Image.Image, num_images: int, num_steps: int, guidance_scale: float, seed: int, ) -> list[PIL.Image.Image]: generator = torch.Generator().manual_seed(seed) # self.pipe.to(self.device) return self.pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_images_per_prompt=num_images, num_inference_steps=num_steps, generator=generator, image=control_image, ).images # @spaces.GPU #[uncomment to use ZeroGPU] @torch.inference_mode() def process_texnet( self, obj_name: str, represented_image: np.ndarray | None, # not used image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, num_steps: int, guidance_scale: float, seed: int, low_threshold: int, high_threshold: int, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError prompt_nospace = prompt.replace(' ', '_') # self.preprocessor.load("texnet") # control_image = self.preprocessor( # image=image, low_threshold=low_threshold, high_threshold=high_threshold, image_resolution=image_resolution, output_type="pil" # ) # self.load_controlnet_weight("texnet") # tex_coarse = self.run_pipe( # prompt=self.get_prompt(prompt, additional_prompt), # negative_prompt=negative_prompt, # control_image=control_image, # num_images=num_images, # num_steps=num_steps, # guidance_scale=guidance_scale, # seed=seed, # ) # # use img2img pipeline # self.pipe_backup = self.pipe # self.pipe = self.pipe_base # # refine tex_fine = [] mesh_fine = [] # for result_coarse in tex_coarse: # # clean up GPU cache # torch.cuda.empty_cache() # gc.collect() # # masking # mask = (np.array(control_image).sum(axis=-1) == 0)[...,None] # image_masked = PIL.Image.fromarray(np.where(mask, control_image, result_coarse)) # image_blurry = transforms.GaussianBlur(kernel_size=5, sigma=1)(image_masked) # result_fine = self.run_pipe( # # prompt=prompt, # prompt=self.get_prompt(prompt, additional_prompt), # negative_prompt=negative_prompt, # control_image=image_blurry, # num_images=1, # num_steps=num_steps, # guidance_scale=guidance_scale, # seed=seed, # )[0] # result_fine = PIL.Image.fromarray(np.where(mask, control_image, result_fine)) # tex_fine.append(result_fine) temp_out_path = tempfile.mkdtemp() temp_out_path = 'output' # put text2tex here, args = init_args() args.input_dir = f'examples/{obj_name}/' args.output_dir = os.path.join(temp_out_path, f'{obj_name}/{prompt_nospace}') args.obj_name = obj_name args.obj_file = 'mesh.obj' args.prompt = f'{prompt} {obj_name}' args.add_view_to_prompt = True args.ddim_steps = 5 # args.ddim_steps = 50 args.new_strength = 1.0 args.update_strength = 0.3 args.view_threshold = 0.1 args.blend = 0 args.dist = 1 args.num_viewpoints = 2 # args.num_viewpoints = 36 args.viewpoint_mode = 'predefined' args.use_principle = True args.update_steps = 2 # args.update_steps = 20 args.update_mode = 'heuristic' args.seed = 42 args.post_process = True args.device = '2080' args.uv_size = 1000 args.image_size = 512 # args.image_size = 768 args.use_objaverse = True # assume the mesh is normalized with y-axis as up output_dir = text2tex_call(args) # get the texture and mesh with underscore '_post', which is the id of the last mesh, should be good for the visual post_idx = glob(os.path.join(output_dir, 'update', 'mesh', "*_post.png"))[0].split('/')[-1].split('_')[0] tex_fine.append(PIL.Image.open(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.png")).convert("RGB")) mesh_fine.append(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.obj")) torch.cuda.empty_cache() # restore the original pipe # self.pipe = self.pipe_backup # use rgb2x for now for generating the texture def rgb2x( pipeline, photo, inference_step = 50, num_samples = 1, ): generator = torch.Generator(device="cuda").manual_seed(seed) # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop old_height = photo.shape[1] old_width = photo.shape[2] new_height = old_height new_width = old_width radio = old_height / old_width max_side = 1000 if old_height > old_width: new_height = max_side new_width = int(new_height / radio) else: new_width = max_side new_height = int(new_width * radio) if new_width % 8 != 0 or new_height % 8 != 0: new_width = new_width // 8 * 8 new_height = new_height // 8 * 8 photo = torchvision.transforms.Resize((new_height, new_width))(photo) required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] prompts = { "albedo": "Albedo (diffuse basecolor)", "normal": "Camera-space Normal", "roughness": "Roughness", "metallic": "Metallicness", "irradiance": "Irradiance (diffuse lighting)", } return_list = [] for i in tqdm(range(num_samples), desc="Running Pipeline", leave=False): for aov_name in required_aovs: prompt = prompts[aov_name] generated_image = pipeline( prompt=prompt, photo=photo, num_inference_steps=inference_step, height=new_height, width=new_width, generator=generator, required_aovs=[aov_name], ).images[0][0] generated_image = torchvision.transforms.Resize( (old_height, old_width) )(generated_image) # generated_image = (generated_image, f"Generated {aov_name} {i}") # generated_image = (generated_image, f"{aov_name}") return_list.append(generated_image) return photo, return_list, prompts # Load rgb2x pipeline _, preds, prompts = rgb2x(self.pipe_rgb2x, torchvision.transforms.PILToTensor()(tex_fine[0]).to(self.pipe.device), inference_step=num_steps, num_samples=num_images) intrinsic_dir = os.path.join(output_dir, 'intrinsic') use_text2tex = True if use_text2tex: base_color_path = image_to_temp_path(tex_fine[0], "base_color", out_dir=intrinsic_dir) normal_map_path = image_to_temp_path(preds[0], "normal_map", out_dir=intrinsic_dir) roughness_path = image_to_temp_path(preds[1], "roughness", out_dir=intrinsic_dir) metallic_path = image_to_temp_path(preds[2], "metallic", out_dir=intrinsic_dir) else: base_color_path = image_to_temp_path(tex_fine[0].rotate(90), "base_color", out_dir=intrinsic_dir) normal_map_path = image_to_temp_path(preds[0].rotate(90), "normal_map", out_dir=intrinsic_dir) roughness_path = image_to_temp_path(preds[1].rotate(90), "roughness", out_dir=intrinsic_dir) metallic_path = image_to_temp_path(preds[2].rotate(90), "metallic", out_dir=intrinsic_dir) current_timecode = time.strftime("%Y%m%d_%H%M%S") # output_blend_path = os.path.join(os.getcwd(), "output", f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path output_blend_path = os.path.join(tempfile.mkdtemp(), f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path os.makedirs(os.path.dirname(output_blend_path), exist_ok=True) def run_blend_generation( blender_path, generate_script_path, obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend ): cmd = [ blender_path, "--background", "--python", generate_script_path, "--", obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend ] subprocess.run(cmd, check=True) # check if the blender_path exists, if not download run_blend_generation( blender_path=self.blender_path, generate_script_path="rgb2x/generate_blend.py", # obj_path=f"examples/{obj_name}/mesh.obj", # replace with actual mesh path obj_path=mesh_fine[0], # replace with actual mesh path base_color_path=base_color_path, normal_map_path=normal_map_path, roughness_path=roughness_path, metallic_path=metallic_path, output_blend=output_blend_path # replace with desired output path ) # gallary return [*tex_fine], [preds[1]], [preds[2]], [preds[3]], [output_blend_path] # @spaces.GPU #[uncomment to use ZeroGPU] @torch.inference_mode() def process_canny( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, num_steps: int, guidance_scale: float, seed: int, low_threshold: int, high_threshold: int, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError self.preprocessor.load("Canny") control_image = self.preprocessor( image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution ) self.load_controlnet_weight("Canny") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_mlsd( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, value_threshold: float, distance_threshold: float, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError self.preprocessor.load("MLSD") control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, thr_v=value_threshold, thr_d=distance_threshold, ) self.load_controlnet_weight("MLSD") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_scribble( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name == "None": image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) elif preprocessor_name == "HED": self.preprocessor.load(preprocessor_name) control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, scribble=False, ) elif preprocessor_name == "PidiNet": self.preprocessor.load(preprocessor_name) control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, safe=False, ) self.load_controlnet_weight("scribble") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_scribble_interactive( self, image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, num_steps: int, guidance_scale: float, seed: int, ) -> list[PIL.Image.Image]: if image_and_mask is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError image = 255 - image_and_mask["composite"] # type: ignore image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) self.load_controlnet_weight("scribble") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_softedge( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name == "None": image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) elif preprocessor_name in ["HED", "HED safe"]: safe = "safe" in preprocessor_name self.preprocessor.load("HED") control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, scribble=safe, ) elif preprocessor_name in ["PidiNet", "PidiNet safe"]: safe = "safe" in preprocessor_name self.preprocessor.load("PidiNet") control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, safe=safe, ) else: raise ValueError self.load_controlnet_weight("softedge") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_openpose( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name == "None": image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) else: self.preprocessor.load("Openpose") control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, hand_and_face=True, ) self.load_controlnet_weight("Openpose") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_segmentation( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name == "None": image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) else: self.preprocessor.load(preprocessor_name) control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) self.load_controlnet_weight("segmentation") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_depth( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name == "None": image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) else: self.preprocessor.load(preprocessor_name) control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) self.load_controlnet_weight("depth") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_normal( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name == "None": image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) else: self.preprocessor.load("NormalBae") control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) self.load_controlnet_weight("NormalBae") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_lineart( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, preprocess_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name in ["None", "None (anime)"]: image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) elif preprocessor_name in ["Lineart", "Lineart coarse"]: coarse = "coarse" in preprocessor_name self.preprocessor.load("Lineart") control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, coarse=coarse, ) elif preprocessor_name == "Lineart (anime)": self.preprocessor.load("LineartAnime") control_image = self.preprocessor( image=image, image_resolution=image_resolution, detect_resolution=preprocess_resolution, ) if "anime" in preprocessor_name: self.load_controlnet_weight("lineart_anime") else: self.load_controlnet_weight("lineart") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_shuffle( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, num_steps: int, guidance_scale: float, seed: int, preprocessor_name: str, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError if preprocessor_name == "None": image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) else: self.preprocessor.load(preprocessor_name) control_image = self.preprocessor( image=image, image_resolution=image_resolution, ) self.load_controlnet_weight("shuffle") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results] @torch.inference_mode() def process_ip2p( self, image: np.ndarray, prompt: str, additional_prompt: str, negative_prompt: str, num_images: int, image_resolution: int, num_steps: int, guidance_scale: float, seed: int, ) -> list[PIL.Image.Image]: if image is None: raise ValueError if image_resolution > MAX_IMAGE_RESOLUTION: raise ValueError if num_images > MAX_NUM_IMAGES: raise ValueError image = HWC3(image) image = resize_image(image, resolution=image_resolution) control_image = PIL.Image.fromarray(image) self.load_controlnet_weight("ip2p") results = self.run_pipe( prompt=self.get_prompt(prompt, additional_prompt), negative_prompt=negative_prompt, control_image=control_image, num_images=num_images, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, ) return [control_image, *results]