File size: 54,009 Bytes
ccbb2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
program(1.3)
[buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.5.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "9.0"}})]
{
    func main<ios18>(tensor<fp16, [1, 1, 1152]> hidden_states) {
            tensor<int32, [3]> var_5 = const()[name = string("op_5"), val = tensor<int32, [3]>([0, 2, 1])];
            tensor<int32, [1]> input_axes_0 = const()[name = string("input_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 1152, 1]> var_6_cast_fp16 = transpose(perm = var_5, x = hidden_states)[name = string("transpose_16")];
            tensor<fp16, [1, 1152, 1, 1]> input_cast_fp16 = expand_dims(axes = input_axes_0, x = var_6_cast_fp16)[name = string("input_cast_fp16")];
            string var_29_pad_type_0 = const()[name = string("op_29_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_29_strides_0 = const()[name = string("op_29_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_29_pad_0 = const()[name = string("op_29_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_29_dilations_0 = const()[name = string("op_29_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_29_groups_0 = const()[name = string("op_29_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_9_promoted_to_fp16 = const()[name = string("op_9_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];
            tensor<fp16, [1, 16384, 1, 1]> var_29_cast_fp16 = conv(dilations = var_29_dilations_0, groups = var_29_groups_0, pad = var_29_pad_0, pad_type = var_29_pad_type_0, strides = var_29_strides_0, weight = var_9_promoted_to_fp16, x = input_cast_fp16)[name = string("op_29_cast_fp16")];
            tensor<int32, [1]> var_31_axes_0 = const()[name = string("op_31_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_31_cast_fp16 = squeeze(axes = var_31_axes_0, x = var_29_cast_fp16)[name = string("op_31_cast_fp16")];
            tensor<int32, [3]> logits_1_perm_0 = const()[name = string("logits_1_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_55_pad_type_0 = const()[name = string("op_55_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_55_strides_0 = const()[name = string("op_55_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_55_pad_0 = const()[name = string("op_55_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_55_dilations_0 = const()[name = string("op_55_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_55_groups_0 = const()[name = string("op_55_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_35_promoted_to_fp16 = const()[name = string("op_35_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(37748864)))];
            tensor<fp16, [1, 16384, 1, 1]> var_55_cast_fp16 = conv(dilations = var_55_dilations_0, groups = var_55_groups_0, pad = var_55_pad_0, pad_type = var_55_pad_type_0, strides = var_55_strides_0, weight = var_35_promoted_to_fp16, x = input_cast_fp16)[name = string("op_55_cast_fp16")];
            tensor<int32, [1]> var_57_axes_0 = const()[name = string("op_57_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_57_cast_fp16 = squeeze(axes = var_57_axes_0, x = var_55_cast_fp16)[name = string("op_57_cast_fp16")];
            tensor<int32, [3]> logits_3_perm_0 = const()[name = string("logits_3_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_81_pad_type_0 = const()[name = string("op_81_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_81_strides_0 = const()[name = string("op_81_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_81_pad_0 = const()[name = string("op_81_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_81_dilations_0 = const()[name = string("op_81_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_81_groups_0 = const()[name = string("op_81_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_61_promoted_to_fp16 = const()[name = string("op_61_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(75497664)))];
            tensor<fp16, [1, 16384, 1, 1]> var_81_cast_fp16 = conv(dilations = var_81_dilations_0, groups = var_81_groups_0, pad = var_81_pad_0, pad_type = var_81_pad_type_0, strides = var_81_strides_0, weight = var_61_promoted_to_fp16, x = input_cast_fp16)[name = string("op_81_cast_fp16")];
            tensor<int32, [1]> var_83_axes_0 = const()[name = string("op_83_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_83_cast_fp16 = squeeze(axes = var_83_axes_0, x = var_81_cast_fp16)[name = string("op_83_cast_fp16")];
            tensor<int32, [3]> logits_5_perm_0 = const()[name = string("logits_5_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_107_pad_type_0 = const()[name = string("op_107_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_107_strides_0 = const()[name = string("op_107_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_107_pad_0 = const()[name = string("op_107_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_107_dilations_0 = const()[name = string("op_107_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_107_groups_0 = const()[name = string("op_107_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_87_promoted_to_fp16 = const()[name = string("op_87_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(113246464)))];
            tensor<fp16, [1, 16384, 1, 1]> var_107_cast_fp16 = conv(dilations = var_107_dilations_0, groups = var_107_groups_0, pad = var_107_pad_0, pad_type = var_107_pad_type_0, strides = var_107_strides_0, weight = var_87_promoted_to_fp16, x = input_cast_fp16)[name = string("op_107_cast_fp16")];
            tensor<int32, [1]> var_109_axes_0 = const()[name = string("op_109_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_109_cast_fp16 = squeeze(axes = var_109_axes_0, x = var_107_cast_fp16)[name = string("op_109_cast_fp16")];
            tensor<int32, [3]> logits_7_perm_0 = const()[name = string("logits_7_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_133_pad_type_0 = const()[name = string("op_133_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_133_strides_0 = const()[name = string("op_133_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_133_pad_0 = const()[name = string("op_133_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_133_dilations_0 = const()[name = string("op_133_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_133_groups_0 = const()[name = string("op_133_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_113_promoted_to_fp16 = const()[name = string("op_113_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(150995264)))];
            tensor<fp16, [1, 16384, 1, 1]> var_133_cast_fp16 = conv(dilations = var_133_dilations_0, groups = var_133_groups_0, pad = var_133_pad_0, pad_type = var_133_pad_type_0, strides = var_133_strides_0, weight = var_113_promoted_to_fp16, x = input_cast_fp16)[name = string("op_133_cast_fp16")];
            tensor<int32, [1]> var_135_axes_0 = const()[name = string("op_135_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_135_cast_fp16 = squeeze(axes = var_135_axes_0, x = var_133_cast_fp16)[name = string("op_135_cast_fp16")];
            tensor<int32, [3]> logits_9_perm_0 = const()[name = string("logits_9_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_159_pad_type_0 = const()[name = string("op_159_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_159_strides_0 = const()[name = string("op_159_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_159_pad_0 = const()[name = string("op_159_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_159_dilations_0 = const()[name = string("op_159_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_159_groups_0 = const()[name = string("op_159_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_139_promoted_to_fp16 = const()[name = string("op_139_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(188744064)))];
            tensor<fp16, [1, 16384, 1, 1]> var_159_cast_fp16 = conv(dilations = var_159_dilations_0, groups = var_159_groups_0, pad = var_159_pad_0, pad_type = var_159_pad_type_0, strides = var_159_strides_0, weight = var_139_promoted_to_fp16, x = input_cast_fp16)[name = string("op_159_cast_fp16")];
            tensor<int32, [1]> var_161_axes_0 = const()[name = string("op_161_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_161_cast_fp16 = squeeze(axes = var_161_axes_0, x = var_159_cast_fp16)[name = string("op_161_cast_fp16")];
            tensor<int32, [3]> logits_11_perm_0 = const()[name = string("logits_11_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_185_pad_type_0 = const()[name = string("op_185_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_185_strides_0 = const()[name = string("op_185_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_185_pad_0 = const()[name = string("op_185_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_185_dilations_0 = const()[name = string("op_185_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_185_groups_0 = const()[name = string("op_185_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_165_promoted_to_fp16 = const()[name = string("op_165_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(226492864)))];
            tensor<fp16, [1, 16384, 1, 1]> var_185_cast_fp16 = conv(dilations = var_185_dilations_0, groups = var_185_groups_0, pad = var_185_pad_0, pad_type = var_185_pad_type_0, strides = var_185_strides_0, weight = var_165_promoted_to_fp16, x = input_cast_fp16)[name = string("op_185_cast_fp16")];
            tensor<int32, [1]> var_187_axes_0 = const()[name = string("op_187_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_187_cast_fp16 = squeeze(axes = var_187_axes_0, x = var_185_cast_fp16)[name = string("op_187_cast_fp16")];
            tensor<int32, [3]> logits_13_perm_0 = const()[name = string("logits_13_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_211_pad_type_0 = const()[name = string("op_211_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_211_strides_0 = const()[name = string("op_211_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_211_pad_0 = const()[name = string("op_211_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_211_dilations_0 = const()[name = string("op_211_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_211_groups_0 = const()[name = string("op_211_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_191_promoted_to_fp16 = const()[name = string("op_191_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(264241664)))];
            tensor<fp16, [1, 16384, 1, 1]> var_211_cast_fp16 = conv(dilations = var_211_dilations_0, groups = var_211_groups_0, pad = var_211_pad_0, pad_type = var_211_pad_type_0, strides = var_211_strides_0, weight = var_191_promoted_to_fp16, x = input_cast_fp16)[name = string("op_211_cast_fp16")];
            tensor<int32, [1]> var_213_axes_0 = const()[name = string("op_213_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_213_cast_fp16 = squeeze(axes = var_213_axes_0, x = var_211_cast_fp16)[name = string("op_213_cast_fp16")];
            tensor<int32, [3]> logits_15_perm_0 = const()[name = string("logits_15_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_237_pad_type_0 = const()[name = string("op_237_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_237_strides_0 = const()[name = string("op_237_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_237_pad_0 = const()[name = string("op_237_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_237_dilations_0 = const()[name = string("op_237_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_237_groups_0 = const()[name = string("op_237_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_217_promoted_to_fp16 = const()[name = string("op_217_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(301990464)))];
            tensor<fp16, [1, 16384, 1, 1]> var_237_cast_fp16 = conv(dilations = var_237_dilations_0, groups = var_237_groups_0, pad = var_237_pad_0, pad_type = var_237_pad_type_0, strides = var_237_strides_0, weight = var_217_promoted_to_fp16, x = input_cast_fp16)[name = string("op_237_cast_fp16")];
            tensor<int32, [1]> var_239_axes_0 = const()[name = string("op_239_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_239_cast_fp16 = squeeze(axes = var_239_axes_0, x = var_237_cast_fp16)[name = string("op_239_cast_fp16")];
            tensor<int32, [3]> logits_17_perm_0 = const()[name = string("logits_17_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_263_pad_type_0 = const()[name = string("op_263_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_263_strides_0 = const()[name = string("op_263_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_263_pad_0 = const()[name = string("op_263_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_263_dilations_0 = const()[name = string("op_263_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_263_groups_0 = const()[name = string("op_263_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_243_promoted_to_fp16 = const()[name = string("op_243_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(339739264)))];
            tensor<fp16, [1, 16384, 1, 1]> var_263_cast_fp16 = conv(dilations = var_263_dilations_0, groups = var_263_groups_0, pad = var_263_pad_0, pad_type = var_263_pad_type_0, strides = var_263_strides_0, weight = var_243_promoted_to_fp16, x = input_cast_fp16)[name = string("op_263_cast_fp16")];
            tensor<int32, [1]> var_265_axes_0 = const()[name = string("op_265_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_265_cast_fp16 = squeeze(axes = var_265_axes_0, x = var_263_cast_fp16)[name = string("op_265_cast_fp16")];
            tensor<int32, [3]> logits_19_perm_0 = const()[name = string("logits_19_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_289_pad_type_0 = const()[name = string("op_289_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_289_strides_0 = const()[name = string("op_289_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_289_pad_0 = const()[name = string("op_289_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_289_dilations_0 = const()[name = string("op_289_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_289_groups_0 = const()[name = string("op_289_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_269_promoted_to_fp16 = const()[name = string("op_269_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(377488064)))];
            tensor<fp16, [1, 16384, 1, 1]> var_289_cast_fp16 = conv(dilations = var_289_dilations_0, groups = var_289_groups_0, pad = var_289_pad_0, pad_type = var_289_pad_type_0, strides = var_289_strides_0, weight = var_269_promoted_to_fp16, x = input_cast_fp16)[name = string("op_289_cast_fp16")];
            tensor<int32, [1]> var_291_axes_0 = const()[name = string("op_291_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_291_cast_fp16 = squeeze(axes = var_291_axes_0, x = var_289_cast_fp16)[name = string("op_291_cast_fp16")];
            tensor<int32, [3]> logits_21_perm_0 = const()[name = string("logits_21_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_315_pad_type_0 = const()[name = string("op_315_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_315_strides_0 = const()[name = string("op_315_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_315_pad_0 = const()[name = string("op_315_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_315_dilations_0 = const()[name = string("op_315_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_315_groups_0 = const()[name = string("op_315_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_295_promoted_to_fp16 = const()[name = string("op_295_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(415236864)))];
            tensor<fp16, [1, 16384, 1, 1]> var_315_cast_fp16 = conv(dilations = var_315_dilations_0, groups = var_315_groups_0, pad = var_315_pad_0, pad_type = var_315_pad_type_0, strides = var_315_strides_0, weight = var_295_promoted_to_fp16, x = input_cast_fp16)[name = string("op_315_cast_fp16")];
            tensor<int32, [1]> var_317_axes_0 = const()[name = string("op_317_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_317_cast_fp16 = squeeze(axes = var_317_axes_0, x = var_315_cast_fp16)[name = string("op_317_cast_fp16")];
            tensor<int32, [3]> logits_23_perm_0 = const()[name = string("logits_23_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_341_pad_type_0 = const()[name = string("op_341_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_341_strides_0 = const()[name = string("op_341_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_341_pad_0 = const()[name = string("op_341_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_341_dilations_0 = const()[name = string("op_341_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_341_groups_0 = const()[name = string("op_341_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_321_promoted_to_fp16 = const()[name = string("op_321_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(452985664)))];
            tensor<fp16, [1, 16384, 1, 1]> var_341_cast_fp16 = conv(dilations = var_341_dilations_0, groups = var_341_groups_0, pad = var_341_pad_0, pad_type = var_341_pad_type_0, strides = var_341_strides_0, weight = var_321_promoted_to_fp16, x = input_cast_fp16)[name = string("op_341_cast_fp16")];
            tensor<int32, [1]> var_343_axes_0 = const()[name = string("op_343_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_343_cast_fp16 = squeeze(axes = var_343_axes_0, x = var_341_cast_fp16)[name = string("op_343_cast_fp16")];
            tensor<int32, [3]> logits_25_perm_0 = const()[name = string("logits_25_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_367_pad_type_0 = const()[name = string("op_367_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_367_strides_0 = const()[name = string("op_367_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_367_pad_0 = const()[name = string("op_367_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_367_dilations_0 = const()[name = string("op_367_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_367_groups_0 = const()[name = string("op_367_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_347_promoted_to_fp16 = const()[name = string("op_347_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(490734464)))];
            tensor<fp16, [1, 16384, 1, 1]> var_367_cast_fp16 = conv(dilations = var_367_dilations_0, groups = var_367_groups_0, pad = var_367_pad_0, pad_type = var_367_pad_type_0, strides = var_367_strides_0, weight = var_347_promoted_to_fp16, x = input_cast_fp16)[name = string("op_367_cast_fp16")];
            tensor<int32, [1]> var_369_axes_0 = const()[name = string("op_369_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_369_cast_fp16 = squeeze(axes = var_369_axes_0, x = var_367_cast_fp16)[name = string("op_369_cast_fp16")];
            tensor<int32, [3]> logits_27_perm_0 = const()[name = string("logits_27_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_393_pad_type_0 = const()[name = string("op_393_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_393_strides_0 = const()[name = string("op_393_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_393_pad_0 = const()[name = string("op_393_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_393_dilations_0 = const()[name = string("op_393_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_393_groups_0 = const()[name = string("op_393_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_373_promoted_to_fp16 = const()[name = string("op_373_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(528483264)))];
            tensor<fp16, [1, 16384, 1, 1]> var_393_cast_fp16 = conv(dilations = var_393_dilations_0, groups = var_393_groups_0, pad = var_393_pad_0, pad_type = var_393_pad_type_0, strides = var_393_strides_0, weight = var_373_promoted_to_fp16, x = input_cast_fp16)[name = string("op_393_cast_fp16")];
            tensor<int32, [1]> var_395_axes_0 = const()[name = string("op_395_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_395_cast_fp16 = squeeze(axes = var_395_axes_0, x = var_393_cast_fp16)[name = string("op_395_cast_fp16")];
            tensor<int32, [3]> logits_29_perm_0 = const()[name = string("logits_29_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            string var_419_pad_type_0 = const()[name = string("op_419_pad_type_0"), val = string("valid")];
            tensor<int32, [2]> var_419_strides_0 = const()[name = string("op_419_strides_0"), val = tensor<int32, [2]>([1, 1])];
            tensor<int32, [4]> var_419_pad_0 = const()[name = string("op_419_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
            tensor<int32, [2]> var_419_dilations_0 = const()[name = string("op_419_dilations_0"), val = tensor<int32, [2]>([1, 1])];
            int32 var_419_groups_0 = const()[name = string("op_419_groups_0"), val = int32(1)];
            tensor<fp16, [16384, 1152, 1, 1]> var_399_promoted_to_fp16 = const()[name = string("op_399_promoted_to_fp16"), val = tensor<fp16, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(566232064)))];
            tensor<fp16, [1, 16384, 1, 1]> var_419_cast_fp16 = conv(dilations = var_419_dilations_0, groups = var_419_groups_0, pad = var_419_pad_0, pad_type = var_419_pad_type_0, strides = var_419_strides_0, weight = var_399_promoted_to_fp16, x = input_cast_fp16)[name = string("op_419_cast_fp16")];
            tensor<int32, [1]> var_421_axes_0 = const()[name = string("op_421_axes_0"), val = tensor<int32, [1]>([2])];
            tensor<fp16, [1, 16384, 1]> var_421_cast_fp16 = squeeze(axes = var_421_axes_0, x = var_419_cast_fp16)[name = string("op_421_cast_fp16")];
            tensor<int32, [3]> logits_perm_0 = const()[name = string("logits_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
            int32 chunk_argmax_1_axis_0 = const()[name = string("chunk_argmax_1_axis_0"), val = int32(-1)];
            bool chunk_argmax_1_keep_dims_0 = const()[name = string("chunk_argmax_1_keep_dims_0"), val = bool(true)];
            string chunk_argmax_1_output_dtype_0 = const()[name = string("chunk_argmax_1_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_1_cast_fp16 = transpose(perm = logits_1_perm_0, x = var_31_cast_fp16)[name = string("transpose_15")];
            tensor<int32, [1, 1, 1]> chunk_argmax_1_cast_fp16 = reduce_argmax(axis = chunk_argmax_1_axis_0, keep_dims = chunk_argmax_1_keep_dims_0, output_dtype = chunk_argmax_1_output_dtype_0, x = logits_1_cast_fp16)[name = string("chunk_argmax_1_cast_fp16")];
            int32 var_428 = const()[name = string("op_428"), val = int32(-1)];
            bool var_430_validate_indices_0 = const()[name = string("op_430_validate_indices_0"), val = bool(false)];
            string chunk_argmax_1_cast_fp16_to_uint16_dtype_0 = const()[name = string("chunk_argmax_1_cast_fp16_to_uint16_dtype_0"), val = string("uint16")];
            tensor<uint16, [1, 1, 1]> chunk_argmax_1_cast_fp16_to_uint16 = cast(dtype = chunk_argmax_1_cast_fp16_to_uint16_dtype_0, x = chunk_argmax_1_cast_fp16)[name = string("cast_19")];
            tensor<fp16, [1, 1, 1]> var_430_cast_fp16_cast_int16 = gather_along_axis(axis = var_428, indices = chunk_argmax_1_cast_fp16_to_uint16, validate_indices = var_430_validate_indices_0, x = logits_1_cast_fp16)[name = string("op_430_cast_fp16_cast_int16")];
            int32 chunk_argmax_3_axis_0 = const()[name = string("chunk_argmax_3_axis_0"), val = int32(-1)];
            bool chunk_argmax_3_keep_dims_0 = const()[name = string("chunk_argmax_3_keep_dims_0"), val = bool(true)];
            string chunk_argmax_3_output_dtype_0 = const()[name = string("chunk_argmax_3_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_3_cast_fp16 = transpose(perm = logits_3_perm_0, x = var_57_cast_fp16)[name = string("transpose_14")];
            tensor<int32, [1, 1, 1]> chunk_argmax_3_cast_fp16 = reduce_argmax(axis = chunk_argmax_3_axis_0, keep_dims = chunk_argmax_3_keep_dims_0, output_dtype = chunk_argmax_3_output_dtype_0, x = logits_3_cast_fp16)[name = string("chunk_argmax_3_cast_fp16")];
            int32 var_439 = const()[name = string("op_439"), val = int32(-1)];
            bool var_441_validate_indices_0 = const()[name = string("op_441_validate_indices_0"), val = bool(false)];
            string chunk_argmax_3_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_3_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_3_cast_fp16_to_int16 = cast(dtype = chunk_argmax_3_cast_fp16_to_int16_dtype_0, x = chunk_argmax_3_cast_fp16)[name = string("cast_18")];
            tensor<fp16, [1, 1, 1]> var_441_cast_fp16_cast_int16 = gather_along_axis(axis = var_439, indices = chunk_argmax_3_cast_fp16_to_int16, validate_indices = var_441_validate_indices_0, x = logits_3_cast_fp16)[name = string("op_441_cast_fp16_cast_int16")];
            int32 chunk_argmax_5_axis_0 = const()[name = string("chunk_argmax_5_axis_0"), val = int32(-1)];
            bool chunk_argmax_5_keep_dims_0 = const()[name = string("chunk_argmax_5_keep_dims_0"), val = bool(true)];
            string chunk_argmax_5_output_dtype_0 = const()[name = string("chunk_argmax_5_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_5_cast_fp16 = transpose(perm = logits_5_perm_0, x = var_83_cast_fp16)[name = string("transpose_13")];
            tensor<int32, [1, 1, 1]> chunk_argmax_5_cast_fp16 = reduce_argmax(axis = chunk_argmax_5_axis_0, keep_dims = chunk_argmax_5_keep_dims_0, output_dtype = chunk_argmax_5_output_dtype_0, x = logits_5_cast_fp16)[name = string("chunk_argmax_5_cast_fp16")];
            int32 var_450 = const()[name = string("op_450"), val = int32(-1)];
            bool var_452_validate_indices_0 = const()[name = string("op_452_validate_indices_0"), val = bool(false)];
            string chunk_argmax_5_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_5_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_5_cast_fp16_to_int16 = cast(dtype = chunk_argmax_5_cast_fp16_to_int16_dtype_0, x = chunk_argmax_5_cast_fp16)[name = string("cast_17")];
            tensor<fp16, [1, 1, 1]> var_452_cast_fp16_cast_int16 = gather_along_axis(axis = var_450, indices = chunk_argmax_5_cast_fp16_to_int16, validate_indices = var_452_validate_indices_0, x = logits_5_cast_fp16)[name = string("op_452_cast_fp16_cast_int16")];
            int32 chunk_argmax_7_axis_0 = const()[name = string("chunk_argmax_7_axis_0"), val = int32(-1)];
            bool chunk_argmax_7_keep_dims_0 = const()[name = string("chunk_argmax_7_keep_dims_0"), val = bool(true)];
            string chunk_argmax_7_output_dtype_0 = const()[name = string("chunk_argmax_7_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_7_cast_fp16 = transpose(perm = logits_7_perm_0, x = var_109_cast_fp16)[name = string("transpose_12")];
            tensor<int32, [1, 1, 1]> chunk_argmax_7_cast_fp16 = reduce_argmax(axis = chunk_argmax_7_axis_0, keep_dims = chunk_argmax_7_keep_dims_0, output_dtype = chunk_argmax_7_output_dtype_0, x = logits_7_cast_fp16)[name = string("chunk_argmax_7_cast_fp16")];
            int32 var_461 = const()[name = string("op_461"), val = int32(-1)];
            bool var_463_validate_indices_0 = const()[name = string("op_463_validate_indices_0"), val = bool(false)];
            string chunk_argmax_7_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_7_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_7_cast_fp16_to_int16 = cast(dtype = chunk_argmax_7_cast_fp16_to_int16_dtype_0, x = chunk_argmax_7_cast_fp16)[name = string("cast_16")];
            tensor<fp16, [1, 1, 1]> var_463_cast_fp16_cast_int16 = gather_along_axis(axis = var_461, indices = chunk_argmax_7_cast_fp16_to_int16, validate_indices = var_463_validate_indices_0, x = logits_7_cast_fp16)[name = string("op_463_cast_fp16_cast_int16")];
            int32 chunk_argmax_9_axis_0 = const()[name = string("chunk_argmax_9_axis_0"), val = int32(-1)];
            bool chunk_argmax_9_keep_dims_0 = const()[name = string("chunk_argmax_9_keep_dims_0"), val = bool(true)];
            string chunk_argmax_9_output_dtype_0 = const()[name = string("chunk_argmax_9_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_9_cast_fp16 = transpose(perm = logits_9_perm_0, x = var_135_cast_fp16)[name = string("transpose_11")];
            tensor<int32, [1, 1, 1]> chunk_argmax_9_cast_fp16 = reduce_argmax(axis = chunk_argmax_9_axis_0, keep_dims = chunk_argmax_9_keep_dims_0, output_dtype = chunk_argmax_9_output_dtype_0, x = logits_9_cast_fp16)[name = string("chunk_argmax_9_cast_fp16")];
            int32 var_472 = const()[name = string("op_472"), val = int32(-1)];
            bool var_474_validate_indices_0 = const()[name = string("op_474_validate_indices_0"), val = bool(false)];
            string chunk_argmax_9_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_9_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_9_cast_fp16_to_int16 = cast(dtype = chunk_argmax_9_cast_fp16_to_int16_dtype_0, x = chunk_argmax_9_cast_fp16)[name = string("cast_15")];
            tensor<fp16, [1, 1, 1]> var_474_cast_fp16_cast_int16 = gather_along_axis(axis = var_472, indices = chunk_argmax_9_cast_fp16_to_int16, validate_indices = var_474_validate_indices_0, x = logits_9_cast_fp16)[name = string("op_474_cast_fp16_cast_int16")];
            int32 chunk_argmax_11_axis_0 = const()[name = string("chunk_argmax_11_axis_0"), val = int32(-1)];
            bool chunk_argmax_11_keep_dims_0 = const()[name = string("chunk_argmax_11_keep_dims_0"), val = bool(true)];
            string chunk_argmax_11_output_dtype_0 = const()[name = string("chunk_argmax_11_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_11_cast_fp16 = transpose(perm = logits_11_perm_0, x = var_161_cast_fp16)[name = string("transpose_10")];
            tensor<int32, [1, 1, 1]> chunk_argmax_11_cast_fp16 = reduce_argmax(axis = chunk_argmax_11_axis_0, keep_dims = chunk_argmax_11_keep_dims_0, output_dtype = chunk_argmax_11_output_dtype_0, x = logits_11_cast_fp16)[name = string("chunk_argmax_11_cast_fp16")];
            int32 var_483 = const()[name = string("op_483"), val = int32(-1)];
            bool var_485_validate_indices_0 = const()[name = string("op_485_validate_indices_0"), val = bool(false)];
            string chunk_argmax_11_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_11_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_11_cast_fp16_to_int16 = cast(dtype = chunk_argmax_11_cast_fp16_to_int16_dtype_0, x = chunk_argmax_11_cast_fp16)[name = string("cast_14")];
            tensor<fp16, [1, 1, 1]> var_485_cast_fp16_cast_int16 = gather_along_axis(axis = var_483, indices = chunk_argmax_11_cast_fp16_to_int16, validate_indices = var_485_validate_indices_0, x = logits_11_cast_fp16)[name = string("op_485_cast_fp16_cast_int16")];
            int32 chunk_argmax_13_axis_0 = const()[name = string("chunk_argmax_13_axis_0"), val = int32(-1)];
            bool chunk_argmax_13_keep_dims_0 = const()[name = string("chunk_argmax_13_keep_dims_0"), val = bool(true)];
            string chunk_argmax_13_output_dtype_0 = const()[name = string("chunk_argmax_13_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_13_cast_fp16 = transpose(perm = logits_13_perm_0, x = var_187_cast_fp16)[name = string("transpose_9")];
            tensor<int32, [1, 1, 1]> chunk_argmax_13_cast_fp16 = reduce_argmax(axis = chunk_argmax_13_axis_0, keep_dims = chunk_argmax_13_keep_dims_0, output_dtype = chunk_argmax_13_output_dtype_0, x = logits_13_cast_fp16)[name = string("chunk_argmax_13_cast_fp16")];
            int32 var_494 = const()[name = string("op_494"), val = int32(-1)];
            bool var_496_validate_indices_0 = const()[name = string("op_496_validate_indices_0"), val = bool(false)];
            string chunk_argmax_13_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_13_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_13_cast_fp16_to_int16 = cast(dtype = chunk_argmax_13_cast_fp16_to_int16_dtype_0, x = chunk_argmax_13_cast_fp16)[name = string("cast_13")];
            tensor<fp16, [1, 1, 1]> var_496_cast_fp16_cast_int16 = gather_along_axis(axis = var_494, indices = chunk_argmax_13_cast_fp16_to_int16, validate_indices = var_496_validate_indices_0, x = logits_13_cast_fp16)[name = string("op_496_cast_fp16_cast_int16")];
            int32 chunk_argmax_15_axis_0 = const()[name = string("chunk_argmax_15_axis_0"), val = int32(-1)];
            bool chunk_argmax_15_keep_dims_0 = const()[name = string("chunk_argmax_15_keep_dims_0"), val = bool(true)];
            string chunk_argmax_15_output_dtype_0 = const()[name = string("chunk_argmax_15_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_15_cast_fp16 = transpose(perm = logits_15_perm_0, x = var_213_cast_fp16)[name = string("transpose_8")];
            tensor<int32, [1, 1, 1]> chunk_argmax_15_cast_fp16 = reduce_argmax(axis = chunk_argmax_15_axis_0, keep_dims = chunk_argmax_15_keep_dims_0, output_dtype = chunk_argmax_15_output_dtype_0, x = logits_15_cast_fp16)[name = string("chunk_argmax_15_cast_fp16")];
            int32 var_505 = const()[name = string("op_505"), val = int32(-1)];
            bool var_507_validate_indices_0 = const()[name = string("op_507_validate_indices_0"), val = bool(false)];
            string chunk_argmax_15_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_15_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_15_cast_fp16_to_int16 = cast(dtype = chunk_argmax_15_cast_fp16_to_int16_dtype_0, x = chunk_argmax_15_cast_fp16)[name = string("cast_12")];
            tensor<fp16, [1, 1, 1]> var_507_cast_fp16_cast_int16 = gather_along_axis(axis = var_505, indices = chunk_argmax_15_cast_fp16_to_int16, validate_indices = var_507_validate_indices_0, x = logits_15_cast_fp16)[name = string("op_507_cast_fp16_cast_int16")];
            int32 chunk_argmax_17_axis_0 = const()[name = string("chunk_argmax_17_axis_0"), val = int32(-1)];
            bool chunk_argmax_17_keep_dims_0 = const()[name = string("chunk_argmax_17_keep_dims_0"), val = bool(true)];
            string chunk_argmax_17_output_dtype_0 = const()[name = string("chunk_argmax_17_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_17_cast_fp16 = transpose(perm = logits_17_perm_0, x = var_239_cast_fp16)[name = string("transpose_7")];
            tensor<int32, [1, 1, 1]> chunk_argmax_17_cast_fp16 = reduce_argmax(axis = chunk_argmax_17_axis_0, keep_dims = chunk_argmax_17_keep_dims_0, output_dtype = chunk_argmax_17_output_dtype_0, x = logits_17_cast_fp16)[name = string("chunk_argmax_17_cast_fp16")];
            int32 var_516 = const()[name = string("op_516"), val = int32(-1)];
            bool var_518_validate_indices_0 = const()[name = string("op_518_validate_indices_0"), val = bool(false)];
            string chunk_argmax_17_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_17_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_17_cast_fp16_to_int16 = cast(dtype = chunk_argmax_17_cast_fp16_to_int16_dtype_0, x = chunk_argmax_17_cast_fp16)[name = string("cast_11")];
            tensor<fp16, [1, 1, 1]> var_518_cast_fp16_cast_int16 = gather_along_axis(axis = var_516, indices = chunk_argmax_17_cast_fp16_to_int16, validate_indices = var_518_validate_indices_0, x = logits_17_cast_fp16)[name = string("op_518_cast_fp16_cast_int16")];
            int32 chunk_argmax_19_axis_0 = const()[name = string("chunk_argmax_19_axis_0"), val = int32(-1)];
            bool chunk_argmax_19_keep_dims_0 = const()[name = string("chunk_argmax_19_keep_dims_0"), val = bool(true)];
            string chunk_argmax_19_output_dtype_0 = const()[name = string("chunk_argmax_19_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_19_cast_fp16 = transpose(perm = logits_19_perm_0, x = var_265_cast_fp16)[name = string("transpose_6")];
            tensor<int32, [1, 1, 1]> chunk_argmax_19_cast_fp16 = reduce_argmax(axis = chunk_argmax_19_axis_0, keep_dims = chunk_argmax_19_keep_dims_0, output_dtype = chunk_argmax_19_output_dtype_0, x = logits_19_cast_fp16)[name = string("chunk_argmax_19_cast_fp16")];
            int32 var_527 = const()[name = string("op_527"), val = int32(-1)];
            bool var_529_validate_indices_0 = const()[name = string("op_529_validate_indices_0"), val = bool(false)];
            string chunk_argmax_19_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_19_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_19_cast_fp16_to_int16 = cast(dtype = chunk_argmax_19_cast_fp16_to_int16_dtype_0, x = chunk_argmax_19_cast_fp16)[name = string("cast_10")];
            tensor<fp16, [1, 1, 1]> var_529_cast_fp16_cast_int16 = gather_along_axis(axis = var_527, indices = chunk_argmax_19_cast_fp16_to_int16, validate_indices = var_529_validate_indices_0, x = logits_19_cast_fp16)[name = string("op_529_cast_fp16_cast_int16")];
            int32 chunk_argmax_21_axis_0 = const()[name = string("chunk_argmax_21_axis_0"), val = int32(-1)];
            bool chunk_argmax_21_keep_dims_0 = const()[name = string("chunk_argmax_21_keep_dims_0"), val = bool(true)];
            string chunk_argmax_21_output_dtype_0 = const()[name = string("chunk_argmax_21_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_21_cast_fp16 = transpose(perm = logits_21_perm_0, x = var_291_cast_fp16)[name = string("transpose_5")];
            tensor<int32, [1, 1, 1]> chunk_argmax_21_cast_fp16 = reduce_argmax(axis = chunk_argmax_21_axis_0, keep_dims = chunk_argmax_21_keep_dims_0, output_dtype = chunk_argmax_21_output_dtype_0, x = logits_21_cast_fp16)[name = string("chunk_argmax_21_cast_fp16")];
            int32 var_538 = const()[name = string("op_538"), val = int32(-1)];
            bool var_540_validate_indices_0 = const()[name = string("op_540_validate_indices_0"), val = bool(false)];
            string chunk_argmax_21_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_21_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_21_cast_fp16_to_int16 = cast(dtype = chunk_argmax_21_cast_fp16_to_int16_dtype_0, x = chunk_argmax_21_cast_fp16)[name = string("cast_9")];
            tensor<fp16, [1, 1, 1]> var_540_cast_fp16_cast_int16 = gather_along_axis(axis = var_538, indices = chunk_argmax_21_cast_fp16_to_int16, validate_indices = var_540_validate_indices_0, x = logits_21_cast_fp16)[name = string("op_540_cast_fp16_cast_int16")];
            int32 chunk_argmax_23_axis_0 = const()[name = string("chunk_argmax_23_axis_0"), val = int32(-1)];
            bool chunk_argmax_23_keep_dims_0 = const()[name = string("chunk_argmax_23_keep_dims_0"), val = bool(true)];
            string chunk_argmax_23_output_dtype_0 = const()[name = string("chunk_argmax_23_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_23_cast_fp16 = transpose(perm = logits_23_perm_0, x = var_317_cast_fp16)[name = string("transpose_4")];
            tensor<int32, [1, 1, 1]> chunk_argmax_23_cast_fp16 = reduce_argmax(axis = chunk_argmax_23_axis_0, keep_dims = chunk_argmax_23_keep_dims_0, output_dtype = chunk_argmax_23_output_dtype_0, x = logits_23_cast_fp16)[name = string("chunk_argmax_23_cast_fp16")];
            int32 var_549 = const()[name = string("op_549"), val = int32(-1)];
            bool var_551_validate_indices_0 = const()[name = string("op_551_validate_indices_0"), val = bool(false)];
            string chunk_argmax_23_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_23_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_23_cast_fp16_to_int16 = cast(dtype = chunk_argmax_23_cast_fp16_to_int16_dtype_0, x = chunk_argmax_23_cast_fp16)[name = string("cast_8")];
            tensor<fp16, [1, 1, 1]> var_551_cast_fp16_cast_int16 = gather_along_axis(axis = var_549, indices = chunk_argmax_23_cast_fp16_to_int16, validate_indices = var_551_validate_indices_0, x = logits_23_cast_fp16)[name = string("op_551_cast_fp16_cast_int16")];
            int32 chunk_argmax_25_axis_0 = const()[name = string("chunk_argmax_25_axis_0"), val = int32(-1)];
            bool chunk_argmax_25_keep_dims_0 = const()[name = string("chunk_argmax_25_keep_dims_0"), val = bool(true)];
            string chunk_argmax_25_output_dtype_0 = const()[name = string("chunk_argmax_25_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_25_cast_fp16 = transpose(perm = logits_25_perm_0, x = var_343_cast_fp16)[name = string("transpose_3")];
            tensor<int32, [1, 1, 1]> chunk_argmax_25_cast_fp16 = reduce_argmax(axis = chunk_argmax_25_axis_0, keep_dims = chunk_argmax_25_keep_dims_0, output_dtype = chunk_argmax_25_output_dtype_0, x = logits_25_cast_fp16)[name = string("chunk_argmax_25_cast_fp16")];
            int32 var_560 = const()[name = string("op_560"), val = int32(-1)];
            bool var_562_validate_indices_0 = const()[name = string("op_562_validate_indices_0"), val = bool(false)];
            string chunk_argmax_25_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_25_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_25_cast_fp16_to_int16 = cast(dtype = chunk_argmax_25_cast_fp16_to_int16_dtype_0, x = chunk_argmax_25_cast_fp16)[name = string("cast_7")];
            tensor<fp16, [1, 1, 1]> var_562_cast_fp16_cast_int16 = gather_along_axis(axis = var_560, indices = chunk_argmax_25_cast_fp16_to_int16, validate_indices = var_562_validate_indices_0, x = logits_25_cast_fp16)[name = string("op_562_cast_fp16_cast_int16")];
            int32 chunk_argmax_27_axis_0 = const()[name = string("chunk_argmax_27_axis_0"), val = int32(-1)];
            bool chunk_argmax_27_keep_dims_0 = const()[name = string("chunk_argmax_27_keep_dims_0"), val = bool(true)];
            string chunk_argmax_27_output_dtype_0 = const()[name = string("chunk_argmax_27_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_27_cast_fp16 = transpose(perm = logits_27_perm_0, x = var_369_cast_fp16)[name = string("transpose_2")];
            tensor<int32, [1, 1, 1]> chunk_argmax_27_cast_fp16 = reduce_argmax(axis = chunk_argmax_27_axis_0, keep_dims = chunk_argmax_27_keep_dims_0, output_dtype = chunk_argmax_27_output_dtype_0, x = logits_27_cast_fp16)[name = string("chunk_argmax_27_cast_fp16")];
            int32 var_571 = const()[name = string("op_571"), val = int32(-1)];
            bool var_573_validate_indices_0 = const()[name = string("op_573_validate_indices_0"), val = bool(false)];
            string chunk_argmax_27_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_27_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_27_cast_fp16_to_int16 = cast(dtype = chunk_argmax_27_cast_fp16_to_int16_dtype_0, x = chunk_argmax_27_cast_fp16)[name = string("cast_6")];
            tensor<fp16, [1, 1, 1]> var_573_cast_fp16_cast_int16 = gather_along_axis(axis = var_571, indices = chunk_argmax_27_cast_fp16_to_int16, validate_indices = var_573_validate_indices_0, x = logits_27_cast_fp16)[name = string("op_573_cast_fp16_cast_int16")];
            int32 chunk_argmax_29_axis_0 = const()[name = string("chunk_argmax_29_axis_0"), val = int32(-1)];
            bool chunk_argmax_29_keep_dims_0 = const()[name = string("chunk_argmax_29_keep_dims_0"), val = bool(true)];
            string chunk_argmax_29_output_dtype_0 = const()[name = string("chunk_argmax_29_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_29_cast_fp16 = transpose(perm = logits_29_perm_0, x = var_395_cast_fp16)[name = string("transpose_1")];
            tensor<int32, [1, 1, 1]> chunk_argmax_29_cast_fp16 = reduce_argmax(axis = chunk_argmax_29_axis_0, keep_dims = chunk_argmax_29_keep_dims_0, output_dtype = chunk_argmax_29_output_dtype_0, x = logits_29_cast_fp16)[name = string("chunk_argmax_29_cast_fp16")];
            int32 var_582 = const()[name = string("op_582"), val = int32(-1)];
            bool var_584_validate_indices_0 = const()[name = string("op_584_validate_indices_0"), val = bool(false)];
            string chunk_argmax_29_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_29_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_29_cast_fp16_to_int16 = cast(dtype = chunk_argmax_29_cast_fp16_to_int16_dtype_0, x = chunk_argmax_29_cast_fp16)[name = string("cast_5")];
            tensor<fp16, [1, 1, 1]> var_584_cast_fp16_cast_int16 = gather_along_axis(axis = var_582, indices = chunk_argmax_29_cast_fp16_to_int16, validate_indices = var_584_validate_indices_0, x = logits_29_cast_fp16)[name = string("op_584_cast_fp16_cast_int16")];
            int32 chunk_argmax_axis_0 = const()[name = string("chunk_argmax_axis_0"), val = int32(-1)];
            bool chunk_argmax_keep_dims_0 = const()[name = string("chunk_argmax_keep_dims_0"), val = bool(true)];
            string chunk_argmax_output_dtype_0 = const()[name = string("chunk_argmax_output_dtype_0"), val = string("int32")];
            tensor<fp16, [1, 1, 16384]> logits_cast_fp16 = transpose(perm = logits_perm_0, x = var_421_cast_fp16)[name = string("transpose_0")];
            tensor<int32, [1, 1, 1]> chunk_argmax_cast_fp16 = reduce_argmax(axis = chunk_argmax_axis_0, keep_dims = chunk_argmax_keep_dims_0, output_dtype = chunk_argmax_output_dtype_0, x = logits_cast_fp16)[name = string("chunk_argmax_cast_fp16")];
            int32 var_593 = const()[name = string("op_593"), val = int32(-1)];
            bool chunk_max_val_validate_indices_0 = const()[name = string("chunk_max_val_validate_indices_0"), val = bool(false)];
            string chunk_argmax_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_cast_fp16_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 1]> chunk_argmax_cast_fp16_to_int16 = cast(dtype = chunk_argmax_cast_fp16_to_int16_dtype_0, x = chunk_argmax_cast_fp16)[name = string("cast_4")];
            tensor<fp16, [1, 1, 1]> chunk_max_val_cast_fp16_cast_int16 = gather_along_axis(axis = var_593, indices = chunk_argmax_cast_fp16_to_int16, validate_indices = chunk_max_val_validate_indices_0, x = logits_cast_fp16)[name = string("chunk_max_val_cast_fp16_cast_int16")];
            int32 var_602 = const()[name = string("op_602"), val = int32(-1)];
            bool var_603_interleave_0 = const()[name = string("op_603_interleave_0"), val = bool(false)];
            tensor<int32, [1, 1, 16]> var_603 = concat(axis = var_602, interleave = var_603_interleave_0, values = (chunk_argmax_1_cast_fp16, chunk_argmax_3_cast_fp16, chunk_argmax_5_cast_fp16, chunk_argmax_7_cast_fp16, chunk_argmax_9_cast_fp16, chunk_argmax_11_cast_fp16, chunk_argmax_13_cast_fp16, chunk_argmax_15_cast_fp16, chunk_argmax_17_cast_fp16, chunk_argmax_19_cast_fp16, chunk_argmax_21_cast_fp16, chunk_argmax_23_cast_fp16, chunk_argmax_25_cast_fp16, chunk_argmax_27_cast_fp16, chunk_argmax_29_cast_fp16, chunk_argmax_cast_fp16))[name = string("op_603")];
            tensor<int32, [1]> var_605_axes_0 = const()[name = string("op_605_axes_0"), val = tensor<int32, [1]>([0])];
            string var_603_to_int16_dtype_0 = const()[name = string("op_603_to_int16_dtype_0"), val = string("int16")];
            tensor<int16, [1, 1, 16]> var_603_to_int16 = cast(dtype = var_603_to_int16_dtype_0, x = var_603)[name = string("cast_3")];
            tensor<int16, [1, 16]> var_605_cast_uint16 = squeeze(axes = var_605_axes_0, x = var_603_to_int16)[name = string("op_605_cast_uint16")];
            tensor<int32, [1]> var_607_axes_0 = const()[name = string("op_607_axes_0"), val = tensor<int32, [1]>([0])];
            tensor<int16, [16]> var_607_cast_uint16 = squeeze(axes = var_607_axes_0, x = var_605_cast_uint16)[name = string("op_607_cast_uint16")];
            string var_607_cast_uint16_to_int32_dtype_0 = const()[name = string("op_607_cast_uint16_to_int32_dtype_0"), val = string("int32")];
            int32 var_609 = const()[name = string("op_609"), val = int32(-1)];
            bool var_610_interleave_0 = const()[name = string("op_610_interleave_0"), val = bool(false)];
            tensor<fp16, [1, 1, 16]> var_610_cast_fp16 = concat(axis = var_609, interleave = var_610_interleave_0, values = (var_430_cast_fp16_cast_int16, var_441_cast_fp16_cast_int16, var_452_cast_fp16_cast_int16, var_463_cast_fp16_cast_int16, var_474_cast_fp16_cast_int16, var_485_cast_fp16_cast_int16, var_496_cast_fp16_cast_int16, var_507_cast_fp16_cast_int16, var_518_cast_fp16_cast_int16, var_529_cast_fp16_cast_int16, var_540_cast_fp16_cast_int16, var_551_cast_fp16_cast_int16, var_562_cast_fp16_cast_int16, var_573_cast_fp16_cast_int16, var_584_cast_fp16_cast_int16, chunk_max_val_cast_fp16_cast_int16))[name = string("op_610_cast_fp16")];
            tensor<int32, [1]> var_612_axes_0 = const()[name = string("op_612_axes_0"), val = tensor<int32, [1]>([0])];
            tensor<fp16, [1, 16]> var_612_cast_fp16 = squeeze(axes = var_612_axes_0, x = var_610_cast_fp16)[name = string("op_612_cast_fp16")];
            tensor<int32, [1]> var_614_axes_0 = const()[name = string("op_614_axes_0"), val = tensor<int32, [1]>([0])];
            tensor<fp16, [16]> argmax_val = squeeze(axes = var_614_axes_0, x = var_612_cast_fp16)[name = string("op_614_cast_fp16")];
            tensor<int32, [16]> argmax_idx = cast(dtype = var_607_cast_uint16_to_int32_dtype_0, x = var_607_cast_uint16)[name = string("cast_2")];
        } -> (argmax_idx, argmax_val);
}