saurabhati commited on
Commit
6dfbaa2
·
verified ·
1 Parent(s): f08c7a5

Upload DASSForAudioClassification

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +1089 -0
  3. configuration_dass.py +91 -0
  4. model.safetensors +3 -0
  5. modeling_dass.py +1228 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DASSForAudioClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_dass.DASSConfig",
7
+ "AutoModelForAudioClassification": "modeling_dass.DASSForAudioClassification"
8
+ },
9
+ "depths": [
10
+ 2,
11
+ 2,
12
+ 20,
13
+ 2
14
+ ],
15
+ "dims": [
16
+ 96,
17
+ 192,
18
+ 384,
19
+ 768
20
+ ],
21
+ "drop_path_rate": 0.2,
22
+ "embed_dim": 96,
23
+ "id2label": {
24
+ "0": "Speech",
25
+ "1": "Male speech, man speaking",
26
+ "2": "Female speech, woman speaking",
27
+ "3": "Child speech, kid speaking",
28
+ "4": "Conversation",
29
+ "5": "Narration, monologue",
30
+ "6": "Babbling",
31
+ "7": "Speech synthesizer",
32
+ "8": "Shout",
33
+ "9": "Bellow",
34
+ "10": "Whoop",
35
+ "11": "Yell",
36
+ "12": "Battle cry",
37
+ "13": "Children shouting",
38
+ "14": "Screaming",
39
+ "15": "Whispering",
40
+ "16": "Laughter",
41
+ "17": "Baby laughter",
42
+ "18": "Giggle",
43
+ "19": "Snicker",
44
+ "20": "Belly laugh",
45
+ "21": "Chuckle, chortle",
46
+ "22": "Crying, sobbing",
47
+ "23": "Baby cry, infant cry",
48
+ "24": "Whimper",
49
+ "25": "Wail, moan",
50
+ "26": "Sigh",
51
+ "27": "Singing",
52
+ "28": "Choir",
53
+ "29": "Yodeling",
54
+ "30": "Chant",
55
+ "31": "Mantra",
56
+ "32": "Male singing",
57
+ "33": "Female singing",
58
+ "34": "Child singing",
59
+ "35": "Synthetic singing",
60
+ "36": "Rapping",
61
+ "37": "Humming",
62
+ "38": "Groan",
63
+ "39": "Grunt",
64
+ "40": "Whistling",
65
+ "41": "Breathing",
66
+ "42": "Wheeze",
67
+ "43": "Snoring",
68
+ "44": "Gasp",
69
+ "45": "Pant",
70
+ "46": "Snort",
71
+ "47": "Cough",
72
+ "48": "Throat clearing",
73
+ "49": "Sneeze",
74
+ "50": "Sniff",
75
+ "51": "Run",
76
+ "52": "Shuffle",
77
+ "53": "Walk, footsteps",
78
+ "54": "Chewing, mastication",
79
+ "55": "Biting",
80
+ "56": "Gargling",
81
+ "57": "Stomach rumble",
82
+ "58": "Burping, eructation",
83
+ "59": "Hiccup",
84
+ "60": "Fart",
85
+ "61": "Hands",
86
+ "62": "Finger snapping",
87
+ "63": "Clapping",
88
+ "64": "Heart sounds, heartbeat",
89
+ "65": "Heart murmur",
90
+ "66": "Cheering",
91
+ "67": "Applause",
92
+ "68": "Chatter",
93
+ "69": "Crowd",
94
+ "70": "Hubbub, speech noise, speech babble",
95
+ "71": "Children playing",
96
+ "72": "Animal",
97
+ "73": "Domestic animals, pets",
98
+ "74": "Dog",
99
+ "75": "Bark",
100
+ "76": "Yip",
101
+ "77": "Howl",
102
+ "78": "Bow-wow",
103
+ "79": "Growling",
104
+ "80": "Whimper (dog)",
105
+ "81": "Cat",
106
+ "82": "Purr",
107
+ "83": "Meow",
108
+ "84": "Hiss",
109
+ "85": "Caterwaul",
110
+ "86": "Livestock, farm animals, working animals",
111
+ "87": "Horse",
112
+ "88": "Clip-clop",
113
+ "89": "Neigh, whinny",
114
+ "90": "Cattle, bovinae",
115
+ "91": "Moo",
116
+ "92": "Cowbell",
117
+ "93": "Pig",
118
+ "94": "Oink",
119
+ "95": "Goat",
120
+ "96": "Bleat",
121
+ "97": "Sheep",
122
+ "98": "Fowl",
123
+ "99": "Chicken, rooster",
124
+ "100": "Cluck",
125
+ "101": "Crowing, cock-a-doodle-doo",
126
+ "102": "Turkey",
127
+ "103": "Gobble",
128
+ "104": "Duck",
129
+ "105": "Quack",
130
+ "106": "Goose",
131
+ "107": "Honk",
132
+ "108": "Wild animals",
133
+ "109": "Roaring cats (lions, tigers)",
134
+ "110": "Roar",
135
+ "111": "Bird",
136
+ "112": "Bird vocalization, bird call, bird song",
137
+ "113": "Chirp, tweet",
138
+ "114": "Squawk",
139
+ "115": "Pigeon, dove",
140
+ "116": "Coo",
141
+ "117": "Crow",
142
+ "118": "Caw",
143
+ "119": "Owl",
144
+ "120": "Hoot",
145
+ "121": "Bird flight, flapping wings",
146
+ "122": "Canidae, dogs, wolves",
147
+ "123": "Rodents, rats, mice",
148
+ "124": "Mouse",
149
+ "125": "Patter",
150
+ "126": "Insect",
151
+ "127": "Cricket",
152
+ "128": "Mosquito",
153
+ "129": "Fly, housefly",
154
+ "130": "Buzz",
155
+ "131": "Bee, wasp, etc.",
156
+ "132": "Frog",
157
+ "133": "Croak",
158
+ "134": "Snake",
159
+ "135": "Rattle",
160
+ "136": "Whale vocalization",
161
+ "137": "Music",
162
+ "138": "Musical instrument",
163
+ "139": "Plucked string instrument",
164
+ "140": "Guitar",
165
+ "141": "Electric guitar",
166
+ "142": "Bass guitar",
167
+ "143": "Acoustic guitar",
168
+ "144": "Steel guitar, slide guitar",
169
+ "145": "Tapping (guitar technique)",
170
+ "146": "Strum",
171
+ "147": "Banjo",
172
+ "148": "Sitar",
173
+ "149": "Mandolin",
174
+ "150": "Zither",
175
+ "151": "Ukulele",
176
+ "152": "Keyboard (musical)",
177
+ "153": "Piano",
178
+ "154": "Electric piano",
179
+ "155": "Organ",
180
+ "156": "Electronic organ",
181
+ "157": "Hammond organ",
182
+ "158": "Synthesizer",
183
+ "159": "Sampler",
184
+ "160": "Harpsichord",
185
+ "161": "Percussion",
186
+ "162": "Drum kit",
187
+ "163": "Drum machine",
188
+ "164": "Drum",
189
+ "165": "Snare drum",
190
+ "166": "Rimshot",
191
+ "167": "Drum roll",
192
+ "168": "Bass drum",
193
+ "169": "Timpani",
194
+ "170": "Tabla",
195
+ "171": "Cymbal",
196
+ "172": "Hi-hat",
197
+ "173": "Wood block",
198
+ "174": "Tambourine",
199
+ "175": "Rattle (instrument)",
200
+ "176": "Maraca",
201
+ "177": "Gong",
202
+ "178": "Tubular bells",
203
+ "179": "Mallet percussion",
204
+ "180": "Marimba, xylophone",
205
+ "181": "Glockenspiel",
206
+ "182": "Vibraphone",
207
+ "183": "Steelpan",
208
+ "184": "Orchestra",
209
+ "185": "Brass instrument",
210
+ "186": "French horn",
211
+ "187": "Trumpet",
212
+ "188": "Trombone",
213
+ "189": "Bowed string instrument",
214
+ "190": "String section",
215
+ "191": "Violin, fiddle",
216
+ "192": "Pizzicato",
217
+ "193": "Cello",
218
+ "194": "Double bass",
219
+ "195": "Wind instrument, woodwind instrument",
220
+ "196": "Flute",
221
+ "197": "Saxophone",
222
+ "198": "Clarinet",
223
+ "199": "Harp",
224
+ "200": "Bell",
225
+ "201": "Church bell",
226
+ "202": "Jingle bell",
227
+ "203": "Bicycle bell",
228
+ "204": "Tuning fork",
229
+ "205": "Chime",
230
+ "206": "Wind chime",
231
+ "207": "Change ringing (campanology)",
232
+ "208": "Harmonica",
233
+ "209": "Accordion",
234
+ "210": "Bagpipes",
235
+ "211": "Didgeridoo",
236
+ "212": "Shofar",
237
+ "213": "Theremin",
238
+ "214": "Singing bowl",
239
+ "215": "Scratching (performance technique)",
240
+ "216": "Pop music",
241
+ "217": "Hip hop music",
242
+ "218": "Beatboxing",
243
+ "219": "Rock music",
244
+ "220": "Heavy metal",
245
+ "221": "Punk rock",
246
+ "222": "Grunge",
247
+ "223": "Progressive rock",
248
+ "224": "Rock and roll",
249
+ "225": "Psychedelic rock",
250
+ "226": "Rhythm and blues",
251
+ "227": "Soul music",
252
+ "228": "Reggae",
253
+ "229": "Country",
254
+ "230": "Swing music",
255
+ "231": "Bluegrass",
256
+ "232": "Funk",
257
+ "233": "Folk music",
258
+ "234": "Middle Eastern music",
259
+ "235": "Jazz",
260
+ "236": "Disco",
261
+ "237": "Classical music",
262
+ "238": "Opera",
263
+ "239": "Electronic music",
264
+ "240": "House music",
265
+ "241": "Techno",
266
+ "242": "Dubstep",
267
+ "243": "Drum and bass",
268
+ "244": "Electronica",
269
+ "245": "Electronic dance music",
270
+ "246": "Ambient music",
271
+ "247": "Trance music",
272
+ "248": "Music of Latin America",
273
+ "249": "Salsa music",
274
+ "250": "Flamenco",
275
+ "251": "Blues",
276
+ "252": "Music for children",
277
+ "253": "New-age music",
278
+ "254": "Vocal music",
279
+ "255": "A capella",
280
+ "256": "Music of Africa",
281
+ "257": "Afrobeat",
282
+ "258": "Christian music",
283
+ "259": "Gospel music",
284
+ "260": "Music of Asia",
285
+ "261": "Carnatic music",
286
+ "262": "Music of Bollywood",
287
+ "263": "Ska",
288
+ "264": "Traditional music",
289
+ "265": "Independent music",
290
+ "266": "Song",
291
+ "267": "Background music",
292
+ "268": "Theme music",
293
+ "269": "Jingle (music)",
294
+ "270": "Soundtrack music",
295
+ "271": "Lullaby",
296
+ "272": "Video game music",
297
+ "273": "Christmas music",
298
+ "274": "Dance music",
299
+ "275": "Wedding music",
300
+ "276": "Happy music",
301
+ "277": "Funny music",
302
+ "278": "Sad music",
303
+ "279": "Tender music",
304
+ "280": "Exciting music",
305
+ "281": "Angry music",
306
+ "282": "Scary music",
307
+ "283": "Wind",
308
+ "284": "Rustling leaves",
309
+ "285": "Wind noise (microphone)",
310
+ "286": "Thunderstorm",
311
+ "287": "Thunder",
312
+ "288": "Water",
313
+ "289": "Rain",
314
+ "290": "Raindrop",
315
+ "291": "Rain on surface",
316
+ "292": "Stream",
317
+ "293": "Waterfall",
318
+ "294": "Ocean",
319
+ "295": "Waves, surf",
320
+ "296": "Steam",
321
+ "297": "Gurgling",
322
+ "298": "Fire",
323
+ "299": "Crackle",
324
+ "300": "Vehicle",
325
+ "301": "Boat, Water vehicle",
326
+ "302": "Sailboat, sailing ship",
327
+ "303": "Rowboat, canoe, kayak",
328
+ "304": "Motorboat, speedboat",
329
+ "305": "Ship",
330
+ "306": "Motor vehicle (road)",
331
+ "307": "Car",
332
+ "308": "Vehicle horn, car horn, honking",
333
+ "309": "Toot",
334
+ "310": "Car alarm",
335
+ "311": "Power windows, electric windows",
336
+ "312": "Skidding",
337
+ "313": "Tire squeal",
338
+ "314": "Car passing by",
339
+ "315": "Race car, auto racing",
340
+ "316": "Truck",
341
+ "317": "Air brake",
342
+ "318": "Air horn, truck horn",
343
+ "319": "Reversing beeps",
344
+ "320": "Ice cream truck, ice cream van",
345
+ "321": "Bus",
346
+ "322": "Emergency vehicle",
347
+ "323": "Police car (siren)",
348
+ "324": "Ambulance (siren)",
349
+ "325": "Fire engine, fire truck (siren)",
350
+ "326": "Motorcycle",
351
+ "327": "Traffic noise, roadway noise",
352
+ "328": "Rail transport",
353
+ "329": "Train",
354
+ "330": "Train whistle",
355
+ "331": "Train horn",
356
+ "332": "Railroad car, train wagon",
357
+ "333": "Train wheels squealing",
358
+ "334": "Subway, metro, underground",
359
+ "335": "Aircraft",
360
+ "336": "Aircraft engine",
361
+ "337": "Jet engine",
362
+ "338": "Propeller, airscrew",
363
+ "339": "Helicopter",
364
+ "340": "Fixed-wing aircraft, airplane",
365
+ "341": "Bicycle",
366
+ "342": "Skateboard",
367
+ "343": "Engine",
368
+ "344": "Light engine (high frequency)",
369
+ "345": "Dental drill, dentist's drill",
370
+ "346": "Lawn mower",
371
+ "347": "Chainsaw",
372
+ "348": "Medium engine (mid frequency)",
373
+ "349": "Heavy engine (low frequency)",
374
+ "350": "Engine knocking",
375
+ "351": "Engine starting",
376
+ "352": "Idling",
377
+ "353": "Accelerating, revving, vroom",
378
+ "354": "Door",
379
+ "355": "Doorbell",
380
+ "356": "Ding-dong",
381
+ "357": "Sliding door",
382
+ "358": "Slam",
383
+ "359": "Knock",
384
+ "360": "Tap",
385
+ "361": "Squeak",
386
+ "362": "Cupboard open or close",
387
+ "363": "Drawer open or close",
388
+ "364": "Dishes, pots, and pans",
389
+ "365": "Cutlery, silverware",
390
+ "366": "Chopping (food)",
391
+ "367": "Frying (food)",
392
+ "368": "Microwave oven",
393
+ "369": "Blender",
394
+ "370": "Water tap, faucet",
395
+ "371": "Sink (filling or washing)",
396
+ "372": "Bathtub (filling or washing)",
397
+ "373": "Hair dryer",
398
+ "374": "Toilet flush",
399
+ "375": "Toothbrush",
400
+ "376": "Electric toothbrush",
401
+ "377": "Vacuum cleaner",
402
+ "378": "Zipper (clothing)",
403
+ "379": "Keys jangling",
404
+ "380": "Coin (dropping)",
405
+ "381": "Scissors",
406
+ "382": "Electric shaver, electric razor",
407
+ "383": "Shuffling cards",
408
+ "384": "Typing",
409
+ "385": "Typewriter",
410
+ "386": "Computer keyboard",
411
+ "387": "Writing",
412
+ "388": "Alarm",
413
+ "389": "Telephone",
414
+ "390": "Telephone bell ringing",
415
+ "391": "Ringtone",
416
+ "392": "Telephone dialing, DTMF",
417
+ "393": "Dial tone",
418
+ "394": "Busy signal",
419
+ "395": "Alarm clock",
420
+ "396": "Siren",
421
+ "397": "Civil defense siren",
422
+ "398": "Buzzer",
423
+ "399": "Smoke detector, smoke alarm",
424
+ "400": "Fire alarm",
425
+ "401": "Foghorn",
426
+ "402": "Whistle",
427
+ "403": "Steam whistle",
428
+ "404": "Mechanisms",
429
+ "405": "Ratchet, pawl",
430
+ "406": "Clock",
431
+ "407": "Tick",
432
+ "408": "Tick-tock",
433
+ "409": "Gears",
434
+ "410": "Pulleys",
435
+ "411": "Sewing machine",
436
+ "412": "Mechanical fan",
437
+ "413": "Air conditioning",
438
+ "414": "Cash register",
439
+ "415": "Printer",
440
+ "416": "Camera",
441
+ "417": "Single-lens reflex camera",
442
+ "418": "Tools",
443
+ "419": "Hammer",
444
+ "420": "Jackhammer",
445
+ "421": "Sawing",
446
+ "422": "Filing (rasp)",
447
+ "423": "Sanding",
448
+ "424": "Power tool",
449
+ "425": "Drill",
450
+ "426": "Explosion",
451
+ "427": "Gunshot, gunfire",
452
+ "428": "Machine gun",
453
+ "429": "Fusillade",
454
+ "430": "Artillery fire",
455
+ "431": "Cap gun",
456
+ "432": "Fireworks",
457
+ "433": "Firecracker",
458
+ "434": "Burst, pop",
459
+ "435": "Eruption",
460
+ "436": "Boom",
461
+ "437": "Wood",
462
+ "438": "Chop",
463
+ "439": "Splinter",
464
+ "440": "Crack",
465
+ "441": "Glass",
466
+ "442": "Chink, clink",
467
+ "443": "Shatter",
468
+ "444": "Liquid",
469
+ "445": "Splash, splatter",
470
+ "446": "Slosh",
471
+ "447": "Squish",
472
+ "448": "Drip",
473
+ "449": "Pour",
474
+ "450": "Trickle, dribble",
475
+ "451": "Gush",
476
+ "452": "Fill (with liquid)",
477
+ "453": "Spray",
478
+ "454": "Pump (liquid)",
479
+ "455": "Stir",
480
+ "456": "Boiling",
481
+ "457": "Sonar",
482
+ "458": "Arrow",
483
+ "459": "Whoosh, swoosh, swish",
484
+ "460": "Thump, thud",
485
+ "461": "Thunk",
486
+ "462": "Electronic tuner",
487
+ "463": "Effects unit",
488
+ "464": "Chorus effect",
489
+ "465": "Basketball bounce",
490
+ "466": "Bang",
491
+ "467": "Slap, smack",
492
+ "468": "Whack, thwack",
493
+ "469": "Smash, crash",
494
+ "470": "Breaking",
495
+ "471": "Bouncing",
496
+ "472": "Whip",
497
+ "473": "Flap",
498
+ "474": "Scratch",
499
+ "475": "Scrape",
500
+ "476": "Rub",
501
+ "477": "Roll",
502
+ "478": "Crushing",
503
+ "479": "Crumpling, crinkling",
504
+ "480": "Tearing",
505
+ "481": "Beep, bleep",
506
+ "482": "Ping",
507
+ "483": "Ding",
508
+ "484": "Clang",
509
+ "485": "Squeal",
510
+ "486": "Creak",
511
+ "487": "Rustle",
512
+ "488": "Whir",
513
+ "489": "Clatter",
514
+ "490": "Sizzle",
515
+ "491": "Clicking",
516
+ "492": "Clickety-clack",
517
+ "493": "Rumble",
518
+ "494": "Plop",
519
+ "495": "Jingle, tinkle",
520
+ "496": "Hum",
521
+ "497": "Zing",
522
+ "498": "Boing",
523
+ "499": "Crunch",
524
+ "500": "Silence",
525
+ "501": "Sine wave",
526
+ "502": "Harmonic",
527
+ "503": "Chirp tone",
528
+ "504": "Sound effect",
529
+ "505": "Pulse",
530
+ "506": "Inside, small room",
531
+ "507": "Inside, large room or hall",
532
+ "508": "Inside, public space",
533
+ "509": "Outside, urban or manmade",
534
+ "510": "Outside, rural or natural",
535
+ "511": "Reverberation",
536
+ "512": "Echo",
537
+ "513": "Noise",
538
+ "514": "Environmental noise",
539
+ "515": "Static",
540
+ "516": "Mains hum",
541
+ "517": "Distortion",
542
+ "518": "Sidetone",
543
+ "519": "Cacophony",
544
+ "520": "White noise",
545
+ "521": "Pink noise",
546
+ "522": "Throbbing",
547
+ "523": "Vibration",
548
+ "524": "Television",
549
+ "525": "Radio",
550
+ "526": "Field recording"
551
+ },
552
+ "label2id": {
553
+ "A capella": 255,
554
+ "Accelerating, revving, vroom": 353,
555
+ "Accordion": 209,
556
+ "Acoustic guitar": 143,
557
+ "Afrobeat": 257,
558
+ "Air brake": 317,
559
+ "Air conditioning": 413,
560
+ "Air horn, truck horn": 318,
561
+ "Aircraft": 335,
562
+ "Aircraft engine": 336,
563
+ "Alarm": 388,
564
+ "Alarm clock": 395,
565
+ "Ambient music": 246,
566
+ "Ambulance (siren)": 324,
567
+ "Angry music": 281,
568
+ "Animal": 72,
569
+ "Applause": 67,
570
+ "Arrow": 458,
571
+ "Artillery fire": 430,
572
+ "Babbling": 6,
573
+ "Baby cry, infant cry": 23,
574
+ "Baby laughter": 17,
575
+ "Background music": 267,
576
+ "Bagpipes": 210,
577
+ "Bang": 466,
578
+ "Banjo": 147,
579
+ "Bark": 75,
580
+ "Basketball bounce": 465,
581
+ "Bass drum": 168,
582
+ "Bass guitar": 142,
583
+ "Bathtub (filling or washing)": 372,
584
+ "Battle cry": 12,
585
+ "Beatboxing": 218,
586
+ "Bee, wasp, etc.": 131,
587
+ "Beep, bleep": 481,
588
+ "Bell": 200,
589
+ "Bellow": 9,
590
+ "Belly laugh": 20,
591
+ "Bicycle": 341,
592
+ "Bicycle bell": 203,
593
+ "Bird": 111,
594
+ "Bird flight, flapping wings": 121,
595
+ "Bird vocalization, bird call, bird song": 112,
596
+ "Biting": 55,
597
+ "Bleat": 96,
598
+ "Blender": 369,
599
+ "Bluegrass": 231,
600
+ "Blues": 251,
601
+ "Boat, Water vehicle": 301,
602
+ "Boiling": 456,
603
+ "Boing": 498,
604
+ "Boom": 436,
605
+ "Bouncing": 471,
606
+ "Bow-wow": 78,
607
+ "Bowed string instrument": 189,
608
+ "Brass instrument": 185,
609
+ "Breaking": 470,
610
+ "Breathing": 41,
611
+ "Burping, eructation": 58,
612
+ "Burst, pop": 434,
613
+ "Bus": 321,
614
+ "Busy signal": 394,
615
+ "Buzz": 130,
616
+ "Buzzer": 398,
617
+ "Cacophony": 519,
618
+ "Camera": 416,
619
+ "Canidae, dogs, wolves": 122,
620
+ "Cap gun": 431,
621
+ "Car": 307,
622
+ "Car alarm": 310,
623
+ "Car passing by": 314,
624
+ "Carnatic music": 261,
625
+ "Cash register": 414,
626
+ "Cat": 81,
627
+ "Caterwaul": 85,
628
+ "Cattle, bovinae": 90,
629
+ "Caw": 118,
630
+ "Cello": 193,
631
+ "Chainsaw": 347,
632
+ "Change ringing (campanology)": 207,
633
+ "Chant": 30,
634
+ "Chatter": 68,
635
+ "Cheering": 66,
636
+ "Chewing, mastication": 54,
637
+ "Chicken, rooster": 99,
638
+ "Child singing": 34,
639
+ "Child speech, kid speaking": 3,
640
+ "Children playing": 71,
641
+ "Children shouting": 13,
642
+ "Chime": 205,
643
+ "Chink, clink": 442,
644
+ "Chirp tone": 503,
645
+ "Chirp, tweet": 113,
646
+ "Choir": 28,
647
+ "Chop": 438,
648
+ "Chopping (food)": 366,
649
+ "Chorus effect": 464,
650
+ "Christian music": 258,
651
+ "Christmas music": 273,
652
+ "Chuckle, chortle": 21,
653
+ "Church bell": 201,
654
+ "Civil defense siren": 397,
655
+ "Clang": 484,
656
+ "Clapping": 63,
657
+ "Clarinet": 198,
658
+ "Classical music": 237,
659
+ "Clatter": 489,
660
+ "Clickety-clack": 492,
661
+ "Clicking": 491,
662
+ "Clip-clop": 88,
663
+ "Clock": 406,
664
+ "Cluck": 100,
665
+ "Coin (dropping)": 380,
666
+ "Computer keyboard": 386,
667
+ "Conversation": 4,
668
+ "Coo": 116,
669
+ "Cough": 47,
670
+ "Country": 229,
671
+ "Cowbell": 92,
672
+ "Crack": 440,
673
+ "Crackle": 299,
674
+ "Creak": 486,
675
+ "Cricket": 127,
676
+ "Croak": 133,
677
+ "Crow": 117,
678
+ "Crowd": 69,
679
+ "Crowing, cock-a-doodle-doo": 101,
680
+ "Crumpling, crinkling": 479,
681
+ "Crunch": 499,
682
+ "Crushing": 478,
683
+ "Crying, sobbing": 22,
684
+ "Cupboard open or close": 362,
685
+ "Cutlery, silverware": 365,
686
+ "Cymbal": 171,
687
+ "Dance music": 274,
688
+ "Dental drill, dentist's drill": 345,
689
+ "Dial tone": 393,
690
+ "Didgeridoo": 211,
691
+ "Ding": 483,
692
+ "Ding-dong": 356,
693
+ "Disco": 236,
694
+ "Dishes, pots, and pans": 364,
695
+ "Distortion": 517,
696
+ "Dog": 74,
697
+ "Domestic animals, pets": 73,
698
+ "Door": 354,
699
+ "Doorbell": 355,
700
+ "Double bass": 194,
701
+ "Drawer open or close": 363,
702
+ "Drill": 425,
703
+ "Drip": 448,
704
+ "Drum": 164,
705
+ "Drum and bass": 243,
706
+ "Drum kit": 162,
707
+ "Drum machine": 163,
708
+ "Drum roll": 167,
709
+ "Dubstep": 242,
710
+ "Duck": 104,
711
+ "Echo": 512,
712
+ "Effects unit": 463,
713
+ "Electric guitar": 141,
714
+ "Electric piano": 154,
715
+ "Electric shaver, electric razor": 382,
716
+ "Electric toothbrush": 376,
717
+ "Electronic dance music": 245,
718
+ "Electronic music": 239,
719
+ "Electronic organ": 156,
720
+ "Electronic tuner": 462,
721
+ "Electronica": 244,
722
+ "Emergency vehicle": 322,
723
+ "Engine": 343,
724
+ "Engine knocking": 350,
725
+ "Engine starting": 351,
726
+ "Environmental noise": 514,
727
+ "Eruption": 435,
728
+ "Exciting music": 280,
729
+ "Explosion": 426,
730
+ "Fart": 60,
731
+ "Female singing": 33,
732
+ "Female speech, woman speaking": 2,
733
+ "Field recording": 526,
734
+ "Filing (rasp)": 422,
735
+ "Fill (with liquid)": 452,
736
+ "Finger snapping": 62,
737
+ "Fire": 298,
738
+ "Fire alarm": 400,
739
+ "Fire engine, fire truck (siren)": 325,
740
+ "Firecracker": 433,
741
+ "Fireworks": 432,
742
+ "Fixed-wing aircraft, airplane": 340,
743
+ "Flamenco": 250,
744
+ "Flap": 473,
745
+ "Flute": 196,
746
+ "Fly, housefly": 129,
747
+ "Foghorn": 401,
748
+ "Folk music": 233,
749
+ "Fowl": 98,
750
+ "French horn": 186,
751
+ "Frog": 132,
752
+ "Frying (food)": 367,
753
+ "Funk": 232,
754
+ "Funny music": 277,
755
+ "Fusillade": 429,
756
+ "Gargling": 56,
757
+ "Gasp": 44,
758
+ "Gears": 409,
759
+ "Giggle": 18,
760
+ "Glass": 441,
761
+ "Glockenspiel": 181,
762
+ "Goat": 95,
763
+ "Gobble": 103,
764
+ "Gong": 177,
765
+ "Goose": 106,
766
+ "Gospel music": 259,
767
+ "Groan": 38,
768
+ "Growling": 79,
769
+ "Grunge": 222,
770
+ "Grunt": 39,
771
+ "Guitar": 140,
772
+ "Gunshot, gunfire": 427,
773
+ "Gurgling": 297,
774
+ "Gush": 451,
775
+ "Hair dryer": 373,
776
+ "Hammer": 419,
777
+ "Hammond organ": 157,
778
+ "Hands": 61,
779
+ "Happy music": 276,
780
+ "Harmonic": 502,
781
+ "Harmonica": 208,
782
+ "Harp": 199,
783
+ "Harpsichord": 160,
784
+ "Heart murmur": 65,
785
+ "Heart sounds, heartbeat": 64,
786
+ "Heavy engine (low frequency)": 349,
787
+ "Heavy metal": 220,
788
+ "Helicopter": 339,
789
+ "Hi-hat": 172,
790
+ "Hiccup": 59,
791
+ "Hip hop music": 217,
792
+ "Hiss": 84,
793
+ "Honk": 107,
794
+ "Hoot": 120,
795
+ "Horse": 87,
796
+ "House music": 240,
797
+ "Howl": 77,
798
+ "Hubbub, speech noise, speech babble": 70,
799
+ "Hum": 496,
800
+ "Humming": 37,
801
+ "Ice cream truck, ice cream van": 320,
802
+ "Idling": 352,
803
+ "Independent music": 265,
804
+ "Insect": 126,
805
+ "Inside, large room or hall": 507,
806
+ "Inside, public space": 508,
807
+ "Inside, small room": 506,
808
+ "Jackhammer": 420,
809
+ "Jazz": 235,
810
+ "Jet engine": 337,
811
+ "Jingle (music)": 269,
812
+ "Jingle bell": 202,
813
+ "Jingle, tinkle": 495,
814
+ "Keyboard (musical)": 152,
815
+ "Keys jangling": 379,
816
+ "Knock": 359,
817
+ "Laughter": 16,
818
+ "Lawn mower": 346,
819
+ "Light engine (high frequency)": 344,
820
+ "Liquid": 444,
821
+ "Livestock, farm animals, working animals": 86,
822
+ "Lullaby": 271,
823
+ "Machine gun": 428,
824
+ "Mains hum": 516,
825
+ "Male singing": 32,
826
+ "Male speech, man speaking": 1,
827
+ "Mallet percussion": 179,
828
+ "Mandolin": 149,
829
+ "Mantra": 31,
830
+ "Maraca": 176,
831
+ "Marimba, xylophone": 180,
832
+ "Mechanical fan": 412,
833
+ "Mechanisms": 404,
834
+ "Medium engine (mid frequency)": 348,
835
+ "Meow": 83,
836
+ "Microwave oven": 368,
837
+ "Middle Eastern music": 234,
838
+ "Moo": 91,
839
+ "Mosquito": 128,
840
+ "Motor vehicle (road)": 306,
841
+ "Motorboat, speedboat": 304,
842
+ "Motorcycle": 326,
843
+ "Mouse": 124,
844
+ "Music": 137,
845
+ "Music for children": 252,
846
+ "Music of Africa": 256,
847
+ "Music of Asia": 260,
848
+ "Music of Bollywood": 262,
849
+ "Music of Latin America": 248,
850
+ "Musical instrument": 138,
851
+ "Narration, monologue": 5,
852
+ "Neigh, whinny": 89,
853
+ "New-age music": 253,
854
+ "Noise": 513,
855
+ "Ocean": 294,
856
+ "Oink": 94,
857
+ "Opera": 238,
858
+ "Orchestra": 184,
859
+ "Organ": 155,
860
+ "Outside, rural or natural": 510,
861
+ "Outside, urban or manmade": 509,
862
+ "Owl": 119,
863
+ "Pant": 45,
864
+ "Patter": 125,
865
+ "Percussion": 161,
866
+ "Piano": 153,
867
+ "Pig": 93,
868
+ "Pigeon, dove": 115,
869
+ "Ping": 482,
870
+ "Pink noise": 521,
871
+ "Pizzicato": 192,
872
+ "Plop": 494,
873
+ "Plucked string instrument": 139,
874
+ "Police car (siren)": 323,
875
+ "Pop music": 216,
876
+ "Pour": 449,
877
+ "Power tool": 424,
878
+ "Power windows, electric windows": 311,
879
+ "Printer": 415,
880
+ "Progressive rock": 223,
881
+ "Propeller, airscrew": 338,
882
+ "Psychedelic rock": 225,
883
+ "Pulleys": 410,
884
+ "Pulse": 505,
885
+ "Pump (liquid)": 454,
886
+ "Punk rock": 221,
887
+ "Purr": 82,
888
+ "Quack": 105,
889
+ "Race car, auto racing": 315,
890
+ "Radio": 525,
891
+ "Rail transport": 328,
892
+ "Railroad car, train wagon": 332,
893
+ "Rain": 289,
894
+ "Rain on surface": 291,
895
+ "Raindrop": 290,
896
+ "Rapping": 36,
897
+ "Ratchet, pawl": 405,
898
+ "Rattle": 135,
899
+ "Rattle (instrument)": 175,
900
+ "Reggae": 228,
901
+ "Reverberation": 511,
902
+ "Reversing beeps": 319,
903
+ "Rhythm and blues": 226,
904
+ "Rimshot": 166,
905
+ "Ringtone": 391,
906
+ "Roar": 110,
907
+ "Roaring cats (lions, tigers)": 109,
908
+ "Rock and roll": 224,
909
+ "Rock music": 219,
910
+ "Rodents, rats, mice": 123,
911
+ "Roll": 477,
912
+ "Rowboat, canoe, kayak": 303,
913
+ "Rub": 476,
914
+ "Rumble": 493,
915
+ "Run": 51,
916
+ "Rustle": 487,
917
+ "Rustling leaves": 284,
918
+ "Sad music": 278,
919
+ "Sailboat, sailing ship": 302,
920
+ "Salsa music": 249,
921
+ "Sampler": 159,
922
+ "Sanding": 423,
923
+ "Sawing": 421,
924
+ "Saxophone": 197,
925
+ "Scary music": 282,
926
+ "Scissors": 381,
927
+ "Scrape": 475,
928
+ "Scratch": 474,
929
+ "Scratching (performance technique)": 215,
930
+ "Screaming": 14,
931
+ "Sewing machine": 411,
932
+ "Shatter": 443,
933
+ "Sheep": 97,
934
+ "Ship": 305,
935
+ "Shofar": 212,
936
+ "Shout": 8,
937
+ "Shuffle": 52,
938
+ "Shuffling cards": 383,
939
+ "Sidetone": 518,
940
+ "Sigh": 26,
941
+ "Silence": 500,
942
+ "Sine wave": 501,
943
+ "Singing": 27,
944
+ "Singing bowl": 214,
945
+ "Single-lens reflex camera": 417,
946
+ "Sink (filling or washing)": 371,
947
+ "Siren": 396,
948
+ "Sitar": 148,
949
+ "Sizzle": 490,
950
+ "Ska": 263,
951
+ "Skateboard": 342,
952
+ "Skidding": 312,
953
+ "Slam": 358,
954
+ "Slap, smack": 467,
955
+ "Sliding door": 357,
956
+ "Slosh": 446,
957
+ "Smash, crash": 469,
958
+ "Smoke detector, smoke alarm": 399,
959
+ "Snake": 134,
960
+ "Snare drum": 165,
961
+ "Sneeze": 49,
962
+ "Snicker": 19,
963
+ "Sniff": 50,
964
+ "Snoring": 43,
965
+ "Snort": 46,
966
+ "Sonar": 457,
967
+ "Song": 266,
968
+ "Soul music": 227,
969
+ "Sound effect": 504,
970
+ "Soundtrack music": 270,
971
+ "Speech": 0,
972
+ "Speech synthesizer": 7,
973
+ "Splash, splatter": 445,
974
+ "Splinter": 439,
975
+ "Spray": 453,
976
+ "Squawk": 114,
977
+ "Squeak": 361,
978
+ "Squeal": 485,
979
+ "Squish": 447,
980
+ "Static": 515,
981
+ "Steam": 296,
982
+ "Steam whistle": 403,
983
+ "Steel guitar, slide guitar": 144,
984
+ "Steelpan": 183,
985
+ "Stir": 455,
986
+ "Stomach rumble": 57,
987
+ "Stream": 292,
988
+ "String section": 190,
989
+ "Strum": 146,
990
+ "Subway, metro, underground": 334,
991
+ "Swing music": 230,
992
+ "Synthesizer": 158,
993
+ "Synthetic singing": 35,
994
+ "Tabla": 170,
995
+ "Tambourine": 174,
996
+ "Tap": 360,
997
+ "Tapping (guitar technique)": 145,
998
+ "Tearing": 480,
999
+ "Techno": 241,
1000
+ "Telephone": 389,
1001
+ "Telephone bell ringing": 390,
1002
+ "Telephone dialing, DTMF": 392,
1003
+ "Television": 524,
1004
+ "Tender music": 279,
1005
+ "Theme music": 268,
1006
+ "Theremin": 213,
1007
+ "Throat clearing": 48,
1008
+ "Throbbing": 522,
1009
+ "Thump, thud": 460,
1010
+ "Thunder": 287,
1011
+ "Thunderstorm": 286,
1012
+ "Thunk": 461,
1013
+ "Tick": 407,
1014
+ "Tick-tock": 408,
1015
+ "Timpani": 169,
1016
+ "Tire squeal": 313,
1017
+ "Toilet flush": 374,
1018
+ "Tools": 418,
1019
+ "Toot": 309,
1020
+ "Toothbrush": 375,
1021
+ "Traditional music": 264,
1022
+ "Traffic noise, roadway noise": 327,
1023
+ "Train": 329,
1024
+ "Train horn": 331,
1025
+ "Train wheels squealing": 333,
1026
+ "Train whistle": 330,
1027
+ "Trance music": 247,
1028
+ "Trickle, dribble": 450,
1029
+ "Trombone": 188,
1030
+ "Truck": 316,
1031
+ "Trumpet": 187,
1032
+ "Tubular bells": 178,
1033
+ "Tuning fork": 204,
1034
+ "Turkey": 102,
1035
+ "Typewriter": 385,
1036
+ "Typing": 384,
1037
+ "Ukulele": 151,
1038
+ "Vacuum cleaner": 377,
1039
+ "Vehicle": 300,
1040
+ "Vehicle horn, car horn, honking": 308,
1041
+ "Vibraphone": 182,
1042
+ "Vibration": 523,
1043
+ "Video game music": 272,
1044
+ "Violin, fiddle": 191,
1045
+ "Vocal music": 254,
1046
+ "Wail, moan": 25,
1047
+ "Walk, footsteps": 53,
1048
+ "Water": 288,
1049
+ "Water tap, faucet": 370,
1050
+ "Waterfall": 293,
1051
+ "Waves, surf": 295,
1052
+ "Wedding music": 275,
1053
+ "Whack, thwack": 468,
1054
+ "Whale vocalization": 136,
1055
+ "Wheeze": 42,
1056
+ "Whimper": 24,
1057
+ "Whimper (dog)": 80,
1058
+ "Whip": 472,
1059
+ "Whir": 488,
1060
+ "Whispering": 15,
1061
+ "Whistle": 402,
1062
+ "Whistling": 40,
1063
+ "White noise": 520,
1064
+ "Whoop": 10,
1065
+ "Whoosh, swoosh, swish": 459,
1066
+ "Wild animals": 108,
1067
+ "Wind": 283,
1068
+ "Wind chime": 206,
1069
+ "Wind instrument, woodwind instrument": 195,
1070
+ "Wind noise (microphone)": 285,
1071
+ "Wood": 437,
1072
+ "Wood block": 173,
1073
+ "Writing": 387,
1074
+ "Yell": 11,
1075
+ "Yip": 76,
1076
+ "Yodeling": 29,
1077
+ "Zing": 497,
1078
+ "Zipper (clothing)": 378,
1079
+ "Zither": 150
1080
+ },
1081
+ "max_length": 1024,
1082
+ "model_type": "dass",
1083
+ "num_classes": 527,
1084
+ "num_mel_bins": 128,
1085
+ "patch_size": 4,
1086
+ "torch_dtype": "float32",
1087
+ "transformers_version": "4.50.0.dev0",
1088
+ "use_checkpoint": false
1089
+ }
configuration_dass.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Distilled Audio State-Space Model (DASS) configuration"""
3
+
4
+ from typing import Any, Dict
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.utils import logging
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ class DASSConfig(PretrainedConfig):
13
+ r"""
14
+ This is the configuration class to store the configuration of a [`DASSModel`]. It is used to instantiate a DASS
15
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
16
+ defaults will yield a similar configuration to that of the
17
+ [DASS-small](https://github.com/Saurabhbhati/DASS/) architecture.
18
+
19
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
20
+ documentation from [`PretrainedConfig`] for more information.
21
+
22
+ Args:
23
+ patch_size (`int`, *optional*, defaults to 4):
24
+ The size (resolution) of each patch.
25
+ embed_dim (`int`, *optional*, defaults to 96):
26
+ Dimensionality of patch embedding.
27
+ depths (`list(int)`, *optional*, defaults to `[2, 2, 8, 2]`):
28
+ Depth of each layer in the DASS encoder.
29
+ dims (`list(int)`, *optional*, defaults to `[96, 192, 384, 768]`):
30
+ Dimensionality of each layer in the DASS encoder.
31
+ drop_path_rate (`float`, *optional*, defaults to 0.2):
32
+ Stochastic depth rate.
33
+ num_classes (`int`, *optional*, defaults to 527):
34
+ Number of classes for classification.
35
+ max_length (`int`, *optional*, defaults to 1024):
36
+ Temporal dimension of the spectrograms.
37
+ num_mel_bins (`int`, *optional*, defaults to 128):
38
+ Frequency dimension of the spectrograms (number of Mel-frequency bins).
39
+ use_checkpoint (`bool`, *optional*, defaults to `False`):
40
+ Whether to use checkpointing to save memory.
41
+
42
+ Example:
43
+
44
+ ```python
45
+ >>> from transformers import DASSConfig, DASSModel
46
+
47
+ >>> # Initializing a DASS small style configuration
48
+ >>> configuration = DASSConfig()
49
+
50
+ >>> # Initializing a model (with random weights) from the DASS small style configuration
51
+ >>> model = DASSModel(configuration)
52
+
53
+ >>> # Accessing the model configuration
54
+ >>> configuration = model.config
55
+ ```"""
56
+
57
+ model_type = "dass"
58
+
59
+ def __init__(
60
+ self,
61
+ patch_size: int = 4,
62
+ embed_dim: int = 96,
63
+ depths: list = [2, 2, 8, 2],
64
+ dims: list =[96, 192, 384, 768],
65
+ drop_path_rate: float = 0.2,
66
+ num_classes: int = 527,
67
+ max_length: int = 1024,
68
+ num_mel_bins: int = 128,
69
+ use_checkpoint: bool = False,
70
+ **kwargs,
71
+ ):
72
+ super().__init__(**kwargs)
73
+
74
+ self.patch_size = patch_size
75
+ self.embed_dim = embed_dim
76
+ self.depths = depths
77
+ self.dims = dims
78
+ self.drop_path_rate = drop_path_rate
79
+ self.num_classes = num_classes
80
+ self.max_length = max_length
81
+ self.num_mel_bins = num_mel_bins
82
+ self.use_checkpoint = use_checkpoint
83
+
84
+ # Overwritten from the parent class: DASS is not compatible with `generate`, but has a config parameter sharing the
85
+ # same name (`max_length`). Sharing the same name triggers checks regarding the config -> generation_config
86
+ # generative parameters deprecation cycle, overwriting this function prevents this from happening.
87
+ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
88
+ return {}
89
+
90
+
91
+ __all__ = ["DASSConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd691a24b6075e62531fe8ee242d5eb196f045ab678bc468715d648a217bd8e2
3
+ size 194645532
modeling_dass.py ADDED
@@ -0,0 +1,1228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # VMamba backbone is from https://github.com/MzeroMiko/VMamba/blob/main/vmamba.py
3
+ # DASSLayer, DASSModel, DASSForAudioClassification are implemnted based on VMamba and AST
4
+ #
5
+ """Distilled Audio State-Space Model (DASS) model"""
6
+
7
+ import math
8
+ import torch
9
+ import warnings
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint as checkpoint
13
+ from functools import partial
14
+ from typing import Optional, Callable, Any, Union
15
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
16
+ from transformers.modeling_outputs import SequenceClassifierOutput
17
+
18
+ from transformers.utils import logging
19
+ from transformers.modeling_utils import PreTrainedModel
20
+
21
+ from .configuration_dass import DASSConfig
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ # General docstring
26
+ _CONFIG_FOR_DOC = "DASSConfig"
27
+
28
+ WITH_TRITON = True
29
+ # WITH_TRITON = False
30
+ try:
31
+ import triton
32
+ import triton.language as tl
33
+ except:
34
+ WITH_TRITON = False
35
+ warnings.warn("Triton not installed, fall back to pytorch implements.")
36
+
37
+ # to make sure cached_property can be loaded for triton
38
+ if WITH_TRITON:
39
+ try:
40
+ from functools import cached_property
41
+ except:
42
+ warnings.warn("if you are using py37, add this line to functools.py: "
43
+ "cached_property = lambda func: property(lru_cache()(func))")
44
+
45
+ # torch implementation ========================================
46
+ def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
47
+ if in_channel_first:
48
+ B, C, H, W = x.shape
49
+ if scans == 0:
50
+ y = x.new_empty((B, 4, C, H * W))
51
+ y[:, 0, :, :] = x.flatten(2, 3)
52
+ y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
53
+ y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
54
+ elif scans == 1:
55
+ y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
56
+ elif scans == 2:
57
+ y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
58
+ y = torch.cat([y, y.flip(dims=[-1])], dim=1)
59
+ elif scans == 3:
60
+ y = x.new_empty((B, 4, C, H * W))
61
+ y[:, 0, :, :] = x.flatten(2, 3)
62
+ y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3)
63
+ y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3)
64
+ y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3)
65
+ else:
66
+ B, H, W, C = x.shape
67
+ if scans == 0:
68
+ y = x.new_empty((B, H * W, 4, C))
69
+ y[:, :, 0, :] = x.flatten(1, 2)
70
+ y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
71
+ y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
72
+ elif scans == 1:
73
+ y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
74
+ elif scans == 2:
75
+ y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
76
+ y = torch.cat([y, y.flip(dims=[1])], dim=2)
77
+ elif scans == 3:
78
+ y = x.new_empty((B, H * W, 4, C))
79
+ y[:, :, 0, :] = x.flatten(1, 2)
80
+ y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2)
81
+ y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2)
82
+ y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2)
83
+
84
+ if in_channel_first and (not out_channel_first):
85
+ y = y.permute(0, 3, 1, 2).contiguous()
86
+ elif (not in_channel_first) and out_channel_first:
87
+ y = y.permute(0, 2, 3, 1).contiguous()
88
+
89
+ return y
90
+
91
+
92
+ def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
93
+ if out_channel_first:
94
+ B, K, D, H, W = y.shape
95
+ y = y.view(B, K, D, -1)
96
+ if scans == 0:
97
+ y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
98
+ y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
99
+ elif scans == 1:
100
+ y = y.sum(1)
101
+ elif scans == 2:
102
+ y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
103
+ y = y.sum(1)
104
+ elif scans == 3:
105
+ oy = y[:, 0, :, :].contiguous().view(B, D, -1)
106
+ oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3)
107
+ oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3)
108
+ oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3)
109
+ y = oy
110
+ else:
111
+ B, H, W, K, D = y.shape
112
+ y = y.view(B, -1, K, D)
113
+ if scans == 0:
114
+ y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
115
+ y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
116
+ elif scans == 1:
117
+ y = y.sum(2)
118
+ elif scans == 2:
119
+ y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
120
+ y = y.sum(2)
121
+ elif scans == 3:
122
+ oy = y[:, :, 0, :].contiguous().view(B, -1, D)
123
+ oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2)
124
+ oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2)
125
+ oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2)
126
+ y = oy
127
+
128
+ if in_channel_first and (not out_channel_first):
129
+ y = y.permute(0, 2, 1).contiguous()
130
+ elif (not in_channel_first) and out_channel_first:
131
+ y = y.permute(0, 2, 1).contiguous()
132
+
133
+ return y
134
+
135
+
136
+ def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
137
+ if in_channel_first:
138
+ B, _, C, H, W = x.shape
139
+ if scans == 0:
140
+ y = torch.stack([
141
+ x[:, 0].flatten(2, 3),
142
+ x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
143
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
144
+ torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
145
+ ], dim=1)
146
+ elif scans == 1:
147
+ y = x.flatten(2, 3)
148
+ elif scans == 2:
149
+ y = torch.stack([
150
+ x[:, 0].flatten(2, 3),
151
+ x[:, 1].flatten(2, 3),
152
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
153
+ torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
154
+ ], dim=1)
155
+ elif scans == 3:
156
+ y = torch.stack([
157
+ x[:, 0, :, :, :].flatten(2, 3),
158
+ torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3),
159
+ torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3),
160
+ torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3),
161
+ ], dim=1)
162
+
163
+ else:
164
+ B, H, W, _, C = x.shape
165
+ if scans == 0:
166
+ y = torch.stack([
167
+ x[:, :, :, 0].flatten(1, 2),
168
+ x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
169
+ torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
170
+ torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
171
+ ], dim=2)
172
+ elif scans == 1:
173
+ y = x.flatten(1, 2)
174
+ elif scans == 2:
175
+ y = torch.stack([
176
+ x[:, 0].flatten(1, 2),
177
+ x[:, 1].flatten(1, 2),
178
+ torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
179
+ torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
180
+ ], dim=2)
181
+ elif scans == 3:
182
+ y = torch.stack([
183
+ x[:, :, :, 0, :].flatten(1, 2),
184
+ torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2),
185
+ torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2),
186
+ torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2),
187
+ ], dim=1)
188
+
189
+ if in_channel_first and (not out_channel_first):
190
+ y = y.permute(0, 3, 1, 2).contiguous()
191
+ elif (not in_channel_first) and out_channel_first:
192
+ y = y.permute(0, 2, 3, 1).contiguous()
193
+
194
+ return y
195
+
196
+
197
+ def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
198
+ if out_channel_first:
199
+ B, K, D, H, W = y.shape
200
+ y = y.view(B, K, D, -1)
201
+ if scans == 0:
202
+ y = torch.stack([
203
+ y[:, 0],
204
+ y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
205
+ torch.flip(y[:, 2], dims=[-1]),
206
+ torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
207
+ ], dim=1)
208
+ elif scans == 1:
209
+ y = y
210
+ elif scans == 2:
211
+ y = torch.stack([
212
+ y[:, 0],
213
+ y[:, 1],
214
+ torch.flip(y[:, 2], dims=[-1]),
215
+ torch.flip(y[:, 3], dims=[-1]),
216
+ ], dim=1)
217
+ elif scans == 3:
218
+ y = torch.stack([
219
+ y[:, 0, :, :].contiguous().view(B, D, -1),
220
+ torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3),
221
+ torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3),
222
+ torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3),
223
+ ], dim=1)
224
+ else:
225
+ B, H, W, K, D = y.shape
226
+ y = y.view(B, -1, K, D)
227
+ if scans == 0:
228
+ y = torch.stack([
229
+ y[:, :, 0],
230
+ y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
231
+ torch.flip(y[:, :, 2], dims=[1]),
232
+ torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
233
+ ], dim=2)
234
+ elif scans == 1:
235
+ y = y
236
+ elif scans == 2:
237
+ y = torch.stack([
238
+ y[:, :, 0],
239
+ y[:, :, 1],
240
+ torch.flip(y[:, :, 2], dims=[1]),
241
+ torch.flip(y[:, :, 3], dims=[1]),
242
+ ], dim=2)
243
+ elif scans == 3:
244
+ y = torch.stack([
245
+ y[:, :, 0, :].contiguous().view(B, -1, D),
246
+ torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2),
247
+ torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2),
248
+ torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2),
249
+ ], dim=2)
250
+
251
+ if out_channel_first and (not in_channel_first):
252
+ y = y.permute(0, 3, 1, 2).contiguous()
253
+ elif (not out_channel_first) and in_channel_first:
254
+ y = y.permute(0, 2, 3, 1).contiguous()
255
+
256
+ return y
257
+
258
+
259
+ class CrossScanF(torch.autograd.Function):
260
+ @staticmethod
261
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
262
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
263
+ # y: (B, 4, C, H * W) | (B, H * W, 4, C)
264
+ ctx.in_channel_first = in_channel_first
265
+ ctx.out_channel_first = out_channel_first
266
+ ctx.one_by_one = one_by_one
267
+ ctx.scans = scans
268
+
269
+ if one_by_one:
270
+ B, K, C, H, W = x.shape
271
+ if not in_channel_first:
272
+ B, H, W, K, C = x.shape
273
+ else:
274
+ B, C, H, W = x.shape
275
+ if not in_channel_first:
276
+ B, H, W, C = x.shape
277
+ ctx.shape = (B, C, H, W)
278
+
279
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
280
+ y = _fn(x, in_channel_first, out_channel_first, scans)
281
+
282
+ return y
283
+
284
+ @staticmethod
285
+ def backward(ctx, ys: torch.Tensor):
286
+ # out: (b, k, d, l)
287
+ in_channel_first = ctx.in_channel_first
288
+ out_channel_first = ctx.out_channel_first
289
+ one_by_one = ctx.one_by_one
290
+ scans = ctx.scans
291
+ B, C, H, W = ctx.shape
292
+
293
+ ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
294
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
295
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
296
+
297
+ if one_by_one:
298
+ y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
299
+ else:
300
+ y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
301
+
302
+ return y, None, None, None, None
303
+
304
+
305
+ class CrossMergeF(torch.autograd.Function):
306
+ @staticmethod
307
+ def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
308
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
309
+ # y: (B, 4, C, H * W) | (B, H * W, 4, C)
310
+ ctx.in_channel_first = in_channel_first
311
+ ctx.out_channel_first = out_channel_first
312
+ ctx.one_by_one = one_by_one
313
+ ctx.scans = scans
314
+
315
+ B, K, C, H, W = ys.shape
316
+ if not out_channel_first:
317
+ B, H, W, K, C = ys.shape
318
+ ctx.shape = (B, C, H, W)
319
+
320
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
321
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
322
+
323
+ return y
324
+
325
+ @staticmethod
326
+ def backward(ctx, x: torch.Tensor):
327
+ # B, D, L = x.shape
328
+ # out: (b, k, d, h, w)
329
+ in_channel_first = ctx.in_channel_first
330
+ out_channel_first = ctx.out_channel_first
331
+ one_by_one = ctx.one_by_one
332
+ scans = ctx.scans
333
+ B, C, H, W = ctx.shape
334
+
335
+ if not one_by_one:
336
+ if in_channel_first:
337
+ x = x.view(B, C, H, W)
338
+ else:
339
+ x = x.view(B, H, W, C)
340
+ else:
341
+ if in_channel_first:
342
+ x = x.view(B, 4, C, H, W)
343
+ else:
344
+ x = x.view(B, H, W, 4, C)
345
+
346
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
347
+ x = _fn(x, in_channel_first, out_channel_first, scans)
348
+ x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
349
+
350
+ return x, None, None, None, None
351
+
352
+
353
+ # triton implements ========================================
354
+
355
+ @triton.jit
356
+ def triton_cross_scan_flex(
357
+ x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
358
+ y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
359
+ x_layout: tl.constexpr,
360
+ y_layout: tl.constexpr,
361
+ operation: tl.constexpr,
362
+ onebyone: tl.constexpr,
363
+ scans: tl.constexpr,
364
+ BC: tl.constexpr,
365
+ BH: tl.constexpr,
366
+ BW: tl.constexpr,
367
+ DC: tl.constexpr,
368
+ DH: tl.constexpr,
369
+ DW: tl.constexpr,
370
+ NH: tl.constexpr,
371
+ NW: tl.constexpr,
372
+ ):
373
+ # x_layout = 0
374
+ # y_layout = 1 # 0 BCHW, 1 BHWC
375
+ # operation = 0 # 0 scan, 1 merge
376
+ # onebyone = 0 # 0 false, 1 true
377
+ # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
378
+
379
+ i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
380
+ i_h, i_w = (i_hw // NW), (i_hw % NW)
381
+ _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
382
+ _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
383
+ _mask_hw = _mask_h[:, None] & _mask_w[None, :]
384
+ _for_C = min(DC - i_c * BC, BC)
385
+
386
+ pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
387
+ pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
388
+ neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
389
+ neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
390
+ if scans == 0:
391
+ # none; trans; flip; trans + flip;
392
+ HWRoute0 = pos_h * DW + pos_w
393
+ HWRoute1 = pos_w * DH + pos_h # trans
394
+ HWRoute2 = neg_h * DW + neg_w # flip
395
+ HWRoute3 = neg_w * DH + neg_h # trans + flip
396
+ elif scans == 1:
397
+ # none; none; none; none;
398
+ HWRoute0 = pos_h * DW + pos_w
399
+ HWRoute1 = HWRoute0
400
+ HWRoute2 = HWRoute0
401
+ HWRoute3 = HWRoute0
402
+ elif scans == 2:
403
+ # none; none; flip; flip;
404
+ HWRoute0 = pos_h * DW + pos_w
405
+ HWRoute1 = HWRoute0
406
+ HWRoute2 = neg_h * DW + neg_w # flip
407
+ HWRoute3 = HWRoute2
408
+ elif scans == 3:
409
+ # none; rot90; rot180==flip; rot270;
410
+ HWRoute0 = pos_h * DW + pos_w
411
+ HWRoute1 = neg_w * DH + pos_h
412
+ HWRoute2 = neg_h * DW + neg_w
413
+ HWRoute3 = pos_w * DH + neg_h
414
+
415
+ _tmp1 = DC * DH * DW
416
+
417
+ y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
418
+ if y_layout == 0:
419
+ p_y1 = y_ptr_base + HWRoute0
420
+ p_y2 = y_ptr_base + _tmp1 + HWRoute1
421
+ p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
422
+ p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
423
+ else:
424
+ p_y1 = y_ptr_base + HWRoute0 * 4 * DC
425
+ p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
426
+ p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
427
+ p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
428
+
429
+ if onebyone == 0:
430
+ x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
431
+ if x_layout == 0:
432
+ p_x = x_ptr_base + HWRoute0
433
+ else:
434
+ p_x = x_ptr_base + HWRoute0 * DC
435
+
436
+ if operation == 0:
437
+ for idxc in range(_for_C):
438
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
439
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
440
+ _x = tl.load(p_x + _idx_x, mask=_mask_hw)
441
+ tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
442
+ tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
443
+ tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
444
+ tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
445
+ elif operation == 1:
446
+ for idxc in range(_for_C):
447
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
448
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
449
+ _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
450
+ _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
451
+ _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
452
+ _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
453
+ tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
454
+
455
+ else:
456
+ x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
457
+ if x_layout == 0:
458
+ p_x1 = x_ptr_base + HWRoute0
459
+ p_x2 = p_x1 + _tmp1
460
+ p_x3 = p_x2 + _tmp1
461
+ p_x4 = p_x3 + _tmp1
462
+ else:
463
+ p_x1 = x_ptr_base + HWRoute0 * 4 * DC
464
+ p_x2 = p_x1 + DC
465
+ p_x3 = p_x2 + DC
466
+ p_x4 = p_x3 + DC
467
+
468
+ if operation == 0:
469
+ for idxc in range(_for_C):
470
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
471
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
472
+ tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
473
+ tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
474
+ tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
475
+ tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
476
+ else:
477
+ for idxc in range(_for_C):
478
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
479
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
480
+ tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
481
+ tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
482
+ tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
483
+ tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
484
+
485
+
486
+ class CrossScanTritonF(torch.autograd.Function):
487
+ @staticmethod
488
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
489
+ if one_by_one:
490
+ if in_channel_first:
491
+ B, _, C, H, W = x.shape
492
+ else:
493
+ B, H, W, _, C = x.shape
494
+ else:
495
+ if in_channel_first:
496
+ B, C, H, W = x.shape
497
+ else:
498
+ B, H, W, C = x.shape
499
+ B, C, H, W = int(B), int(C), int(H), int(W)
500
+ BC, BH, BW = 1, 32, 32
501
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
502
+
503
+ ctx.in_channel_first = in_channel_first
504
+ ctx.out_channel_first = out_channel_first
505
+ ctx.one_by_one = one_by_one
506
+ ctx.scans = scans
507
+ ctx.shape = (B, C, H, W)
508
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
509
+
510
+ y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
511
+ triton_cross_scan_flex[(NH * NW, NC, B)](
512
+ x.contiguous(), y,
513
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
514
+ BC, BH, BW, C, H, W, NH, NW
515
+ )
516
+ return y
517
+
518
+ @staticmethod
519
+ def backward(ctx, y: torch.Tensor):
520
+ in_channel_first = ctx.in_channel_first
521
+ out_channel_first = ctx.out_channel_first
522
+ one_by_one = ctx.one_by_one
523
+ scans = ctx.scans
524
+ B, C, H, W = ctx.shape
525
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
526
+ if one_by_one:
527
+ x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
528
+ else:
529
+ x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
530
+
531
+ triton_cross_scan_flex[(NH * NW, NC, B)](
532
+ x, y.contiguous(),
533
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
534
+ BC, BH, BW, C, H, W, NH, NW
535
+ )
536
+ return x, None, None, None, None
537
+
538
+
539
+ class CrossMergeTritonF(torch.autograd.Function):
540
+ @staticmethod
541
+ def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
542
+ if out_channel_first:
543
+ B, _, C, H, W = y.shape
544
+ else:
545
+ B, H, W, _, C = y.shape
546
+ B, C, H, W = int(B), int(C), int(H), int(W)
547
+ BC, BH, BW = 1, 32, 32
548
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
549
+ ctx.in_channel_first = in_channel_first
550
+ ctx.out_channel_first = out_channel_first
551
+ ctx.one_by_one = one_by_one
552
+ ctx.scans = scans
553
+ ctx.shape = (B, C, H, W)
554
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
555
+ if one_by_one:
556
+ x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
557
+ else:
558
+ x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
559
+ triton_cross_scan_flex[(NH * NW, NC, B)](
560
+ x, y.contiguous(),
561
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
562
+ BC, BH, BW, C, H, W, NH, NW
563
+ )
564
+ return x
565
+
566
+ @staticmethod
567
+ def backward(ctx, x: torch.Tensor):
568
+ in_channel_first = ctx.in_channel_first
569
+ out_channel_first = ctx.out_channel_first
570
+ one_by_one = ctx.one_by_one
571
+ scans = ctx.scans
572
+ B, C, H, W = ctx.shape
573
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
574
+ y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
575
+ triton_cross_scan_flex[(NH * NW, NC, B)](
576
+ x.contiguous(), y,
577
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
578
+ BC, BH, BW, C, H, W, NH, NW
579
+ )
580
+ return y, None, None, None, None, None
581
+
582
+
583
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
584
+ def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
585
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
586
+ # y: (B, 4, C, L) | (B, L, 4, C)
587
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
588
+ CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
589
+ if x.is_cuda:
590
+ with torch.cuda.device(x.device):
591
+ return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
592
+ else:
593
+ return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
594
+
595
+
596
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
597
+ def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
598
+ # y: (B, 4, C, L) | (B, L, 4, C)
599
+ # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
600
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
601
+ CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
602
+ if y.is_cuda:
603
+ with torch.cuda.device(y.device):
604
+ return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
605
+ else:
606
+ return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
607
+
608
+
609
+ ##########################################################
610
+ # csms6s.py
611
+ ##########################################################
612
+
613
+ WITH_SELECTIVESCAN_MAMBA = True
614
+ try:
615
+ import selective_scan_cuda
616
+ except ImportError:
617
+ WITH_SELECTIVESCAN_MAMBA = False
618
+
619
+
620
+ def selective_scan_torch(
621
+ u: torch.Tensor, # (B, K * C, L)
622
+ delta: torch.Tensor, # (B, K * C, L)
623
+ A: torch.Tensor, # (K * C, N)
624
+ B: torch.Tensor, # (B, K, N, L)
625
+ C: torch.Tensor, # (B, K, N, L)
626
+ D: torch.Tensor = None, # (K * C)
627
+ delta_bias: torch.Tensor = None, # (K * C)
628
+ delta_softplus=True,
629
+ oflex=True,
630
+ *args,
631
+ **kwargs
632
+ ):
633
+ dtype_in = u.dtype
634
+ Batch, K, N, L = B.shape
635
+ KCdim = u.shape[1]
636
+ Cdim = int(KCdim / K)
637
+ assert u.shape == (Batch, KCdim, L)
638
+ assert delta.shape == (Batch, KCdim, L)
639
+ assert A.shape == (KCdim, N)
640
+ assert C.shape == B.shape
641
+
642
+ if delta_bias is not None:
643
+ delta = delta + delta_bias[..., None]
644
+ if delta_softplus:
645
+ delta = torch.nn.functional.softplus(delta)
646
+
647
+ u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
648
+ B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
649
+ C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
650
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
651
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
652
+
653
+ if True:
654
+ x = A.new_zeros((Batch, KCdim, N))
655
+ ys = []
656
+ for i in range(L):
657
+ x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
658
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
659
+ ys.append(y)
660
+ y = torch.stack(ys, dim=2) # (B, C, L)
661
+
662
+ out = y if D is None else y + u * D.unsqueeze(-1)
663
+ return out if oflex else out.to(dtype=dtype_in)
664
+
665
+
666
+ class SelectiveScanCuda(torch.autograd.Function):
667
+ @staticmethod
668
+ @torch.cuda.amp.custom_fwd
669
+ def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
670
+ ctx.delta_softplus = delta_softplus
671
+ # backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
672
+ # backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
673
+ backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
674
+ ctx.backend = backend
675
+ if backend == "oflex":
676
+ out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
677
+ elif backend == "mamba":
678
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
679
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
680
+ return out
681
+
682
+ @staticmethod
683
+ @torch.cuda.amp.custom_bwd
684
+ def backward(ctx, dout, *args):
685
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
686
+ backend = ctx.backend
687
+ if dout.stride(-1) != 1:
688
+ dout = dout.contiguous()
689
+ if backend == "oflex":
690
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
691
+ u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
692
+ )
693
+ elif backend == "mamba":
694
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
695
+ u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
696
+ False
697
+ )
698
+ return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
699
+
700
+
701
+ def selective_scan_fn(
702
+ u: torch.Tensor, # (B, K * C, L)
703
+ delta: torch.Tensor, # (B, K * C, L)
704
+ A: torch.Tensor, # (K * C, N)
705
+ B: torch.Tensor, # (B, K, N, L)
706
+ C: torch.Tensor, # (B, K, N, L)
707
+ D: torch.Tensor = None, # (K * C)
708
+ delta_bias: torch.Tensor = None, # (K * C)
709
+ delta_softplus=True,
710
+ oflex=True,
711
+ backend=None,
712
+ ):
713
+ fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply
714
+ return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
715
+
716
+ ##########################################################
717
+ ############## HuggingFace modeling file #################
718
+ ##########################################################
719
+
720
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
721
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
722
+
723
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
724
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
725
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
726
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
727
+ 'survival rate' as the argument.
728
+
729
+ """
730
+ if drop_prob == 0. or not training:
731
+ return x
732
+ keep_prob = 1 - drop_prob
733
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
734
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
735
+ if keep_prob > 0.0 and scale_by_keep:
736
+ random_tensor.div_(keep_prob)
737
+ return x * random_tensor
738
+
739
+ class DropPath(nn.Module):
740
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
741
+ """
742
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
743
+ super(DropPath, self).__init__()
744
+ self.drop_prob = drop_prob
745
+ self.scale_by_keep = scale_by_keep
746
+
747
+ def forward(self, x):
748
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
749
+
750
+ def extra_repr(self):
751
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
752
+
753
+ class DASSLinear2d(nn.Linear):
754
+ def __init__(self, *args, groups=1, **kwargs):
755
+ nn.Linear.__init__(self, *args, **kwargs)
756
+ self.groups = groups
757
+
758
+ def forward(self, x: torch.Tensor):
759
+ if len(x.shape) == 4:
760
+ return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups)
761
+ elif len(x.shape) == 3:
762
+ return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups)
763
+
764
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
765
+ self_state_dict = self.state_dict()
766
+ load_state_dict_keys = list(state_dict.keys())
767
+ if prefix + "weight" in load_state_dict_keys:
768
+ state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"])
769
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
770
+
771
+
772
+ class DASSLayerNorm2d(nn.LayerNorm):
773
+ def __init__(self, *args, **kwargs):
774
+ nn.LayerNorm.__init__(self, *args, **kwargs)
775
+
776
+ def forward(self, x: torch.Tensor):
777
+ x = x.permute(0, 2, 3, 1)
778
+ x = nn.LayerNorm.forward(self, x)
779
+ x = x.permute(0, 3, 1, 2)
780
+ return x
781
+
782
+
783
+ class DASSPatchEmbeddings(nn.Module):
784
+ """
785
+ This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
786
+ seq_length, hidden_size)` to be consumed by a State-space model.
787
+ """
788
+
789
+ def __init__(self, patch_size=4,embed_dim=96):
790
+ super().__init__()
791
+
792
+ stride = patch_size // 2
793
+ kernel_size = stride + 1
794
+ padding = 1
795
+
796
+ self.projection = nn.Sequential(
797
+ nn.Conv2d(1, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
798
+ DASSLayerNorm2d(embed_dim // 2),
799
+ nn.GELU(),
800
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
801
+ DASSLayerNorm2d(embed_dim),
802
+ )
803
+
804
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
805
+ x = x.unsqueeze(1)
806
+ x = x.transpose(2, 3)
807
+ x = self.projection(x)
808
+ return x
809
+
810
+
811
+ class DASSDowsample(nn.Module):
812
+ """
813
+ This class downsamples the input tensor using a convolutional layer followed by a layer normalization.
814
+ """
815
+ def __init__(self, dim, out_dim, use_norm=True):
816
+ super().__init__()
817
+ self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1)
818
+ self.norm = DASSLayerNorm2d(out_dim) if use_norm else nn.Identity()
819
+
820
+ def forward(self, x):
821
+ x = self.down(x)
822
+ x = self.norm(x)
823
+ return x
824
+
825
+
826
+ class DASSMlp(nn.Module):
827
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
828
+ super().__init__()
829
+ out_features = out_features or in_features
830
+ hidden_features = hidden_features or in_features
831
+ self.fc1 = DASSLinear2d(in_features, hidden_features)
832
+ self.act = act_layer()
833
+ self.fc2 = DASSLinear2d(hidden_features, out_features)
834
+ self.drop = nn.Dropout(drop)
835
+
836
+ def forward(self, x):
837
+ x = self.fc1(x)
838
+ x = self.act(x)
839
+ x = self.drop(x)
840
+ x = self.fc2(x)
841
+ x = self.drop(x)
842
+ return x
843
+
844
+
845
+ class SS2D(nn.Module):
846
+ def __init__(
847
+ self,
848
+ # basic dims ===========
849
+ d_model=96,
850
+ d_state=16,
851
+ ssm_ratio=2.0,
852
+ dt_rank="auto",
853
+ act_layer=nn.SiLU,
854
+ # dwconv ===============
855
+ d_conv=3,
856
+ conv_bias=True,
857
+ # ======================
858
+ dropout=0.0,
859
+ bias=False,
860
+ # dt init ==============
861
+ dt_min=0.001,
862
+ dt_max=0.1,
863
+ dt_init="random",
864
+ dt_scale=1.0,
865
+ dt_init_floor=1e-4,
866
+ # forward_type="v05_noz" is always used
867
+ # ======================
868
+ **kwargs,
869
+ ):
870
+ super().__init__()
871
+ self.k_group = 4
872
+ self.d_model = int(d_model)
873
+ self.d_state = int(d_state)
874
+ self.d_inner = int(ssm_ratio * d_model)
875
+ self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
876
+ self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True)
877
+ self.with_dconv = d_conv > 1
878
+
879
+ # In projection
880
+ self.in_proj = DASSLinear2d(self.d_model, self.d_inner, bias=bias)
881
+ self.act: nn.Module = act_layer()
882
+
883
+ # Convolution
884
+ if self.with_dconv:
885
+ self.conv2d = nn.Conv2d(
886
+ in_channels=self.d_inner,
887
+ out_channels=self.d_inner,
888
+ groups=self.d_inner,
889
+ bias=conv_bias,
890
+ kernel_size=d_conv,
891
+ padding=(d_conv - 1) // 2,
892
+ )
893
+
894
+ # x_proj and dt_proj
895
+ self.x_proj = DASSLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False)
896
+ self.dt_projs = DASSLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False)
897
+
898
+ # out projection
899
+ self.out_proj = DASSLinear2d(self.d_inner, self.d_model, bias=bias)
900
+ self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
901
+
902
+ # Initialization
903
+ self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D(
904
+ self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
905
+ )
906
+ self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape)
907
+ # self.dt_projs.bias.data = self.dt_projs_bias.data.view(self.dt_projs.bias.shape)
908
+ del self.dt_projs_weight
909
+ # del self.dt_projs_bias
910
+ # Define out_norm directly with "LN2D"
911
+ self.out_norm = DASSLayerNorm2d(self.d_inner)
912
+
913
+ @staticmethod
914
+ def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
915
+ dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
916
+
917
+ dt_init_std = dt_rank**-0.5 * dt_scale
918
+ if dt_init == "constant":
919
+ nn.init.constant_(dt_proj.weight, dt_init_std)
920
+ elif dt_init == "random":
921
+ nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
922
+ else:
923
+ raise NotImplementedError
924
+
925
+ dt = torch.exp(
926
+ torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
927
+ + math.log(dt_min)
928
+ ).clamp(min=dt_init_floor)
929
+
930
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
931
+ with torch.no_grad():
932
+ dt_proj.bias.copy_(inv_dt)
933
+
934
+ return dt_proj
935
+
936
+ @staticmethod
937
+ def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
938
+ A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
939
+ A_log = torch.log(A)
940
+ if copies > 0:
941
+ A_log = A_log[None].repeat(copies, 1, 1).contiguous()
942
+ if merge:
943
+ A_log = A_log.flatten(0, 1)
944
+ A_log = nn.Parameter(A_log)
945
+ A_log._no_weight_decay = True
946
+ return A_log
947
+
948
+ @staticmethod
949
+ def D_init(d_inner, copies=-1, device=None, merge=True):
950
+ D = torch.ones(d_inner, device=device)
951
+ if copies > 0:
952
+ D = D[None].repeat(copies, 1).contiguous()
953
+ if merge:
954
+ D = D.flatten(0, 1)
955
+ D = nn.Parameter(D)
956
+ D._no_weight_decay = True
957
+ return D
958
+
959
+ @classmethod
960
+ def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
961
+ dt_projs = [
962
+ cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
963
+ for _ in range(k_group)
964
+ ]
965
+ dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0))
966
+ dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0))
967
+ del dt_projs
968
+
969
+ A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True)
970
+ Ds = cls.D_init(d_inner, copies=k_group, merge=True)
971
+ return A_logs, Ds, dt_projs_weight, dt_projs_bias
972
+
973
+ def forward_corev2(
974
+ self,
975
+ x: torch.Tensor,
976
+ force_fp32=False,
977
+ no_einsum=True,
978
+ ):
979
+ B, D, H, W = x.shape
980
+ N = self.d_state
981
+ L = H * W
982
+
983
+ xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True)
984
+ x_dbl = self.x_proj(xs.view(B, -1, L))
985
+ dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2)
986
+ dts = dts.contiguous().view(B, -1, L)
987
+ dts = self.dt_projs(dts)
988
+
989
+ xs = xs.view(B, -1, L)
990
+ dts = dts.contiguous().view(B, -1, L)
991
+ As = -self.A_logs.to(torch.float32).exp()
992
+ Ds = self.Ds.to(torch.float32)
993
+ Bs = Bs.contiguous().view(B, self.k_group, N, L)
994
+ Cs = Cs.contiguous().view(B, self.k_group, N, L)
995
+ delta_bias = self.dt_projs_bias.view(-1).to(torch.float32)
996
+
997
+ ys = selective_scan_fn(
998
+ xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba"
999
+ ).view(B, self.k_group, -1, H, W)
1000
+
1001
+ y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True)
1002
+ y = y.view(B, -1, H, W)
1003
+ y = self.out_norm(y)
1004
+ return y.to(x.dtype)
1005
+
1006
+ def forward(self, x: torch.Tensor):
1007
+ x = self.in_proj(x)
1008
+ x = self.conv2d(x)
1009
+
1010
+ x = self.act(x)
1011
+ y = self.forward_core(x)
1012
+
1013
+ out = self.dropout(self.out_proj(y))
1014
+ return out
1015
+
1016
+
1017
+ class VSSBlock(nn.Module):
1018
+ def __init__(
1019
+ self,
1020
+ hidden_dim: int = 0,
1021
+ drop_path: float = 0,
1022
+ ssm_d_state: int = 1,
1023
+ ssm_ratio=1.0,
1024
+ ssm_dt_rank: Any = "auto",
1025
+ ssm_act_layer=nn.SiLU,
1026
+ ssm_conv: int = 3,
1027
+ ssm_conv_bias=False,
1028
+ ssm_drop_rate: float = 0,
1029
+ mlp_ratio=4.0,
1030
+ mlp_act_layer=nn.GELU,
1031
+ mlp_drop_rate: float = 0.0,
1032
+ use_checkpoint: bool = False,
1033
+ post_norm: bool = False,
1034
+ **kwargs,
1035
+ ):
1036
+ super().__init__()
1037
+ self.ssm_branch = ssm_ratio > 0
1038
+ self.mlp_branch = mlp_ratio > 0
1039
+ self.use_checkpoint = use_checkpoint
1040
+ self.post_norm = post_norm
1041
+
1042
+ if self.ssm_branch:
1043
+ self.norm = DASSLayerNorm2d(hidden_dim)
1044
+ self.op = SS2D(
1045
+ d_model=hidden_dim,
1046
+ d_state=ssm_d_state,
1047
+ ssm_ratio=ssm_ratio,
1048
+ dt_rank=ssm_dt_rank,
1049
+ act_layer=ssm_act_layer,
1050
+ d_conv=ssm_conv,
1051
+ conv_bias=ssm_conv_bias,
1052
+ dropout=ssm_drop_rate,
1053
+ )
1054
+
1055
+ self.drop_path = DropPath(drop_path)
1056
+
1057
+ if self.mlp_branch:
1058
+ self.norm2 = DASSLayerNorm2d(hidden_dim)
1059
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio)
1060
+ self.mlp = DASSMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate)
1061
+
1062
+ def _forward(self, input: torch.Tensor):
1063
+ x = input
1064
+ if self.ssm_branch:
1065
+ if self.post_norm:
1066
+ x = x + self.drop_path(self.norm(self.op(x)))
1067
+ else:
1068
+ x = x + self.drop_path(self.op(self.norm(x)))
1069
+ if self.mlp_branch:
1070
+ if self.post_norm:
1071
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
1072
+ else:
1073
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1074
+ return x
1075
+
1076
+ def forward(self, input: torch.Tensor):
1077
+ if self.use_checkpoint:
1078
+ return checkpoint.checkpoint(self._forward, input)
1079
+ else:
1080
+ return self._forward(input)
1081
+
1082
+ class DASSLayer(nn.Module):
1083
+
1084
+ def __init__(
1085
+ self,
1086
+ input_dim,
1087
+ depth,
1088
+ drop_path=0.0,
1089
+ norm_layer=DASSLayerNorm2d,
1090
+ downsample=nn.Identity(),
1091
+ use_checkpoint=False,
1092
+ **kwargs,
1093
+ ):
1094
+ super().__init__()
1095
+ self.input_dim = input_dim
1096
+ self.use_checkpoint = use_checkpoint
1097
+
1098
+ self.blocks = nn.ModuleList()
1099
+ for i in range(depth):
1100
+ self.blocks.append(
1101
+ VSSBlock(hidden_dim=input_dim,
1102
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
1103
+ norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs,
1104
+ )
1105
+ )
1106
+
1107
+ self.downsample = downsample
1108
+
1109
+ def forward(self, x):
1110
+ for block in self.blocks:
1111
+ x = block(x)
1112
+
1113
+ x = self.downsample(x)
1114
+ return x
1115
+
1116
+ class DASSPreTrainedModel(PreTrainedModel):
1117
+ """
1118
+ An abstract class to handle weights initialization and
1119
+ a simple interface for downloading and loading pretrained models.
1120
+ """
1121
+
1122
+ config_class = DASSConfig
1123
+ base_model_prefix = "dass"
1124
+ supports_gradient_checkpointing = False
1125
+
1126
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
1127
+ """Initialize the weights"""
1128
+ if isinstance(module, nn.Linear):
1129
+ nn.init.trunc_normal_(module.weight, std=0.02)
1130
+ if isinstance(module, nn.Linear) and module.bias is not None:
1131
+ nn.init.constant_(module.bias, 0)
1132
+ elif isinstance(module, nn.LayerNorm):
1133
+ nn.init.constant_(module.bias, 0)
1134
+ nn.init.constant_(module.weight, 1.0)
1135
+
1136
+
1137
+ class DASSModel(DASSPreTrainedModel):
1138
+ def __init__(self, config):
1139
+ super().__init__(config)
1140
+ self.config = config
1141
+
1142
+ dims = config.dims
1143
+ if isinstance(dims, int):
1144
+ dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)]
1145
+
1146
+ self.dims = dims
1147
+ self.patch_embeddings = DASSPatchEmbeddings(patch_size=config.patch_size,
1148
+ embed_dim=dims[0])
1149
+
1150
+ self.num_layers = len(config.depths)
1151
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
1152
+ self.num_features = dims[-1]
1153
+
1154
+ self.layers = nn.ModuleList()
1155
+ for i in range(self.num_layers):
1156
+ layer = DASSLayer(
1157
+ input_dim=self.dims[i],
1158
+ depth=config.depths[i],
1159
+ drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])],
1160
+ downsample=DASSDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(),
1161
+ use_checkpoint=config.use_checkpoint,
1162
+ )
1163
+ self.layers.append(layer)
1164
+
1165
+ self.norm = DASSLayerNorm2d(self.num_features)
1166
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
1167
+
1168
+ def get_input_embeddings(self) -> DASSPatchEmbeddings:
1169
+ return self.patch_embeddings
1170
+
1171
+ def forward(self, input_values: torch.Tensor):
1172
+ x = self.patch_embeddings(input_values)
1173
+ for layer in self.layers:
1174
+ x = layer(x)
1175
+ x = self.norm(x)
1176
+ x = self.avgpool(x).flatten(1)
1177
+ return x
1178
+
1179
+
1180
+ class DASSForAudioClassification(DASSPreTrainedModel):
1181
+ def __init__(self, config):
1182
+ super().__init__(config)
1183
+
1184
+ self.num_classes = config.num_classes
1185
+ self.dass = DASSModel(config)
1186
+ self.head = nn.Linear(self.dass.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
1187
+
1188
+ # Initialize weights and apply final processing
1189
+ self.post_init()
1190
+
1191
+ def forward(
1192
+ self,
1193
+ input_values: Optional[torch.Tensor] = None,
1194
+ labels: Optional[torch.Tensor] = None,
1195
+ return_dict: Optional[bool] = None,
1196
+ ):
1197
+
1198
+ outputs = self.dass(
1199
+ input_values,
1200
+ )
1201
+
1202
+ logits = self.head(outputs)
1203
+
1204
+ loss = None
1205
+ if labels is not None:
1206
+ labels = labels.to(logits.device)
1207
+ if self.config.loss_type == "ce":
1208
+ loss_fct = CrossEntropyLoss()
1209
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1210
+ elif self.config.problem_type == "bce":
1211
+ loss_fct = BCEWithLogitsLoss()
1212
+ loss = loss_fct(logits, labels)
1213
+
1214
+ if return_dict:
1215
+ output = (logits,) + (outputs,)
1216
+ return ((loss,) + output) if loss is not None else output
1217
+
1218
+ return SequenceClassifierOutput(
1219
+ loss=loss,
1220
+ logits=logits,
1221
+ hidden_states=outputs,
1222
+ )
1223
+
1224
+ __all__ = [
1225
+ "DASSModel",
1226
+ "DASSPreTrainedModel",
1227
+ "DASSForAudioClassification",
1228
+ ]