Hhhh / imagegen_api.py
Hjgugugjhuhjggg's picture
Upload 28 files
e83e49f verified
raw
history blame
1.12 kB
import os
from flask import jsonify, send_file, request
from io import BytesIO
from PIL import Image
from main import *
import torch
def generate_image(prompt, output_path="output_image.png"):
if imagegen_model is None:
return "Image generation model not initialized."
generator = torch.Generator(device=device).manual_seed(0)
with torch.no_grad():
image = imagegen_model(
prompt,
generator=generator,
).images[0]
image.save(output_path)
return output_path
def imagegen_api():
data = request.get_json()
prompt = data.get('prompt')
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
output_file = generate_image(prompt)
if output_file == "Image generation model not initialized.":
return jsonify({"error": "Image generation failed"}), 500
image_io = BytesIO()
pil_image = Image.open(output_file)
pil_image.save(image_io, 'PNG')
image_io.seek(0)
return send_file(image_io, mimetype='image/png', as_attachment=True, download_name="output.png")