|
import numpy as np
|
|
from diffusers import FluxFillPipeline
|
|
from PIL import Image
|
|
|
|
class GenExWorldInitializerPipeline(FluxFillPipeline):
|
|
def precompute_rotation_matrix(self, rx, ry, rz):
|
|
rx = np.deg2rad(rx)
|
|
ry = np.deg2rad(ry)
|
|
rz = np.deg2rad(rz)
|
|
|
|
Rx = np.array([
|
|
[1, 0, 0],
|
|
[0, np.cos(rx), -np.sin(rx)],
|
|
[0, np.sin(rx), np.cos(rx)]
|
|
])
|
|
|
|
Ry = np.array([
|
|
[np.cos(ry), 0, np.sin(ry)],
|
|
[0, 1, 0],
|
|
[-np.sin(ry), 0, np.cos(ry)]
|
|
])
|
|
|
|
Rz = np.array([
|
|
[np.cos(rz), -np.sin(rz), 0],
|
|
[np.sin(rz), np.cos(rz), 0],
|
|
[0, 0, 1]
|
|
])
|
|
|
|
R = Rz @ Ry @ Rx
|
|
return R
|
|
|
|
def cubemap_to_equirectangular(self, cubemap_faces, output_width, output_height, scale_factor=2):
|
|
scaled_output_width = output_width * scale_factor
|
|
scaled_output_height = output_height * scale_factor
|
|
|
|
rx, ry, rz = 90, -90, 180
|
|
R = self.precompute_rotation_matrix(rx, ry, rz)
|
|
|
|
x = np.linspace(0, scaled_output_width - 1, scaled_output_width)
|
|
y = np.linspace(0, scaled_output_height - 1, scaled_output_height)
|
|
xv, yv = np.meshgrid(x, y)
|
|
|
|
theta = (xv / scaled_output_width) * 2 * np.pi - np.pi
|
|
phi = (yv / scaled_output_height) * np.pi - (np.pi / 2)
|
|
|
|
xs = np.cos(phi) * np.cos(theta)
|
|
ys = np.cos(phi) * np.sin(theta)
|
|
zs = np.sin(phi)
|
|
|
|
def apply_rotation(x, y, z):
|
|
return R @ np.array([x, y, z])
|
|
|
|
xs, ys, zs = apply_rotation(xs.flatten(), ys.flatten(), zs.flatten())
|
|
xs = xs.reshape((scaled_output_height, scaled_output_width))
|
|
ys = ys.reshape((scaled_output_height, scaled_output_width))
|
|
zs = zs.reshape((scaled_output_height, scaled_output_width))
|
|
|
|
abs_x, abs_y, abs_z = np.abs(xs), np.abs(ys), np.abs(zs)
|
|
face_indices = np.argmax(np.stack([abs_x, abs_y, abs_z], axis=-1), axis=-1)
|
|
|
|
equirectangular_pixels = np.zeros((scaled_output_height, scaled_output_width, 3), dtype=np.uint8)
|
|
|
|
for face_name, face_image in cubemap_faces.items():
|
|
face_image = np.array(face_image)
|
|
if face_name == 'right':
|
|
mask = (face_indices == 0) & (xs > 0)
|
|
u = (-zs[mask] / abs_x[mask] + 1) / 2
|
|
v = (ys[mask] / abs_x[mask] + 1) / 2
|
|
elif face_name == 'left':
|
|
mask = (face_indices == 0) & (xs < 0)
|
|
u = (zs[mask] / abs_x[mask] + 1) / 2
|
|
v = (ys[mask] / abs_x[mask] + 1) / 2
|
|
elif face_name == 'bottom':
|
|
mask = (face_indices == 1) & (ys > 0)
|
|
u = (xs[mask] / abs_y[mask] + 1) / 2
|
|
v = (-zs[mask] / abs_y[mask] + 1) / 2
|
|
elif face_name == 'top':
|
|
mask = (face_indices == 1) & (ys < 0)
|
|
u = (xs[mask] / abs_y[mask] + 1) / 2
|
|
v = (zs[mask] / abs_y[mask] + 1) / 2
|
|
elif face_name == 'front':
|
|
mask = (face_indices == 2) & (zs > 0)
|
|
u = (xs[mask] / abs_z[mask] + 1) / 2
|
|
v = (ys[mask] / abs_z[mask] + 1) / 2
|
|
elif face_name == 'back':
|
|
mask = (face_indices == 2) & (zs < 0)
|
|
u = (-xs[mask] / abs_z[mask] + 1) / 2
|
|
v = (ys[mask] / abs_z[mask] + 1) / 2
|
|
|
|
face_height, face_width, _ = face_image.shape
|
|
u_pixel = np.clip((u * face_width).astype(int), 0, face_width - 1)
|
|
v_pixel = np.clip((v * face_height).astype(int), 0, face_height - 1)
|
|
|
|
mask = mask.astype(bool)
|
|
|
|
masked_yv = yv[mask]
|
|
masked_xv = xv[mask]
|
|
|
|
masked_yv = masked_yv.astype(int)
|
|
masked_xv = masked_xv.astype(int)
|
|
|
|
equirectangular_pixels[masked_yv, masked_xv] = face_image[v_pixel, u_pixel]
|
|
|
|
equirectangular_image = Image.fromarray(equirectangular_pixels)
|
|
|
|
if scale_factor > 1:
|
|
equirectangular_image = equirectangular_image.resize((output_width, output_height), Image.LANCZOS)
|
|
|
|
return equirectangular_image
|
|
|
|
def preprocess_image(self, image: Image.Image) -> Image.Image:
|
|
w, h = image.size
|
|
side = min(w, h)
|
|
left = (w - side) // 2
|
|
top = (h - side) // 2
|
|
img = image.crop((left, top, left + side, top + side))
|
|
front = img.resize((512, 512))
|
|
|
|
cubes = {}
|
|
cubes['front'] = front
|
|
cubes['back'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['left'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['right'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['top'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['bottom'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
|
|
input_panorama = self.cubemap_to_equirectangular(cubes, 2048, 1024, scale_factor=2)
|
|
|
|
return front, input_panorama
|
|
|
|
|
|
def preprocess_mask(self) -> Image.Image:
|
|
mask = Image.open("pano_mask.png").convert("L")
|
|
return mask.resize((2048, 1024))
|
|
|
|
def create_mask(self) -> Image.Image:
|
|
cubes = {}
|
|
cubes['front'] = Image.new("RGB", (512, 512), (0, 0, 0))
|
|
cubes['back'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['left'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['right'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['top'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
cubes['bottom'] = Image.new("RGB", (512, 512), (255, 255, 255))
|
|
|
|
mask = self.cubemap_to_equirectangular(cubes, 2048, 1024, scale_factor=1)
|
|
|
|
mask = mask.convert("L")
|
|
|
|
return mask
|
|
|
|
|
|
def __call__(
|
|
self,
|
|
image: Image.Image,
|
|
prompt: str = None,
|
|
guidance_scale: float = 3.5,
|
|
):
|
|
front, img = self.preprocess_image(image)
|
|
|
|
mask = self.create_mask()
|
|
|
|
|
|
if prompt:
|
|
prompt = 'GenEx Panoramic World Initialization, ' + prompt
|
|
else:
|
|
prompt = 'GenEx Panoramic World Initialization'
|
|
|
|
return front, super().__call__(
|
|
prompt=prompt,
|
|
image=img,
|
|
mask_image=mask,
|
|
guidance_scale=guidance_scale,
|
|
width=2048,
|
|
height=1024,
|
|
) |