brestok commited on
Commit
642914f
·
1 Parent(s): b86a37b

Update Dockerfile to optimize application setup and streamline dependencies

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/LutAI.iml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/.venv" />
6
+ </content>
7
+ <orderEntry type="inheritedJdk" />
8
+ <orderEntry type="sourceFolder" forTests="false" />
9
+ </component>
10
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <Languages>
6
+ <language minSize="74" name="Python" />
7
+ </Languages>
8
+ </inspection_tool>
9
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
10
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
11
+ <option name="ignoredPackages">
12
+ <value>
13
+ <list size="123">
14
+ <item index="0" class="java.lang.String" itemvalue="motor" />
15
+ <item index="1" class="java.lang.String" itemvalue="rsa" />
16
+ <item index="2" class="java.lang.String" itemvalue="pybit" />
17
+ <item index="3" class="java.lang.String" itemvalue="PyYAML" />
18
+ <item index="4" class="java.lang.String" itemvalue="cffi" />
19
+ <item index="5" class="java.lang.String" itemvalue="marshmallow" />
20
+ <item index="6" class="java.lang.String" itemvalue="pyasn1" />
21
+ <item index="7" class="java.lang.String" itemvalue="requests" />
22
+ <item index="8" class="java.lang.String" itemvalue="exceptiongroup" />
23
+ <item index="9" class="java.lang.String" itemvalue="starlette" />
24
+ <item index="10" class="java.lang.String" itemvalue="certifi" />
25
+ <item index="11" class="java.lang.String" itemvalue="anyio" />
26
+ <item index="12" class="java.lang.String" itemvalue="urllib3" />
27
+ <item index="13" class="java.lang.String" itemvalue="uvicorn" />
28
+ <item index="14" class="java.lang.String" itemvalue="python-jose" />
29
+ <item index="15" class="java.lang.String" itemvalue="uvloop" />
30
+ <item index="16" class="java.lang.String" itemvalue="passlib" />
31
+ <item index="17" class="java.lang.String" itemvalue="websockets" />
32
+ <item index="18" class="java.lang.String" itemvalue="annotated-types" />
33
+ <item index="19" class="java.lang.String" itemvalue="watchfiles" />
34
+ <item index="20" class="java.lang.String" itemvalue="dnspython" />
35
+ <item index="21" class="java.lang.String" itemvalue="pydantic" />
36
+ <item index="22" class="java.lang.String" itemvalue="pymongo" />
37
+ <item index="23" class="java.lang.String" itemvalue="ecdsa" />
38
+ <item index="24" class="java.lang.String" itemvalue="packaging" />
39
+ <item index="25" class="java.lang.String" itemvalue="starkbank-ecdsa" />
40
+ <item index="26" class="java.lang.String" itemvalue="pydash" />
41
+ <item index="27" class="java.lang.String" itemvalue="bcrypt" />
42
+ <item index="28" class="java.lang.String" itemvalue="fastapi" />
43
+ <item index="29" class="java.lang.String" itemvalue="pydantic_core" />
44
+ <item index="30" class="java.lang.String" itemvalue="python-http-client" />
45
+ <item index="31" class="java.lang.String" itemvalue="email_validator" />
46
+ <item index="32" class="java.lang.String" itemvalue="typing_extensions" />
47
+ <item index="33" class="java.lang.String" itemvalue="pycryptodome" />
48
+ <item index="34" class="java.lang.String" itemvalue="pyee" />
49
+ <item index="35" class="java.lang.String" itemvalue="azure-identity" />
50
+ <item index="36" class="java.lang.String" itemvalue="greenlet" />
51
+ <item index="37" class="java.lang.String" itemvalue="playwright" />
52
+ <item index="38" class="java.lang.String" itemvalue="pyppeteer" />
53
+ <item index="39" class="java.lang.String" itemvalue="SQLAlchemy" />
54
+ <item index="40" class="java.lang.String" itemvalue="pyarrow" />
55
+ <item index="41" class="java.lang.String" itemvalue="protobuf" />
56
+ <item index="42" class="java.lang.String" itemvalue="tornado" />
57
+ <item index="43" class="java.lang.String" itemvalue="solders" />
58
+ <item index="44" class="java.lang.String" itemvalue="rich" />
59
+ <item index="45" class="java.lang.String" itemvalue="numpy" />
60
+ <item index="46" class="java.lang.String" itemvalue="Jinja2" />
61
+ <item index="47" class="java.lang.String" itemvalue="DateTime" />
62
+ <item index="48" class="java.lang.String" itemvalue="attrs" />
63
+ <item index="49" class="java.lang.String" itemvalue="altair" />
64
+ <item index="50" class="java.lang.String" itemvalue="pandas" />
65
+ <item index="51" class="java.lang.String" itemvalue="Pygments" />
66
+ <item index="52" class="java.lang.String" itemvalue="pip" />
67
+ <item index="53" class="java.lang.String" itemvalue="pillow" />
68
+ <item index="54" class="java.lang.String" itemvalue="httpx" />
69
+ <item index="55" class="java.lang.String" itemvalue="boto3" />
70
+ <item index="56" class="java.lang.String" itemvalue="six" />
71
+ <item index="57" class="java.lang.String" itemvalue="botocore" />
72
+ <item index="58" class="java.lang.String" itemvalue="huggingface-hub" />
73
+ <item index="59" class="java.lang.String" itemvalue="nvidia-cuda-cupti-cu12" />
74
+ <item index="60" class="java.lang.String" itemvalue="nvidia-cufft-cu12" />
75
+ <item index="61" class="java.lang.String" itemvalue="python-dateutil" />
76
+ <item index="62" class="java.lang.String" itemvalue="python-dotenv" />
77
+ <item index="63" class="java.lang.String" itemvalue="MarkupSafe" />
78
+ <item index="64" class="java.lang.String" itemvalue="pycparser" />
79
+ <item index="65" class="java.lang.String" itemvalue="frozenlist" />
80
+ <item index="66" class="java.lang.String" itemvalue="fsspec" />
81
+ <item index="67" class="java.lang.String" itemvalue="nvidia-cusolver-cu12" />
82
+ <item index="68" class="java.lang.String" itemvalue="nvidia-curand-cu12" />
83
+ <item index="69" class="java.lang.String" itemvalue="filelock" />
84
+ <item index="70" class="java.lang.String" itemvalue="safetensors" />
85
+ <item index="71" class="java.lang.String" itemvalue="sentencepiece" />
86
+ <item index="72" class="java.lang.String" itemvalue="multiprocess" />
87
+ <item index="73" class="java.lang.String" itemvalue="pyarrow-hotfix" />
88
+ <item index="74" class="java.lang.String" itemvalue="nvidia-cuda-runtime-cu12" />
89
+ <item index="75" class="java.lang.String" itemvalue="sympy" />
90
+ <item index="76" class="java.lang.String" itemvalue="xxhash" />
91
+ <item index="77" class="java.lang.String" itemvalue="beautifulsoup4" />
92
+ <item index="78" class="java.lang.String" itemvalue="tokenizers" />
93
+ <item index="79" class="java.lang.String" itemvalue="nvidia-cuda-nvrtc-cu12" />
94
+ <item index="80" class="java.lang.String" itemvalue="transformers" />
95
+ <item index="81" class="java.lang.String" itemvalue="triton" />
96
+ <item index="82" class="java.lang.String" itemvalue="cryptography" />
97
+ <item index="83" class="java.lang.String" itemvalue="openai" />
98
+ <item index="84" class="java.lang.String" itemvalue="nvidia-cublas-cu12" />
99
+ <item index="85" class="java.lang.String" itemvalue="regex" />
100
+ <item index="86" class="java.lang.String" itemvalue="nvidia-nvtx-cu12" />
101
+ <item index="87" class="java.lang.String" itemvalue="PyMySQL" />
102
+ <item index="88" class="java.lang.String" itemvalue="Mako" />
103
+ <item index="89" class="java.lang.String" itemvalue="evaluate" />
104
+ <item index="90" class="java.lang.String" itemvalue="httpcore" />
105
+ <item index="91" class="java.lang.String" itemvalue="idna" />
106
+ <item index="92" class="java.lang.String" itemvalue="environs" />
107
+ <item index="93" class="java.lang.String" itemvalue="networkx" />
108
+ <item index="94" class="java.lang.String" itemvalue="nvidia-nvjitlink-cu12" />
109
+ <item index="95" class="java.lang.String" itemvalue="nvidia-cusparse-cu12" />
110
+ <item index="96" class="java.lang.String" itemvalue="datasets" />
111
+ <item index="97" class="java.lang.String" itemvalue="nvidia-nccl-cu12" />
112
+ <item index="98" class="java.lang.String" itemvalue="sniffio" />
113
+ <item index="99" class="java.lang.String" itemvalue="aiomysql" />
114
+ <item index="100" class="java.lang.String" itemvalue="sqladmin" />
115
+ <item index="101" class="java.lang.String" itemvalue="itsdangerous" />
116
+ <item index="102" class="java.lang.String" itemvalue="faiss-gpu" />
117
+ <item index="103" class="java.lang.String" itemvalue="aiocron" />
118
+ <item index="104" class="java.lang.String" itemvalue="tzdata" />
119
+ <item index="105" class="java.lang.String" itemvalue="dill" />
120
+ <item index="106" class="java.lang.String" itemvalue="nvidia-cudnn-cu12" />
121
+ <item index="107" class="java.lang.String" itemvalue="torch" />
122
+ <item index="108" class="java.lang.String" itemvalue="et-xmlfile" />
123
+ <item index="109" class="java.lang.String" itemvalue="python-multipart" />
124
+ <item index="110" class="java.lang.String" itemvalue="tqdm" />
125
+ <item index="111" class="java.lang.String" itemvalue="aiohttp" />
126
+ <item index="112" class="java.lang.String" itemvalue="multidict" />
127
+ <item index="113" class="java.lang.String" itemvalue="responses" />
128
+ <item index="114" class="java.lang.String" itemvalue="pytz" />
129
+ <item index="115" class="java.lang.String" itemvalue="openpyxl" />
130
+ <item index="116" class="java.lang.String" itemvalue="tomli" />
131
+ <item index="117" class="java.lang.String" itemvalue="asyncio" />
132
+ <item index="118" class="java.lang.String" itemvalue="colorama" />
133
+ <item index="119" class="java.lang.String" itemvalue="zope.interface" />
134
+ <item index="120" class="java.lang.String" itemvalue="setuptools" />
135
+ <item index="121" class="java.lang.String" itemvalue="click" />
136
+ <item index="122" class="java.lang.String" itemvalue="vellum-ai" />
137
+ </list>
138
+ </value>
139
+ </option>
140
+ </inspection_tool>
141
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
142
+ <option name="ignoredErrors">
143
+ <list>
144
+ <option value="N802" />
145
+ </list>
146
+ </option>
147
+ </inspection_tool>
148
+ </profile>
149
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12 (LutAI)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (LutAI)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/LutAI.iml" filepath="$PROJECT_DIR$/.idea/LutAI.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
README.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LUT Transformation API
2
+
3
+ A FastAPI-based service for transforming .cube LUT files using AI-generated color adjustments and creating split preview images.
4
+
5
+ ## Features
6
+
7
+ - **LUT Transformation**: Transform .cube files using JSON-based color adjustments
8
+ - **AI Integration**: Generate color adjustments based on natural language prompts
9
+ - **Split Preview**: Create side-by-side comparison images showing before/after effects
10
+ - **Base64 Support**: Handle file uploads and image responses in base64 format
11
+
12
+ ## Setup
13
+
14
+ ### 1. Install Dependencies
15
+
16
+ ```bash
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ ### 2. Run the Server
21
+
22
+ ```bash
23
+ uvicorn main:app --reload
24
+ ```
25
+
26
+ The API will be available at `http://localhost:8000`
27
+
28
+ ### 3. API Documentation
29
+
30
+ Visit `http://localhost:8000/docs` for interactive API documentation.
31
+
32
+ ## API Endpoints
33
+
34
+ ### POST /transform-lut
35
+
36
+ Transform a LUT file using a text prompt.
37
+
38
+ **Request Body:**
39
+ ```json
40
+ {
41
+ "cube_file_base64": "base64_encoded_cube_file",
42
+ "user_prompt": "Make this LUT more cinematic with cool shadows"
43
+ }
44
+ ```
45
+
46
+ **Response:**
47
+ ```json
48
+ {
49
+ "success": true,
50
+ "message": "LUT transformation completed successfully",
51
+ "adjustments_applied": {
52
+ "shadows": {"r": 0.9, "g": 1.0, "b": 1.2},
53
+ "midtones": {"r": 1.0, "g": 1.0, "b": 1.0},
54
+ "highlights": {"r": 1.1, "g": 1.05, "b": 0.95},
55
+ "global": {"r": 1.0, "g": 1.0, "b": 1.0}
56
+ },
57
+ "split_preview_base64": "base64_encoded_preview_image"
58
+ }
59
+ ```
60
+
61
+ ### GET /health
62
+
63
+ Check API health and sample image availability.
64
+
65
+ **Response:**
66
+ ```json
67
+ {
68
+ "status": "healthy",
69
+ "sample_image_exists": true
70
+ }
71
+ ```
72
+
73
+ ## How It Works
74
+
75
+ 1. **Upload**: Send a .cube file as base64 and a text prompt
76
+ 2. **AI Processing**: The `generate_new_cube()` function processes the prompt and returns JSON adjustments
77
+ 3. **LUT Transformation**: Apply the adjustments to the original LUT using the `LUTTransformer` class
78
+ 4. **Image Processing**: Apply both original and modified LUTs to the sample image
79
+ 5. **Split Preview**: Create a side-by-side comparison with a vertical divider line
80
+ 6. **Response**: Return the preview image as base64
81
+
82
+ ## LUT Adjustment Format
83
+
84
+ The AI generates adjustments in this JSON format:
85
+
86
+ ```json
87
+ {
88
+ "shadows": {"r": 0.9, "g": 1.0, "b": 1.2},
89
+ "midtones": {"r": 1.0, "g": 1.0, "b": 1.0},
90
+ "highlights": {"r": 1.1, "g": 1.05, "b": 0.95},
91
+ "global": {"r": 1.0, "g": 1.0, "b": 1.0}
92
+ }
93
+ ```
94
+
95
+ - **shadows**: Adjustments for darker regions (luminance < 0.33)
96
+ - **midtones**: Adjustments for medium regions (0.33 ≤ luminance < 0.66)
97
+ - **highlights**: Adjustments for brighter regions (luminance ≥ 0.66)
98
+ - **global**: Overall adjustments applied to all regions
99
+
100
+ ## Testing
101
+
102
+ Use the provided `test_main.http` file to test the endpoints, or use curl:
103
+
104
+ ```bash
105
+ curl -X POST "http://localhost:8000/transform-lut" \
106
+ -H "Content-Type: application/json" \
107
+ -d '{
108
+ "cube_file_base64": "VElUTEUgIlRlc3QgTFVUIgpMVVRfM0RfU0laRSAyCg...",
109
+ "user_prompt": "Make this LUT more cinematic with cool shadows"
110
+ }'
111
+ ```
112
+
113
+ ## Sample Image
114
+
115
+ The API uses `sample.jpg` as the standard test image for preview generation. Make sure this file exists in the project root.
116
+
117
+ ## AI Integration
118
+
119
+ Replace the placeholder `generate_new_cube()` function with your actual AI implementation that takes a user prompt and returns color adjustment JSON.
120
+
121
+ ## Error Handling
122
+
123
+ The API includes comprehensive error handling for:
124
+ - Invalid cube file formats
125
+ - Missing sample images
126
+ - Base64 decoding errors
127
+ - Image processing failures
128
+ - File system operations
ai.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_openai import ChatOpenAI
3
+
4
+ from dotenv import load_dotenv
5
+ from pydantic import BaseModel, Field
6
+
7
+ load_dotenv()
8
+
9
+
10
+ class RGB(BaseModel):
11
+ r: float = Field(
12
+ description="Red channel multiplier (0.0-2.0). Values above 1.0 increase red, below 1.0 decrease red. Use for color temperature adjustments - lower for cool tones, higher for warm tones.",
13
+ ge=0.0,
14
+ le=2.0
15
+ )
16
+ g: float = Field(
17
+ description="Green channel multiplier (0.0-2.0). Values above 1.0 increase green, below 1.0 decrease green. Affects magenta/green color balance.",
18
+ ge=0.0,
19
+ le=2.0
20
+ )
21
+ b: float = Field(
22
+ description="Blue channel multiplier (0.0-2.0). Values above 1.0 increase blue, below 1.0 decrease blue. Higher values create cooler tones, lower values create warmer tones.",
23
+ ge=0.0,
24
+ le=2.0
25
+ )
26
+
27
+
28
+ class CubeAI(BaseModel):
29
+ shadows: RGB = Field(
30
+ description="RGB multipliers for shadow tones (darker areas of the image, typically 0-33% luminance). Use to adjust color temperature and mood in dark areas."
31
+ )
32
+ midtones: RGB = Field(
33
+ description="RGB multipliers for midtone areas (medium brightness areas, typically 33-66% luminance). Primary area for overall color grading and mood."
34
+ )
35
+ highlights: RGB = Field(
36
+ description="RGB multipliers for highlight areas (brightest areas of the image, typically 66-100% luminance). Use for sky tones, bright surfaces, and overall brightness balance."
37
+ )
38
+ glob: RGB = Field(
39
+ description="Global RGB multipliers applied to the entire image across all tonal ranges. Use for overall color temperature shifts and global corrections."
40
+ )
41
+
42
+
43
+ PROMPT = """
44
+ You are an expert color grading AI that generates precise LUT (Look-Up Table) adjustments for cinematic color grading.
45
+
46
+ Your task is to analyze the user's style request and generate RGB channel multipliers for different tonal ranges (shadows, midtones, highlights, and global).
47
+
48
+ ## Understanding LUT Color Grading:
49
+ - A LUT transforms input colors to output colors for consistent color grading
50
+ - RGB multipliers adjust the intensity of red, green, and blue channels
51
+ - Values: 1.0 = no change, >1.0 = increase channel, <1.0 = decrease channel
52
+ - Range: 0.0 to 2.0 for each channel
53
+
54
+ ## Tonal Ranges:
55
+ - **Shadows** (0-33% luminance): Dark areas, deep shadows, black levels
56
+ - **Midtones** (33-66% luminance): Primary subject matter, skin tones, main image content
57
+ - **Highlights** (66-100% luminance): Bright areas, sky, light sources, white levels
58
+ - **Global**: Overall adjustment applied to entire image
59
+
60
+ ## Common Style Guidelines:
61
+
62
+ ### Cinematic Styles:
63
+ - **Cool/Blue shadows**: shadows.b = 1.1-1.3, shadows.r = 0.8-0.9
64
+ - **Warm highlights**: highlights.r = 1.1-1.2, highlights.b = 0.9-0.95
65
+ - **Desaturated midtones**: midtones values closer to 1.0
66
+ - **Teal & Orange**: shadows.b = 1.2, highlights.r = 1.15, midtones.g = 0.95
67
+
68
+ ### Mood Adjustments:
69
+ - **Warm/Golden**: Increase red/green, decrease blue across tones
70
+ - **Cool/Cold**: Increase blue, decrease red across tones
71
+ - **Vintage**: Slightly reduce contrast, warm highlights, cool shadows
72
+ - **Modern/Clean**: Balanced adjustments, slight contrast enhancement
73
+
74
+ ### Technical Considerations:
75
+ - Subtle adjustments (0.05-0.2 change) for natural looks
76
+ - Dramatic adjustments (0.3+ change) for stylized looks
77
+ - Maintain skin tone integrity in midtones
78
+ - Avoid extreme values that cause color clipping
79
+
80
+ ## Examples:
81
+ User: "Make this more cinematic with cool shadows"
82
+ Response: shadows with blue increased (1.2), red decreased (0.85), warm highlights
83
+
84
+ User: "Vintage film look with warm tones"
85
+ Response: Overall warm adjustment, slightly reduced contrast, golden highlights
86
+
87
+ User: "Dark and moody atmosphere"
88
+ Response: Reduced overall brightness, cool shadows, maintain highlight detail
89
+
90
+ ## Output Requirements:
91
+ - Generate precise RGB multipliers for shadows, midtones, highlights, and global
92
+ - Each RGB value must be between 0.0 and 2.0
93
+ - Consider the interaction between different tonal ranges
94
+ - Aim for cohesive, professional color grading results
95
+
96
+ Analyze the user's request and generate appropriate LUT adjustments that achieve their desired look while maintaining image quality and professional color grading standards.
97
+ """
98
+
99
+
100
+ def generate_cube(user_prompt: str) -> CubeAI:
101
+ prompt = ChatPromptTemplate.from_messages(
102
+ [
103
+ ("system", PROMPT),
104
+ ("user", "{user_prompt}")
105
+ ]
106
+ )
107
+ model = ChatOpenAI(model="gpt-4.1-mini", temperature=0).with_structured_output(
108
+ CubeAI
109
+ )
110
+ chain = prompt | model
111
+ response = chain.invoke({"user_prompt": user_prompt})
112
+ return response
main.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import re
4
+ import sqlite3
5
+ import tempfile
6
+ import uuid
7
+ from io import BytesIO
8
+ from typing import Dict, List, Optional
9
+
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image
13
+ from fastapi import FastAPI, HTTPException, UploadFile, File
14
+ from pydantic import BaseModel
15
+ from starlette.middleware.cors import CORSMiddleware
16
+
17
+ from ai import generate_cube
18
+
19
+ app = FastAPI(title="LUT Transformation API", version="1.0.0")
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+
30
+ class LUTRequest(BaseModel):
31
+ file_id: str
32
+ user_prompt: str
33
+
34
+
35
+ class LUTTransformRequest(BaseModel):
36
+ file_id: str
37
+ user_prompt: str
38
+
39
+
40
+ class CubeFileResponse(BaseModel):
41
+ file_id: str
42
+ file_name: str
43
+
44
+
45
+ class CubeFileListItem(BaseModel):
46
+ file_id: str
47
+ file_name: str
48
+ upload_date: str
49
+
50
+
51
+ DATABASE_PATH = "cube_files.db"
52
+
53
+
54
+ def init_database():
55
+ """Initialize SQLite database and create tables"""
56
+ conn = sqlite3.connect(DATABASE_PATH)
57
+ cursor = conn.cursor()
58
+
59
+ cursor.execute("""
60
+ CREATE TABLE IF NOT EXISTS cube_files (
61
+ id TEXT PRIMARY KEY,
62
+ file_name TEXT NOT NULL,
63
+ file_data BLOB NOT NULL,
64
+ upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
65
+ )
66
+ """)
67
+
68
+ conn.commit()
69
+ conn.close()
70
+
71
+
72
+ def save_cube_file_to_db(file_name: str, file_data: bytes) -> str:
73
+ """Save cube file to database and return file ID"""
74
+ file_id = str(uuid.uuid4())
75
+ conn = sqlite3.connect(DATABASE_PATH)
76
+ cursor = conn.cursor()
77
+
78
+ cursor.execute(
79
+ "INSERT INTO cube_files (id, file_name, file_data) VALUES (?, ?, ?)",
80
+ (file_id, file_name, file_data)
81
+ )
82
+
83
+ conn.commit()
84
+ conn.close()
85
+ return file_id
86
+
87
+
88
+ def get_cube_file_from_db(file_id: str) -> Optional[tuple]:
89
+ """Retrieve cube file from database by ID"""
90
+ conn = sqlite3.connect(DATABASE_PATH)
91
+ cursor = conn.cursor()
92
+
93
+ cursor.execute(
94
+ "SELECT file_name, file_data FROM cube_files WHERE id = ?",
95
+ (file_id,)
96
+ )
97
+
98
+ result = cursor.fetchone()
99
+ conn.close()
100
+ return result
101
+
102
+
103
+ def list_cube_files_from_db() -> List[tuple]:
104
+ """List all cube files from database"""
105
+ conn = sqlite3.connect(DATABASE_PATH)
106
+ cursor = conn.cursor()
107
+
108
+ cursor.execute(
109
+ "SELECT id, file_name, upload_date FROM cube_files ORDER BY upload_date DESC"
110
+ )
111
+
112
+ results = cursor.fetchall()
113
+ conn.close()
114
+ return results
115
+
116
+
117
+ class LUTTransformer:
118
+ def __init__(self):
119
+ self.title = ""
120
+ self.size = 0
121
+ self.lut_data = []
122
+
123
+ def parse_cube_file(self, filepath: str) -> bool:
124
+ """Parse .cube file and extract LUT data"""
125
+ try:
126
+ with open(filepath, "r") as file:
127
+ lines = file.readlines()
128
+
129
+ self.lut_data = []
130
+
131
+ for line in lines:
132
+ line = line.strip()
133
+
134
+ if not line or line.startswith("#"):
135
+ continue
136
+
137
+ if line.startswith("TITLE"):
138
+ self.title = line.split('"')[1] if '"' in line else line.split()[1]
139
+
140
+ elif line.startswith("LUT_3D_SIZE"):
141
+ self.size = int(line.split()[1])
142
+
143
+ else:
144
+ rgb_match = re.findall(r"[\d.]+", line)
145
+ if len(rgb_match) >= 3:
146
+ r, g, b = map(float, rgb_match[:3])
147
+ self.lut_data.append([r, g, b])
148
+
149
+ return len(self.lut_data) > 0
150
+
151
+ except Exception as e:
152
+ print(f"Error parsing cube file: {e}")
153
+ return False
154
+
155
+ def apply_json_transformation(self, json_adjustments: Dict) -> bool:
156
+ """Apply JSON color adjustments to LUT data"""
157
+ try:
158
+ lut_array = np.array(self.lut_data)
159
+
160
+ for i, (r, g, b) in enumerate(lut_array):
161
+ luminance = 0.299 * r + 0.587 * g + 0.114 * b
162
+
163
+ if luminance < 0.33:
164
+ if "shadows" in json_adjustments:
165
+ adj = json_adjustments["shadows"]
166
+ lut_array[i] *= [
167
+ adj.get("r", 1.0),
168
+ adj.get("g", 1.0),
169
+ adj.get("b", 1.0),
170
+ ]
171
+
172
+ elif luminance < 0.66:
173
+ if "midtones" in json_adjustments:
174
+ adj = json_adjustments["midtones"]
175
+ lut_array[i] *= [
176
+ adj.get("r", 1.0),
177
+ adj.get("g", 1.0),
178
+ adj.get("b", 1.0),
179
+ ]
180
+
181
+ else:
182
+ if "highlights" in json_adjustments:
183
+ adj = json_adjustments["highlights"]
184
+ lut_array[i] *= [
185
+ adj.get("r", 1.0),
186
+ adj.get("g", 1.0),
187
+ adj.get("b", 1.0),
188
+ ]
189
+
190
+ if "glob" in json_adjustments:
191
+ global_adj = json_adjustments["glob"]
192
+ lut_array *= [
193
+ global_adj.get("r", 1.0),
194
+ global_adj.get("g", 1.0),
195
+ global_adj.get("b", 1.0),
196
+ ]
197
+
198
+ lut_array = np.clip(lut_array, 0.0, 1.0)
199
+ self.lut_data = lut_array.tolist()
200
+
201
+ return True
202
+
203
+ except Exception as e:
204
+ print(f"Error applying transformation: {e}")
205
+ return False
206
+
207
+ def save_cube_file(self, output_path: str, new_title: str = None) -> bool:
208
+ """Save modified LUT as .cube file"""
209
+ try:
210
+ with open(output_path, "w") as file:
211
+ title = new_title if new_title else f"{self.title}_modified"
212
+ file.write(f'TITLE "{title}"\n')
213
+ file.write(f"LUT_3D_SIZE {self.size}\n\n")
214
+
215
+ for r, g, b in self.lut_data:
216
+ file.write(f"{r:.6f} {g:.6f} {b:.6f}\n")
217
+
218
+ return True
219
+
220
+ except Exception as e:
221
+ print(f"Error saving cube file: {e}")
222
+ return False
223
+
224
+
225
+ def generate_new_cube(user_prompt: str) -> dict:
226
+ """
227
+ Placeholder for AI function that generates JSON adjustments based on user prompt.
228
+ This function should be replaced with the actual AI implementation.
229
+ """
230
+ response = generate_cube(user_prompt)
231
+ return response.model_dump(mode="json")
232
+
233
+
234
+ def apply_lut_to_image(image_path: str, lut_path: str) -> np.ndarray:
235
+ """Apply LUT to image using OpenCV"""
236
+ try:
237
+ img = cv2.imread(image_path)
238
+ if img is None:
239
+ raise ValueError(f"Could not load image: {image_path}")
240
+
241
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
242
+
243
+ transformer = LUTTransformer()
244
+ if not transformer.parse_cube_file(lut_path):
245
+ raise ValueError(f"Could not parse LUT file: {lut_path}")
246
+
247
+ lut_data = np.array(transformer.lut_data)
248
+ lut_size = transformer.size
249
+
250
+ lut_3d = lut_data.reshape((lut_size, lut_size, lut_size, 3))
251
+
252
+ img_normalized = img.astype(np.float32) / 255.0
253
+
254
+ result = np.zeros_like(img_normalized)
255
+
256
+ for i in range(img.shape[0]):
257
+ for j in range(img.shape[1]):
258
+ r, g, b = img_normalized[i, j]
259
+
260
+ r_idx = min(int(r * (lut_size - 1)), lut_size - 1)
261
+ g_idx = min(int(g * (lut_size - 1)), lut_size - 1)
262
+ b_idx = min(int(b * (lut_size - 1)), lut_size - 1)
263
+
264
+ result[i, j] = lut_3d[r_idx, g_idx, b_idx]
265
+
266
+ result = np.clip(result * 255, 0, 255).astype(np.uint8)
267
+ return result
268
+
269
+ except Exception as e:
270
+ print(f"Error applying LUT to image: {e}")
271
+ raise
272
+
273
+
274
+ def create_split_preview(
275
+ original_lut_path: str, new_lut_path: str, sample_image_path: str
276
+ ) -> str:
277
+ """Create a split preview image and return as base64"""
278
+ try:
279
+ original_processed = apply_lut_to_image(sample_image_path, original_lut_path)
280
+ new_processed = apply_lut_to_image(sample_image_path, new_lut_path)
281
+
282
+ height, width = original_processed.shape[:2]
283
+ split_image = np.zeros_like(original_processed)
284
+
285
+ mid_point = width // 2
286
+ split_image[:, :mid_point] = original_processed[:, :mid_point]
287
+ split_image[:, mid_point:] = new_processed[:, mid_point:]
288
+
289
+ cv2.line(split_image, (mid_point, 0), (mid_point, height), (255, 255, 255), 2)
290
+
291
+ pil_image = Image.fromarray(split_image)
292
+
293
+ buffer = BytesIO()
294
+ pil_image.save(buffer, format="PNG")
295
+ buffer.seek(0)
296
+
297
+ base64_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
298
+ return base64_string
299
+
300
+ except Exception as e:
301
+ print(f"Error creating split preview: {e}")
302
+ raise
303
+
304
+
305
+ @app.on_event("startup")
306
+ async def startup_event():
307
+ init_database()
308
+
309
+
310
+ @app.get("/")
311
+ async def root():
312
+ return {"message": "LUT Transformation API", "version": "1.0.0"}
313
+
314
+
315
+ @app.post("/upload-cube", response_model=CubeFileResponse)
316
+ async def upload_cube_file(file: UploadFile = File(...)):
317
+ """
318
+ Upload a .cube file and save it to the database
319
+ """
320
+ try:
321
+ if not file.filename.endswith('.cube'):
322
+ raise HTTPException(status_code=400, detail="Only .cube files are allowed")
323
+
324
+ file_data = await file.read()
325
+
326
+ file_id = save_cube_file_to_db(file.filename, file_data)
327
+
328
+ return CubeFileResponse(
329
+ file_id=file_id,
330
+ file_name=file.filename
331
+ )
332
+
333
+ except Exception as e:
334
+ raise HTTPException(status_code=500, detail=f"Error uploading file: {str(e)}")
335
+
336
+
337
+ @app.get("/cube-files", response_model=List[CubeFileListItem])
338
+ async def list_cube_files():
339
+ """
340
+ List all uploaded cube files with their IDs and names
341
+ """
342
+ try:
343
+ files = list_cube_files_from_db()
344
+ return [
345
+ CubeFileListItem(
346
+ file_id=file_id,
347
+ file_name=file_name,
348
+ upload_date=upload_date
349
+ )
350
+ for file_id, file_name, upload_date in files
351
+ ]
352
+ except Exception as e:
353
+ raise HTTPException(status_code=500, detail=f"Error listing files: {str(e)}")
354
+
355
+
356
+ @app.post("/transform-lut")
357
+ async def transform_lut(request: LUTTransformRequest):
358
+ """
359
+ Transform a LUT based on file ID and user prompt, return split preview image
360
+ """
361
+ try:
362
+ file_data = get_cube_file_from_db(request.file_id)
363
+ if not file_data:
364
+ raise HTTPException(status_code=404, detail="Cube file not found")
365
+
366
+ file_name, cube_data = file_data
367
+
368
+ with tempfile.NamedTemporaryFile(
369
+ mode="wb", suffix=".cube", delete=False
370
+ ) as temp_cube:
371
+ temp_cube.write(cube_data)
372
+ original_cube_path = temp_cube.name
373
+
374
+ try:
375
+ adjustments = generate_new_cube(request.user_prompt)
376
+ transformer = LUTTransformer()
377
+ if not transformer.parse_cube_file(original_cube_path):
378
+ raise HTTPException(status_code=400, detail="Failed to parse cube file")
379
+
380
+ if not transformer.apply_json_transformation(adjustments):
381
+ raise HTTPException(
382
+ status_code=500, detail="Failed to apply transformations"
383
+ )
384
+
385
+ with tempfile.NamedTemporaryFile(
386
+ mode="w", suffix=".cube", delete=False
387
+ ) as temp_new_cube:
388
+ new_cube_path = temp_new_cube.name
389
+
390
+ if not transformer.save_cube_file(
391
+ new_cube_path, f"{transformer.title}_AI_Modified"
392
+ ):
393
+ raise HTTPException(
394
+ status_code=500, detail="Failed to save new cube file"
395
+ )
396
+
397
+ sample_image_path = "sample.jpg"
398
+ if not os.path.exists(sample_image_path):
399
+ raise HTTPException(status_code=404, detail="Sample image not found")
400
+
401
+ split_preview_base64 = create_split_preview(
402
+ original_cube_path, new_cube_path, sample_image_path
403
+ )
404
+
405
+ return {
406
+ "success": True,
407
+ "message": "LUT transformation completed successfully",
408
+ "file_name": file_name,
409
+ "adjustments_applied": adjustments,
410
+ "split_preview_base64": split_preview_base64,
411
+ }
412
+
413
+ finally:
414
+ if os.path.exists(original_cube_path):
415
+ os.unlink(original_cube_path)
416
+ if "new_cube_path" in locals() and os.path.exists(new_cube_path):
417
+ os.unlink(new_cube_path)
418
+
419
+ except Exception as e:
420
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
421
+
422
+
423
+ @app.get("/health")
424
+ async def health_check():
425
+ return {"status": "healthy", "sample_image_exists": os.path.exists("sample.jpg")}
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ python-multipart==0.0.6
4
+ pillow==10.1.0
5
+ numpy==1.24.3
6
+ opencv-python==4.8.1.78
7
+ pydantic==2.4.2
test_cube_generator.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility script to generate a sample .cube file and convert it to base64 for testing the LUT API.
4
+ """
5
+
6
+ import base64
7
+ import json
8
+
9
+ def create_sample_cube_file(filename: str = "sample.cube", size: int = 8):
10
+ """Create a simple identity LUT cube file for testing"""
11
+ with open(filename, 'w') as f:
12
+ f.write('TITLE "Sample Identity LUT"\n')
13
+ f.write(f'LUT_3D_SIZE {size}\n\n')
14
+
15
+ for r in range(size):
16
+ for g in range(size):
17
+ for b in range(size):
18
+ red_val = r / (size - 1)
19
+ green_val = g / (size - 1)
20
+ blue_val = b / (size - 1)
21
+ f.write(f'{red_val:.6f} {green_val:.6f} {blue_val:.6f}\n')
22
+
23
+ print(f"Created sample cube file: {filename}")
24
+ return filename
25
+
26
+ def cube_file_to_base64(cube_file_path: str) -> str:
27
+ """Convert a cube file to base64 string"""
28
+ with open(cube_file_path, 'rb') as f:
29
+ cube_data = f.read()
30
+
31
+ base64_string = base64.b64encode(cube_data).decode('utf-8')
32
+ return base64_string
33
+
34
+ def generate_test_request(cube_file_path: str, prompt: str = "Make this LUT more cinematic with cool shadows") -> dict:
35
+ """Generate a complete test request JSON"""
36
+ base64_cube = cube_file_to_base64(cube_file_path)
37
+
38
+ request_data = {
39
+ "cube_file_base64": base64_cube,
40
+ "user_prompt": prompt
41
+ }
42
+
43
+ return request_data
44
+
45
+ if __name__ == "__main__":
46
+ cube_file = create_sample_cube_file()
47
+
48
+ test_request = generate_test_request(cube_file)
49
+
50
+ print(f"\nBase64 encoded cube file:")
51
+ print(f"Length: {len(test_request['cube_file_base64'])} characters")
52
+ print(f"First 100 chars: {test_request['cube_file_base64'][:100]}...")
53
+
54
+ with open("test_request.json", "w") as f:
55
+ json.dump(test_request, f, indent=2)
56
+
57
+ print(f"\nSaved complete test request to: test_request.json")
58
+ print(f"You can use this to test the API endpoint.")
59
+
60
+ print(f"\nExample curl command:")
61
+ print(f"curl -X POST \"http://localhost:8000/transform-lut\" \\")
62
+ print(f" -H \"Content-Type: application/json\" \\")
63
+ print(f" -d @test_request.json")
test_main.http ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Test the root endpoint
2
+ GET http://localhost:8000/
3
+
4
+ ### Test health check
5
+ GET http://localhost:8000/health
6
+
7
+ ### Test LUT transformation endpoint
8
+ POST http://localhost:8000/transform-lut
9
+ Content-Type: application/json
10
+
11
+ {
12
+ "cube_file_base64": "VElUTEUgIlRlc3QgTFVUIgpMVVRfM0RfU0laRSAyCgowLjAwMDAwMCAwLjAwMDAwMCAwLjAwMDAwMAowLjUwMDAwMCAwLjAwMDAwMCAwLjAwMDAwMAowLjAwMDAwMCAwLjUwMDAwMCAwLjAwMDAwMAowLjUwMDAwMCAwLjUwMDAwMCAwLjAwMDAwMAowLjAwMDAwMCAwLjAwMDAwMCAwLjUwMDAwMAowLjUwMDAwMCAwLjAwMDAwMCAwLjUwMDAwMAowLjAwMDAwMCAwLjUwMDAwMCAwLjUwMDAwMAoxLjAwMDAwMCAxLjAwMDAwMCAxLjAwMDAwMA==",
13
+ "user_prompt": "Make this LUT more cinematic with cool shadows"
14
+ }