Hhhh / musicgen_api.py
Hjgugugjhuhjggg's picture
Upload 28 files
e83e49f verified
raw
history blame
1.18 kB
from flask import jsonify, send_file, request
from main import *
import torch
import soundfile as sf
import numpy as np
import io
def generate_music(prompt, output_path="output_music.wav"):
if musicgen_model is None:
return "Music generation model not initialized."
attributes = [prompt]
sample_rate = 32000
duration = 10
audio_values = musicgen_model.sample(
attributes=attributes,
sample_rate=sample_rate,
duration=duration,
)
output_audio = audio_values.cpu().numpy().squeeze()
sf.write(output_path, output_audio, sample_rate)
return output_path
def musicgen_api():
data = request.get_json()
prompt = data.get('prompt')
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
output_file = generate_music(prompt)
if output_file == "Music generation model not initialized.":
return jsonify({"error": "Music generation failed"}), 500
with open(output_file, 'rb') as f:
audio_content = f.read()
return send_file(io.BytesIO(audio_content), mimetype="audio/wav", as_attachment=True, download_name="output.wav")