projektas1-demo-hf1 / filtravimas.py
Elanas's picture
Upload filtravimas.py
12dd383 verified
raw
history blame
3.24 kB
import os
import torch
import torchaudio
import torch.nn as nn
import torchaudio.transforms as T
import noisereduce as nr
import numpy as np
from asteroid.models import DCCRNet
TEMP_DIR = "temp_filtered"
OUTPUT_PATH = os.path.join(TEMP_DIR, "ivestis.wav")
os.makedirs(TEMP_DIR, exist_ok=True)
class WaveUNet(nn.Module):
def __init__(self):
super(WaveUNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose1d(64, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.ConvTranspose1d(32, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.ConvTranspose1d(16, 1, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def filtruoti_su_waveunet(input_path, output_path):
print("🔧 Wave-U-Net filtravimas...")
model = WaveUNet()
model.eval()
mixture, sr = torchaudio.load(input_path)
if sr != 16000:
print("🔁 Resample į 16kHz...")
resampler = T.Resample(orig_freq=sr, new_freq=16000).to(mixture.device)
mixture = resampler(mixture)
if mixture.dim() == 2:
mixture = mixture.unsqueeze(0)
with torch.no_grad():
output = model(mixture)
output = output.squeeze(0)
torchaudio.save(output_path, output, 16000)
print(f"✅ Wave-U-Net išsaugota: {output_path}")
def filtruoti_su_denoiser(input_path, output_path):
print("🔧 Denoiser (DCCRNet)...")
model = DCCRNet.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k")
mixture, sr = torchaudio.load(input_path)
if sr != 16000:
print("🔁 Resample į 16kHz...")
resampler = T.Resample(orig_freq=sr, new_freq=16000).to(mixture.device)
mixture = resampler(mixture)
with torch.no_grad():
est_source = model.separate(mixture)
torchaudio.save(output_path, est_source[0], 16000)
print(f"✅ Denoiser išsaugota: {output_path}")
def filtruoti_su_noisereduce(input_path, output_path):
print("🔧 Noisereduce filtravimas...")
waveform, sr = torchaudio.load(input_path)
audio = waveform.detach().cpu().numpy()[0]
reduced = nr.reduce_noise(y=audio, sr=sr)
reduced_tensor = torch.from_numpy(reduced).unsqueeze(0)
torchaudio.save(output_path, reduced_tensor, sr)
print(f"✅ Noisereduce išsaugota: {output_path}")
def filtruoti_audio(input_path: str, metodas: str) -> str:
if metodas == "Denoiser":
filtruoti_su_denoiser(input_path, OUTPUT_PATH)
elif metodas == "Wave-U-Net":
filtruoti_su_waveunet(input_path, OUTPUT_PATH)
elif metodas == "Noisereduce":
filtruoti_su_noisereduce(input_path, OUTPUT_PATH)
else:
raise ValueError("Nepalaikomas filtravimo metodas")
return OUTPUT_PATH