gsaon commited on
Commit
dc71e31
·
verified ·
1 Parent(s): ee8e2c4

Delete feature_extraction_granite_speech.py

Browse files
Files changed (1) hide show
  1. feature_extraction_granite_speech.py +0 -118
feature_extraction_granite_speech.py DELETED
@@ -1,118 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2025 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Feature extractor class for Speech Granite
17
- """
18
-
19
- import math
20
- from typing import List, Optional
21
-
22
- from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
23
- from transformers.utils import is_torch_available, is_torchaudio_available, logging
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
- if is_torch_available():
29
- import torch
30
-
31
- if is_torchaudio_available():
32
- import torchaudio
33
-
34
-
35
- class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
36
- model_input_names = ["input_features"]
37
-
38
- def __init__(
39
- self,
40
- sampling_rate=16000,
41
- n_fft=512,
42
- win_length=400,
43
- hop_length=160,
44
- n_mels=80,
45
- projector_window_size=15,
46
- projector_downsample_rate=5,
47
- **kwargs,
48
- ):
49
- super().__init__(**kwargs)
50
- self.melspec_kwargs = {
51
- "sample_rate": sampling_rate,
52
- "n_fft": n_fft,
53
- "win_length": win_length,
54
- "hop_length": hop_length,
55
- "n_mels": n_mels,
56
- }
57
- # HACK - for now, lazily initialize the mel spectrogram transform;
58
- # the feature extractor mixin explodes otherwise because
59
- # it tries to log the feature extractor, and the melspectrogram
60
- # transform isn't json serializable...
61
- self.melspec = None
62
- self.projector_window_size = projector_window_size
63
- self.projector_downsample_rate = projector_downsample_rate
64
-
65
- def _ensure_melspec_transform_is_initialized(self):
66
- if self.melspec is None:
67
- self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
68
-
69
- def __call__(
70
- self,
71
- x: torch.Tensor,
72
- device: Optional[str] = "cpu",
73
- ) -> BatchFeature:
74
- # TODO there is probably a better way to do both of these things...
75
- self._ensure_melspec_transform_is_initialized()
76
- if device is not None:
77
- melspec = self.melspec.to(device)
78
- x = x.to(device)
79
- else:
80
- melspec = self.melspec
81
-
82
- B, _ = x.shape
83
- with torch.no_grad():
84
- mel = melspec(x.float())
85
- logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_()
86
- mx = logmel.amax(dim=(-2, -1), keepdim=True)
87
- logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1)
88
- if logmel.shape[1] % 2 == 1:
89
- logmel = logmel[:, :-1] # remove last frame if odd
90
- x = logmel.reshape(B, -1, 2 * logmel.shape[-1]) # stacking and skipping by 2
91
-
92
- if x.device != "cpu":
93
- return x.detach().cpu()
94
- return x
95
-
96
- def _get_num_audio_features(self, audio_lengths: List[int]) -> List[int]:
97
- """
98
- Gets the (variable length) variable length number of features
99
- (i.e., projector output) for the sequences being considered.
100
- """
101
- hop_length = self.melspec_kwargs["hop_length"]
102
- effective_window_size = self.projector_window_size // self.projector_downsample_rate
103
-
104
- projector_lengths = []
105
- for raw_length in audio_lengths:
106
- # mel sequence length computation
107
- mel_length = raw_length // hop_length + 1
108
- # encoder frame takes two mel features
109
- encoder_length = mel_length // 2
110
- nblocks = math.ceil(encoder_length / self.projector_window_size)
111
- # projector output length
112
- projector_length = nblocks * effective_window_size
113
- projector_lengths.append(projector_length)
114
-
115
- return projector_lengths
116
-
117
-
118
- __all__ = ["GraniteSpeechFeatureExtractor"]