File size: 3,241 Bytes
12dd383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import torch
import torchaudio
import torch.nn as nn
import torchaudio.transforms as T
import noisereduce as nr
import numpy as np
from asteroid.models import DCCRNet

TEMP_DIR = "temp_filtered"
OUTPUT_PATH = os.path.join(TEMP_DIR, "ivestis.wav")
os.makedirs(TEMP_DIR, exist_ok=True)

class WaveUNet(nn.Module):
    def __init__(self):
        super(WaveUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(16, 1, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def filtruoti_su_waveunet(input_path, output_path):
    print("🔧 Wave-U-Net filtravimas...")
    model = WaveUNet()
    model.eval()
    mixture, sr = torchaudio.load(input_path)
    if sr != 16000:
        print("🔁 Resample į 16kHz...")
        resampler = T.Resample(orig_freq=sr, new_freq=16000).to(mixture.device)
        mixture = resampler(mixture)
    if mixture.dim() == 2:
        mixture = mixture.unsqueeze(0)
    with torch.no_grad():
        output = model(mixture)
    output = output.squeeze(0)
    torchaudio.save(output_path, output, 16000)
    print(f"✅ Wave-U-Net išsaugota: {output_path}")

def filtruoti_su_denoiser(input_path, output_path):
    print("🔧 Denoiser (DCCRNet)...")
    model = DCCRNet.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k")
    mixture, sr = torchaudio.load(input_path)
    if sr != 16000:
        print("🔁 Resample į 16kHz...")
        resampler = T.Resample(orig_freq=sr, new_freq=16000).to(mixture.device)
        mixture = resampler(mixture)
    with torch.no_grad():
        est_source = model.separate(mixture)
    torchaudio.save(output_path, est_source[0], 16000)
    print(f"✅ Denoiser išsaugota: {output_path}")

def filtruoti_su_noisereduce(input_path, output_path):
    print("🔧 Noisereduce filtravimas...")
    waveform, sr = torchaudio.load(input_path)
    audio = waveform.detach().cpu().numpy()[0]
    reduced = nr.reduce_noise(y=audio, sr=sr)
    reduced_tensor = torch.from_numpy(reduced).unsqueeze(0)
    torchaudio.save(output_path, reduced_tensor, sr)
    print(f"✅ Noisereduce išsaugota: {output_path}")

def filtruoti_audio(input_path: str, metodas: str) -> str:
    if metodas == "Denoiser":
        filtruoti_su_denoiser(input_path, OUTPUT_PATH)
    elif metodas == "Wave-U-Net":
        filtruoti_su_waveunet(input_path, OUTPUT_PATH)
    elif metodas == "Noisereduce":
        filtruoti_su_noisereduce(input_path, OUTPUT_PATH)
    else:
        raise ValueError("Nepalaikomas filtravimo metodas")

    return OUTPUT_PATH