thliang01's picture
Update app.py
ea295fd verified
import gradio as gr
import requests
import base64
import os
import io
from PIL import Image
# --- OpenAI API Call Logic ---
def call_openai_image_api(prompt: str, api_key: str, input_image: Image.Image | None = None):
"""
Calls the appropriate OpenAI Image API (generation or edit)
based on whether an input image is provided.
Args:
prompt: The text prompt for image generation or editing.
api_key: The OpenAI API key.
input_image: A PIL Image object for editing, or None for generation.
Returns:
A tuple containing:
- original_image (PIL.Image or None): The original image if editing, else None.
- result_image (PIL.Image or None): The generated/edited image, or None on error.
- status_message (str): A message indicating success or error details.
"""
if not api_key:
return None, None, "Error: OpenAI API Key is missing."
if not prompt:
return None, None, "Error: Prompt cannot be empty."
headers = {"Authorization": f"Bearer {api_key}"}
# Hypothetical model name from the original code. Replace with "dall-e-2" or "dall-e-3" if needed.
model = "gpt-image-1" # Using the model specified in the original code
size = "1024x1024"
response = None # Initialize response variable
try:
if input_image:
# --- Image Editing ---
if not isinstance(input_image, Image.Image):
return None, None, "Error: Invalid image provided for editing."
# Convert PIL Image to bytes for the API request
byte_stream = io.BytesIO()
input_image.save(byte_stream, format="PNG") # Save PIL image to bytes buffer [[1]]
byte_stream.seek(0) # Rewind buffer to the beginning
files = {
"image": ("input_image.png", byte_stream, "image/png"),
}
# CORRECTED data dictionary: removed 'response_format'
data = {
"prompt": prompt,
"model": model,
"size": size,
# "response_format": "b64_json", # <-- THIS LINE IS REMOVED
}
api_url = "https://api.openai.com/v1/images/edits"
print("Calling OpenAI Image Edit API...") # Debug print
response = requests.post(api_url, headers=headers, files=files, data=data)
else:
# --- Image Generation ---
# (This part remains the same as it uses response_format correctly via json payload)
headers["Content-Type"] = "application/json"
payload = {
"prompt": prompt,
"model": model,
"response_format": "b64_json", # Keep this for generation
"size": size,
"n": 1, # Generate one image
}
api_url = "https://api.openai.com/v1/images/generations"
print("Calling OpenAI Image Generation API...") # Debug print
response = requests.post(api_url, headers=headers, json=payload)
print(f"API Response Status Code: {response.status_code}") # Debug print
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
# Process successful response
response_data = response.json()
# Ensure the expected data structure is present
if not response_data.get("data") or not isinstance(response_data["data"], list) or len(response_data["data"]) == 0:
return input_image, None, f"Error: Unexpected API response format - 'data' array missing or empty: {response_data}"
if not response_data["data"][0].get("b64_json"):
return input_image, None, f"Error: Unexpected API response format - 'b64_json' key missing: {response_data}"
img_b64 = response_data["data"][0]["b64_json"]
img_bytes = base64.b64decode(img_b64) # Decode base64 string [[1]]
result_image = Image.open(io.BytesIO(img_bytes)) # Convert bytes to PIL Image [[1]]
print("Image processed successfully.") # Debug print
return input_image, result_image, "Success!"
except requests.exceptions.RequestException as e:
error_message = f"API Request Error: {e}"
# Check if response exists before trying to access its attributes/methods
if response is not None:
try:
# Attempt to get more specific error from OpenAI response
error_detail = response.json()
error_message += f"\nAPI Error Details: {error_detail}"
except requests.exceptions.JSONDecodeError:
# Fallback if response is not JSON
error_message += f"\nRaw Response Text: {response.text}"
except Exception as json_e:
error_message += f"\nError parsing JSON response: {json_e}\nRaw Response Text: {response.text}"
print(error_message) # Debug print
return input_image, None, error_message
except Exception as e:
error_message = f"An unexpected error occurred: {e}"
print(error_message) # Debug print
return input_image, None, error_message
# --- Gradio Interface Setup ---
# Check for API key in environment variables
api_key_env = os.environ.get("OPENAI_API_KEY")
api_key_present_info = "OpenAI API Key found in environment variables." if api_key_env else "OpenAI API Key not found in environment variables. Please enter it below."
def process_image_request(prompt_input, api_key_input, uploaded_image):
"""
Wrapper function for Gradio interface.
Determines the API key to use (input field first, then environment variable).
Calls the main API function.
"""
# Prioritize the API key entered in the input field
final_api_key = api_key_input if api_key_input else api_key_env
# Call the actual API logic
original_img, result_img, status = call_openai_image_api(prompt_input, final_api_key, uploaded_image)
# Return values for the Gradio output components
# If generating (original_img is None), return None for the original image display
# If editing, return the uploaded image (original_img) for the original image display
return original_img, result_img, status
# Build the Gradio interface using Blocks for more layout control [[7]]
with gr.Blocks() as demo:
gr.Markdown("# OpenAI GPT-Image-1 Text-to-Image Demo") # App title
gr.Markdown("Enter a prompt to generate an image, or upload an image and enter a prompt to edit it.")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Image Description (Prompt)", lines=3, placeholder="e.g., A futuristic cityscape at sunset") # Text input for prompt
gr.Markdown(f"*{api_key_present_info}*")
api_key_input = gr.Textbox(label="OpenAI API Key", type="password", placeholder="Enter your key if not set in environment") # Password input for API key
uploaded_image_input = gr.Image(type="pil", label="Upload Image to Edit (Optional)") # Image upload [[4]]
submit_button = gr.Button("Generate / Edit Image")
with gr.Column(scale=2):
status_output = gr.Textbox(label="Status", interactive=False)
with gr.Row():
original_image_output = gr.Image(label="Original Image", interactive=False)
result_image_output = gr.Image(label="Generated / Edited Image", interactive=False) # Display output image
# Connect the button click event to the processing function
submit_button.click(
fn=process_image_request,
inputs=[prompt_input, api_key_input, uploaded_image_input],
outputs=[original_image_output, result_image_output, status_output]
)
# Launch the Gradio app [[2]]
if __name__ == "__main__":
demo.launch()