|
import base64 |
|
import os |
|
import re |
|
import sqlite3 |
|
import tempfile |
|
import uuid |
|
from io import BytesIO |
|
from typing import Dict, List, Optional |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
from pillow_lut import load_cube_file |
|
from fastapi import FastAPI, HTTPException, UploadFile, File |
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
from starlette.middleware.cors import CORSMiddleware |
|
|
|
from ai import generate_cube |
|
|
|
app = FastAPI(title="LUT Transformation API", version="1.0.0") |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
class LUTRequest(BaseModel): |
|
file_id: str |
|
user_prompt: str |
|
|
|
|
|
class LUTTransformRequest(BaseModel): |
|
file_id: str |
|
user_prompt: str |
|
|
|
|
|
class CubeFileResponse(BaseModel): |
|
file_id: str |
|
file_name: str |
|
|
|
|
|
class CubeFileListItem(BaseModel): |
|
file_id: str |
|
file_name: str |
|
upload_date: str |
|
|
|
|
|
DATABASE_PATH = "cube_files.db" |
|
|
|
|
|
def init_database(): |
|
"""Initialize SQLite database and create tables""" |
|
conn = sqlite3.connect(DATABASE_PATH) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute( |
|
""" |
|
CREATE TABLE IF NOT EXISTS cube_files ( |
|
id TEXT PRIMARY KEY, |
|
file_name TEXT NOT NULL, |
|
file_data BLOB NOT NULL, |
|
upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
) |
|
""" |
|
) |
|
|
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
def save_cube_file_to_db(file_name: str, file_data: bytes) -> str: |
|
"""Save cube file to database and return file ID""" |
|
file_id = str(uuid.uuid4()) |
|
conn = sqlite3.connect(DATABASE_PATH) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute( |
|
"INSERT INTO cube_files (id, file_name, file_data) VALUES (?, ?, ?)", |
|
(file_id, file_name, file_data), |
|
) |
|
|
|
conn.commit() |
|
conn.close() |
|
return file_id |
|
|
|
|
|
def get_cube_file_from_db(file_id: str) -> Optional[tuple]: |
|
"""Retrieve cube file from database by ID""" |
|
conn = sqlite3.connect(DATABASE_PATH) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute( |
|
"SELECT file_name, file_data FROM cube_files WHERE id = ?", (file_id,) |
|
) |
|
|
|
result = cursor.fetchone() |
|
conn.close() |
|
return result |
|
|
|
|
|
def list_cube_files_from_db() -> List[tuple]: |
|
"""List all cube files from database""" |
|
conn = sqlite3.connect(DATABASE_PATH) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute( |
|
"SELECT id, file_name, upload_date FROM cube_files ORDER BY upload_date DESC" |
|
) |
|
|
|
results = cursor.fetchall() |
|
conn.close() |
|
return results |
|
|
|
|
|
class LUTTransformer: |
|
def __init__(self): |
|
self.title = "" |
|
self.size = 0 |
|
self.lut_data = [] |
|
|
|
def parse_cube_file(self, filepath: str) -> bool: |
|
"""Parse .cube file and extract LUT data""" |
|
try: |
|
with open(filepath, "r") as file: |
|
lines = file.readlines() |
|
|
|
self.lut_data = [] |
|
|
|
for line in lines: |
|
line = line.strip() |
|
|
|
if not line or line.startswith("#"): |
|
continue |
|
|
|
if line.startswith("TITLE"): |
|
self.title = line.split('"')[1] if '"' in line else line.split()[1] |
|
|
|
elif line.startswith("LUT_3D_SIZE"): |
|
self.size = int(line.split()[1]) |
|
|
|
else: |
|
rgb_match = re.findall(r"[\d.]+", line) |
|
if len(rgb_match) >= 3: |
|
r, g, b = map(float, rgb_match[:3]) |
|
self.lut_data.append([r, g, b]) |
|
|
|
return len(self.lut_data) > 0 |
|
|
|
except Exception as e: |
|
print(f"Error parsing cube file: {e}") |
|
return False |
|
|
|
def apply_json_transformation(self, json_adjustments: Dict) -> bool: |
|
"""Apply JSON color adjustments to LUT data""" |
|
try: |
|
lut_array = np.array(self.lut_data) |
|
|
|
for i, (r, g, b) in enumerate(lut_array): |
|
luminance = 0.299 * r + 0.587 * g + 0.114 * b |
|
|
|
if luminance < 0.33: |
|
if "shadows" in json_adjustments: |
|
adj = json_adjustments["shadows"] |
|
lut_array[i] *= [ |
|
adj.get("r", 1.0), |
|
adj.get("g", 1.0), |
|
adj.get("b", 1.0), |
|
] |
|
|
|
elif luminance < 0.66: |
|
if "midtones" in json_adjustments: |
|
adj = json_adjustments["midtones"] |
|
lut_array[i] *= [ |
|
adj.get("r", 1.0), |
|
adj.get("g", 1.0), |
|
adj.get("b", 1.0), |
|
] |
|
|
|
else: |
|
if "highlights" in json_adjustments: |
|
adj = json_adjustments["highlights"] |
|
lut_array[i] *= [ |
|
adj.get("r", 1.0), |
|
adj.get("g", 1.0), |
|
adj.get("b", 1.0), |
|
] |
|
|
|
if "glob" in json_adjustments: |
|
global_adj = json_adjustments["glob"] |
|
lut_array *= [ |
|
global_adj.get("r", 1.0), |
|
global_adj.get("g", 1.0), |
|
global_adj.get("b", 1.0), |
|
] |
|
|
|
lut_array = np.clip(lut_array, 0.0, 1.0) |
|
self.lut_data = lut_array.tolist() |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"Error applying transformation: {e}") |
|
return False |
|
|
|
def save_cube_file(self, output_path: str, new_title: str = None) -> bool: |
|
"""Save modified LUT as .cube file""" |
|
try: |
|
with open(output_path, "w") as file: |
|
title = new_title if new_title else f"{self.title}_modified" |
|
file.write(f'TITLE "{title}"\n') |
|
file.write(f"LUT_3D_SIZE {self.size}\n\n") |
|
|
|
for r, g, b in self.lut_data: |
|
file.write(f"{r:.6f} {g:.6f} {b:.6f}\n") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"Error saving cube file: {e}") |
|
return False |
|
|
|
|
|
def generate_new_cube(user_prompt: str) -> dict: |
|
""" |
|
Placeholder for AI function that generates JSON adjustments based on user prompt. |
|
This function should be replaced with the actual AI implementation. |
|
""" |
|
response = generate_cube(user_prompt) |
|
return response.model_dump(mode="json") |
|
|
|
|
|
def apply_lut_to_image(image_path: str, lut_path: str) -> np.ndarray: |
|
"""Apply LUT to image using pillow_lut""" |
|
try: |
|
lut = load_cube_file(lut_path) |
|
im = Image.open(image_path) |
|
result_image = im.filter(lut) |
|
|
|
result_array = np.array(result_image) |
|
return result_array |
|
|
|
except Exception as e: |
|
print(f"Error applying LUT to image: {e}") |
|
raise |
|
|
|
|
|
def create_split_preview( |
|
original_lut_path: str, new_lut_path: str, sample_image_path: str |
|
) -> str: |
|
"""Create a split preview image and return as base64""" |
|
try: |
|
original_processed = apply_lut_to_image(sample_image_path, original_lut_path) |
|
new_processed = apply_lut_to_image(sample_image_path, new_lut_path) |
|
|
|
height, width = original_processed.shape[:2] |
|
split_image = np.zeros_like(original_processed) |
|
|
|
mid_point = width // 2 |
|
split_image[:, :mid_point] = original_processed[:, :mid_point] |
|
split_image[:, mid_point:] = new_processed[:, mid_point:] |
|
|
|
cv2.line(split_image, (mid_point, 0), (mid_point, height), (255, 255, 255), 2) |
|
|
|
pil_image = Image.fromarray(split_image) |
|
|
|
buffer = BytesIO() |
|
pil_image.save(buffer, format="PNG") |
|
buffer.seek(0) |
|
|
|
base64_string = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
return base64_string |
|
|
|
except Exception as e: |
|
print(f"Error creating split preview: {e}") |
|
raise |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
init_database() |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "LUT Transformation API", "version": "1.0.0"} |
|
|
|
|
|
@app.post("/upload-cube", response_model=CubeFileResponse) |
|
async def upload_cube_file(file: UploadFile = File(...)): |
|
""" |
|
Upload a .cube file and save it to the database |
|
""" |
|
try: |
|
if not file.filename.endswith(".cube"): |
|
raise HTTPException(status_code=400, detail="Only .cube files are allowed") |
|
|
|
file_data = await file.read() |
|
|
|
file_id = save_cube_file_to_db(file.filename, file_data) |
|
|
|
return CubeFileResponse(file_id=file_id, file_name=file.filename) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error uploading file: {str(e)}") |
|
|
|
|
|
@app.get("/cube-files", response_model=List[CubeFileListItem]) |
|
async def list_cube_files(): |
|
""" |
|
List all uploaded cube files with their IDs and names |
|
""" |
|
try: |
|
files = list_cube_files_from_db() |
|
return [ |
|
CubeFileListItem( |
|
file_id=file_id, file_name=file_name, upload_date=upload_date |
|
) |
|
for file_id, file_name, upload_date in files |
|
] |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error listing files: {str(e)}") |
|
|
|
|
|
@app.post("/transform-lut") |
|
async def transform_lut(request: LUTTransformRequest): |
|
""" |
|
Transform a LUT based on file ID and user prompt, return split preview image |
|
""" |
|
try: |
|
file_data = get_cube_file_from_db(request.file_id) |
|
if not file_data: |
|
raise HTTPException(status_code=404, detail="Cube file not found") |
|
|
|
file_name, cube_data = file_data |
|
|
|
with tempfile.NamedTemporaryFile( |
|
mode="wb", suffix=".cube", delete=False |
|
) as temp_cube: |
|
temp_cube.write(cube_data) |
|
original_cube_path = temp_cube.name |
|
|
|
try: |
|
adjustments = generate_new_cube(request.user_prompt) |
|
transformer = LUTTransformer() |
|
if not transformer.parse_cube_file(original_cube_path): |
|
raise HTTPException(status_code=400, detail="Failed to parse cube file") |
|
|
|
if not transformer.apply_json_transformation(adjustments): |
|
raise HTTPException( |
|
status_code=500, detail="Failed to apply transformations" |
|
) |
|
|
|
with tempfile.NamedTemporaryFile( |
|
mode="w", suffix=".cube", delete=False |
|
) as temp_new_cube: |
|
new_cube_path = temp_new_cube.name |
|
|
|
if not transformer.save_cube_file( |
|
new_cube_path, f"{transformer.title}_AI_Modified" |
|
): |
|
raise HTTPException( |
|
status_code=500, detail="Failed to save new cube file" |
|
) |
|
|
|
sample_image_path = "static/sample.jpg" |
|
if not os.path.exists(sample_image_path): |
|
raise HTTPException(status_code=404, detail="Sample image not found") |
|
|
|
split_preview_base64 = create_split_preview( |
|
original_cube_path, new_cube_path, sample_image_path |
|
) |
|
|
|
return { |
|
"success": True, |
|
"message": "LUT transformation completed successfully", |
|
"file_name": file_name, |
|
"adjustments_applied": adjustments, |
|
"split_preview_base64": split_preview_base64, |
|
} |
|
|
|
finally: |
|
if os.path.exists(original_cube_path): |
|
os.unlink(original_cube_path) |
|
if "new_cube_path" in locals() and os.path.exists(new_cube_path): |
|
os.unlink(new_cube_path) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy", "sample_image_exists": os.path.exists("static/sample.jpg")} |
|
|