Elanas commited on
Commit
bbb4657
·
verified ·
1 Parent(s): 755c7ae

Upload filtravimas.py

Browse files
Files changed (1) hide show
  1. filtravimas.py +86 -0
filtravimas.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filtravimas.py
2
+
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import torch.nn as nn
7
+ import torchaudio.transforms as T
8
+ from asteroid.models import DCCRNet
9
+
10
+ # Laikinas katalogas išfiltruotiems failams
11
+ TEMP_DIR = "temp_filtered"
12
+ OUTPUT_PATH = os.path.join(TEMP_DIR, "ivestis.wav")
13
+
14
+ # Užtikriname, kad aplankas egzistuoja
15
+ os.makedirs(TEMP_DIR, exist_ok=True)
16
+
17
+ # Paprastas Wave-U-Net modelis
18
+ class WaveUNet(nn.Module):
19
+ def __init__(self):
20
+ super(WaveUNet, self).__init__()
21
+ self.encoder = nn.Sequential(
22
+ nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1),
23
+ nn.ReLU(),
24
+ nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1),
25
+ nn.ReLU(),
26
+ nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1),
27
+ nn.ReLU(),
28
+ )
29
+ self.decoder = nn.Sequential(
30
+ nn.ConvTranspose1d(64, 32, kernel_size=3, stride=1, padding=1),
31
+ nn.ReLU(),
32
+ nn.ConvTranspose1d(32, 16, kernel_size=3, stride=1, padding=1),
33
+ nn.ReLU(),
34
+ nn.ConvTranspose1d(16, 1, kernel_size=3, stride=1, padding=1)
35
+ )
36
+
37
+ def forward(self, x):
38
+ x = self.encoder(x)
39
+ x = self.decoder(x)
40
+ return x
41
+
42
+ # Wave-U-Net filtravimas
43
+ def filtruoti_su_waveunet(input_path, output_path):
44
+ print("🔧 Wave-U-Net filtravimas...")
45
+ model = WaveUNet()
46
+ model.eval()
47
+ mixture, sr = torchaudio.load(input_path)
48
+
49
+ if sr != 16000:
50
+ print("🔁 Resample į 16kHz...")
51
+ mixture = T.Resample(orig_freq=sr, new_freq=16000)(mixture)
52
+
53
+ mixture = mixture.unsqueeze(0)
54
+ with torch.no_grad():
55
+ output = model(mixture)
56
+
57
+ output = output.squeeze(0)
58
+ torchaudio.save(output_path, output, 16000)
59
+ print(f"✅ Išsaugotas Wave-U-Net rezultatas: {output_path}")
60
+
61
+ # Denoiser (DCCRNet)
62
+ def filtruoti_su_denoiser(input_path, output_path):
63
+ print("🔧 Denoiser (DCCRNet)...")
64
+ model = DCCRNet.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k")
65
+ mixture, sr = torchaudio.load(input_path)
66
+
67
+ if sr != 16000:
68
+ print("🔁 Resample į 16kHz...")
69
+ mixture = T.Resample(orig_freq=sr, new_freq=16000)(mixture)
70
+
71
+ with torch.no_grad():
72
+ est_source = model.separate(mixture)
73
+
74
+ torchaudio.save(output_path, est_source[0], 16000)
75
+ print(f"✅ Išsaugotas Denoiser rezultatas: {output_path}")
76
+
77
+ # Pasirinkimo funkcija
78
+ def filtruoti_audio(input_path: str, metodas: str) -> str:
79
+ if metodas == "Denoiser":
80
+ filtruoti_su_denoiser(input_path, OUTPUT_PATH)
81
+ elif metodas == "Wave-U-Net":
82
+ filtruoti_su_waveunet(input_path, OUTPUT_PATH)
83
+ else:
84
+ raise ValueError("Nepalaikomas filtravimo metodas")
85
+
86
+ return OUTPUT_PATH