ggerganov HF Staff commited on
Commit
7dd37dc
·
1 Parent(s): 96dc902

metal : enable shader debugging (cmake option) (llama/4705)

Browse files

* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (llama/4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

ggml-ci

Files changed (2) hide show
  1. ggml-metal.m +19 -9
  2. ggml-metal.metal +265 -210
ggml-metal.m CHANGED
@@ -257,13 +257,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
257
  bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
258
  #endif
259
  NSError * error = nil;
260
- NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
261
  if (libPath != nil) {
 
262
  NSURL * libURL = [NSURL fileURLWithPath:libPath];
263
  GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
264
  ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
265
  } else {
266
- GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
267
 
268
  NSString * sourcePath;
269
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
@@ -291,6 +292,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
291
  options = [MTLCompileOptions new];
292
  options.preprocessorMacros = @{ @"QK_K" : @(64) };
293
  #endif
 
 
 
 
 
 
 
294
  ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
295
  }
296
 
@@ -1230,7 +1238,7 @@ void ggml_metal_graph_compute(
1230
  // not sure how to avoid this
1231
  // TODO: make a simpler cpy_bytes kernel
1232
 
1233
- const int nth = MIN(1024, ne00);
1234
 
1235
  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1236
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1285,7 +1293,7 @@ void ggml_metal_graph_compute(
1285
  [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1286
  [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1287
 
1288
- const int nth = MIN(1024, ne0);
1289
 
1290
  [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1291
  } break;
@@ -1785,8 +1793,9 @@ void ggml_metal_graph_compute(
1785
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1786
  [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1787
  // TODO: how to make this an array? read Metal docs
1788
- for (int j = 0; j < n_as; ++j) {
1789
- struct ggml_tensor * src_cur = dst->src[2 + j];
 
1790
 
1791
  size_t offs_src_cur = 0;
1792
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
@@ -1909,8 +1918,9 @@ void ggml_metal_graph_compute(
1909
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1910
  [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1911
  // TODO: how to make this an array? read Metal docs
1912
- for (int j = 0; j < n_as; ++j) {
1913
- struct ggml_tensor * src_cur = dst->src[2 + j];
 
1914
 
1915
  size_t offs_src_cur = 0;
1916
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
@@ -2229,7 +2239,7 @@ void ggml_metal_graph_compute(
2229
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2230
  [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2231
 
2232
- const int nth = MIN(1024, ne0);
2233
 
2234
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2235
  } break;
 
257
  bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
258
  #endif
259
  NSError * error = nil;
260
+ NSString * libPath = [bundle pathForResource:@"ggml" ofType:@"metallib"];
261
  if (libPath != nil) {
262
+ // pre-compiled library found
263
  NSURL * libURL = [NSURL fileURLWithPath:libPath];
264
  GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
265
  ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
266
  } else {
267
+ GGML_METAL_LOG_INFO("%s: ggml.metallib not found, loading from source\n", __func__);
268
 
269
  NSString * sourcePath;
270
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
 
292
  options = [MTLCompileOptions new];
293
  options.preprocessorMacros = @{ @"QK_K" : @(64) };
294
  #endif
295
+ // try to disable fast-math
296
+ // NOTE: this seems to have no effect whatsoever
297
+ // instead, in order to disable fast-math, we have to build ggml.metallib from the command line
298
+ // using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
299
+ // and go through the "pre-compiled library found" path above
300
+ //[options setFastMathEnabled:false];
301
+
302
  ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
303
  }
304
 
 
1238
  // not sure how to avoid this
1239
  // TODO: make a simpler cpy_bytes kernel
1240
 
1241
+ const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
1242
 
1243
  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1244
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1293
  [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1294
  [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1295
 
1296
+ const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
1297
 
1298
  [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1299
  } break;
 
1793
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1794
  [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1795
  // TODO: how to make this an array? read Metal docs
1796
+ for (int j = 0; j < 8; ++j) {
1797
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1798
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1799
 
1800
  size_t offs_src_cur = 0;
1801
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
 
1918
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1919
  [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1920
  // TODO: how to make this an array? read Metal docs
1921
+ for (int j = 0; j < 8; ++j) {
1922
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1923
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1924
 
1925
  size_t offs_src_cur = 0;
1926
  id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
 
2239
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2240
  [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2241
 
2242
+ const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
2243
 
2244
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2245
  } break;
ggml-metal.metal CHANGED
@@ -59,26 +59,26 @@ kernel void kernel_add(
59
  constant int64_t & ne01,
60
  constant int64_t & ne02,
61
  constant int64_t & ne03,
62
- constant int64_t & nb00,
63
- constant int64_t & nb01,
64
- constant int64_t & nb02,
65
- constant int64_t & nb03,
66
  constant int64_t & ne10,
67
  constant int64_t & ne11,
68
  constant int64_t & ne12,
69
  constant int64_t & ne13,
70
- constant int64_t & nb10,
71
- constant int64_t & nb11,
72
- constant int64_t & nb12,
73
- constant int64_t & nb13,
74
  constant int64_t & ne0,
75
  constant int64_t & ne1,
76
  constant int64_t & ne2,
77
  constant int64_t & ne3,
78
- constant int64_t & nb0,
79
- constant int64_t & nb1,
80
- constant int64_t & nb2,
81
- constant int64_t & nb3,
82
  constant int64_t & offs,
83
  uint3 tgpig[[threadgroup_position_in_grid]],
84
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -109,26 +109,26 @@ kernel void kernel_mul(
109
  constant int64_t & ne01,
110
  constant int64_t & ne02,
111
  constant int64_t & ne03,
112
- constant int64_t & nb00,
113
- constant int64_t & nb01,
114
- constant int64_t & nb02,
115
- constant int64_t & nb03,
116
  constant int64_t & ne10,
117
  constant int64_t & ne11,
118
  constant int64_t & ne12,
119
  constant int64_t & ne13,
120
- constant int64_t & nb10,
121
- constant int64_t & nb11,
122
- constant int64_t & nb12,
123
- constant int64_t & nb13,
124
  constant int64_t & ne0,
125
  constant int64_t & ne1,
126
  constant int64_t & ne2,
127
  constant int64_t & ne3,
128
- constant int64_t & nb0,
129
- constant int64_t & nb1,
130
- constant int64_t & nb2,
131
- constant int64_t & nb3,
132
  uint3 tgpig[[threadgroup_position_in_grid]],
133
  uint3 tpitg[[thread_position_in_threadgroup]],
134
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -158,26 +158,26 @@ kernel void kernel_div(
158
  constant int64_t & ne01,
159
  constant int64_t & ne02,
160
  constant int64_t & ne03,
161
- constant int64_t & nb00,
162
- constant int64_t & nb01,
163
- constant int64_t & nb02,
164
- constant int64_t & nb03,
165
  constant int64_t & ne10,
166
  constant int64_t & ne11,
167
  constant int64_t & ne12,
168
  constant int64_t & ne13,
169
- constant int64_t & nb10,
170
- constant int64_t & nb11,
171
- constant int64_t & nb12,
172
- constant int64_t & nb13,
173
  constant int64_t & ne0,
174
  constant int64_t & ne1,
175
  constant int64_t & ne2,
176
  constant int64_t & ne3,
177
- constant int64_t & nb0,
178
- constant int64_t & nb1,
179
- constant int64_t & nb2,
180
- constant int64_t & nb3,
181
  uint3 tgpig[[threadgroup_position_in_grid]],
182
  uint3 tpitg[[thread_position_in_threadgroup]],
183
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
205
  device const float4 * src0,
206
  device const float4 * src1,
207
  device float4 * dst,
208
- constant int64_t & nb [[buffer(28)]],
209
  uint tpig[[thread_position_in_grid]]) {
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
211
  }
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
214
  device const float4 * src0,
215
  device const float4 * src1,
216
  device float4 * dst,
217
- constant int64_t & nb [[buffer(28)]],
218
  uint tpig[[thread_position_in_grid]]) {
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
220
  }
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
223
  device const float4 * src0,
224
  device const float4 * src1,
225
  device float4 * dst,
226
- constant int64_t & nb [[buffer(28)]],
227
  uint tpig[[thread_position_in_grid]]) {
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
229
  }
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
307
  constant int64_t & ne01,
308
  constant int64_t & ne02,
309
  constant int64_t & ne03,
310
- constant int64_t & nb00,
311
- constant int64_t & nb01,
312
- constant int64_t & nb02,
313
- constant int64_t & nb03,
314
  constant int64_t & ne10,
315
  constant int64_t & ne11,
316
  constant int64_t & ne12,
317
  constant int64_t & ne13,
318
- constant int64_t & nb10,
319
- constant int64_t & nb11,
320
- constant int64_t & nb12,
321
- constant int64_t & nb13,
322
  constant int64_t & ne0,
323
  constant int64_t & ne1,
324
  constant int64_t & ne2,
325
  constant int64_t & ne3,
326
- constant int64_t & nb0,
327
- constant int64_t & nb1,
328
- constant int64_t & nb2,
329
- constant int64_t & nb3,
330
  uint3 tpig[[thread_position_in_grid]]) {
331
  int64_t i3 = tpig.z;
332
  int64_t i2 = tpig.y;
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
920
  device const float * src1,
921
  device float * dst,
922
  constant int64_t & ne00,
923
- constant int64_t & ne01[[buffer(4)]],
924
- constant int64_t & ne02[[buffer(5)]],
925
- constant int64_t & ne10[[buffer(9)]],
926
- constant int64_t & ne12[[buffer(11)]],
927
- constant int64_t & ne0 [[buffer(15)]],
928
- constant int64_t & ne1 [[buffer(16)]],
929
- constant uint & r2 [[buffer(17)]],
930
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
931
  uint3 tgpig[[threadgroup_position_in_grid]],
932
  uint tiisg[[thread_index_in_simdgroup]],
933
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
939
  device const float * src1,
940
  device float * dst,
941
  constant int64_t & ne00,
942
- constant int64_t & ne01[[buffer(4)]],
943
- constant int64_t & ne02[[buffer(5)]],
944
- constant int64_t & ne10[[buffer(9)]],
945
- constant int64_t & ne12[[buffer(11)]],
946
- constant int64_t & ne0 [[buffer(15)]],
947
- constant int64_t & ne1 [[buffer(16)]],
948
- constant uint & r2 [[buffer(17)]],
949
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
950
  uint3 tgpig[[threadgroup_position_in_grid]],
951
  uint tiisg[[thread_index_in_simdgroup]],
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
958
  device const float * src1,
959
  device float * dst,
960
  constant int64_t & ne00,
961
- constant int64_t & ne01[[buffer(4)]],
962
- constant int64_t & ne02[[buffer(5)]],
963
- constant int64_t & ne10[[buffer(9)]],
964
- constant int64_t & ne12[[buffer(11)]],
965
- constant int64_t & ne0 [[buffer(15)]],
966
- constant int64_t & ne1 [[buffer(16)]],
967
- constant uint & r2 [[buffer(17)]],
968
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
969
  uint3 tgpig[[threadgroup_position_in_grid]],
970
  uint tiisg[[thread_index_in_simdgroup]],
971
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
977
  device const float * src1,
978
  device float * dst,
979
  constant int64_t & ne00,
980
- constant int64_t & ne01[[buffer(4)]],
981
- constant int64_t & ne02[[buffer(5)]],
982
- constant int64_t & ne10[[buffer(9)]],
983
- constant int64_t & ne12[[buffer(11)]],
984
- constant int64_t & ne0 [[buffer(15)]],
985
- constant int64_t & ne1 [[buffer(16)]],
986
- constant uint & r2 [[buffer(17)]],
987
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
988
  uint3 tgpig[[threadgroup_position_in_grid]],
989
  uint tiisg[[thread_index_in_simdgroup]],
990
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
1071
  constant int64_t & ne00,
1072
  constant int64_t & ne01,
1073
  constant int64_t & ne02,
 
 
 
1074
  constant int64_t & ne10,
 
1075
  constant int64_t & ne12,
 
 
 
1076
  constant int64_t & ne0,
1077
  constant int64_t & ne1,
1078
- constant uint & r2 [[buffer(17)]],
1079
- constant uint & r3 [[buffer(18)]],
1080
  uint3 tgpig[[threadgroup_position_in_grid]],
1081
  uint tiisg[[thread_index_in_simdgroup]],
1082
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
1182
  constant uint64_t & nb12,
1183
  constant int64_t & ne0,
1184
  constant int64_t & ne1,
1185
- constant uint & r2 [[buffer(17)]],
1186
- constant uint & r3 [[buffer(18)]],
1187
  uint3 tgpig[[threadgroup_position_in_grid]],
1188
  uint tiisg[[thread_index_in_simdgroup]]) {
1189
  kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
1209
  constant uint64_t & nb12,
1210
  constant int64_t & ne0,
1211
  constant int64_t & ne1,
1212
- constant uint & r2 [[buffer(17)]],
1213
- constant uint & r3 [[buffer(18)]],
1214
  uint3 tgpig[[threadgroup_position_in_grid]],
1215
  uint tiisg[[thread_index_in_simdgroup]]) {
1216
 
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
1346
  constant uint64_t & nb12,
1347
  constant int64_t & ne0,
1348
  constant int64_t & ne1,
1349
- constant uint & r2 [[buffer(17)]],
1350
- constant uint & r3 [[buffer(18)]],
1351
  uint3 tgpig[[threadgroup_position_in_grid]],
1352
  uint tiisg[[thread_index_in_simdgroup]]) {
1353
  kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
1452
  constant uint64_t & nb12,
1453
  constant int64_t & ne0,
1454
  constant int64_t & ne1,
1455
- constant uint & r2 [[buffer(17)]],
1456
- constant uint & r3 [[buffer(18)]],
1457
  uint3 tgpig[[threadgroup_position_in_grid]],
1458
  uint tiisg[[thread_index_in_simdgroup]]) {
1459
  kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1478
  constant uint64_t & nb12,
1479
  constant int64_t & ne0,
1480
  constant int64_t & ne1,
1481
- constant uint & r2 [[buffer(17)]],
1482
- constant uint & r3 [[buffer(18)]],
1483
  uint3 tgpig[[threadgroup_position_in_grid]],
1484
  uint tiisg[[thread_index_in_simdgroup]]) {
1485
 
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
1543
  const int64_t i3 = n / (ne2*ne1*ne0);
1544
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1545
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1546
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
 
1547
  const int64_t k = i3*ne3 + i2;
1548
 
1549
  float m_k;
@@ -2410,22 +2446,6 @@ typedef struct {
2410
  } block_q6_K;
2411
  // 210 bytes / block
2412
 
2413
- static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
2414
- uchar4 r;
2415
- if (j < 4) {
2416
- r[0] = q[j+0] & 63;
2417
- r[2] = q[j+1] & 63;
2418
- r[1] = q[j+4] & 63;
2419
- r[3] = q[j+5] & 63;
2420
- } else {
2421
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
2422
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
2423
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2424
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
2425
- }
2426
- return r;
2427
- }
2428
-
2429
  //====================================== dot products =========================
2430
 
2431
  void kernel_mul_mv_q2_K_f32_impl(
@@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
2584
  device const float * src1,
2585
  device float * dst,
2586
  constant int64_t & ne00,
2587
- constant int64_t & ne01[[buffer(4)]],
2588
- constant int64_t & ne02[[buffer(5)]],
2589
- constant int64_t & ne10[[buffer(9)]],
2590
- constant int64_t & ne12[[buffer(11)]],
2591
- constant int64_t & ne0 [[buffer(15)]],
2592
- constant int64_t & ne1 [[buffer(16)]],
2593
- constant uint & r2 [[buffer(17)]],
2594
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
2595
  uint3 tgpig[[threadgroup_position_in_grid]],
2596
  uint tiisg[[thread_index_in_simdgroup]],
2597
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
2841
  device const float * src1,
2842
  device float * dst,
2843
  constant int64_t & ne00,
2844
- constant int64_t & ne01[[buffer(4)]],
2845
- constant int64_t & ne02[[buffer(5)]],
2846
- constant int64_t & ne10[[buffer(9)]],
2847
- constant int64_t & ne12[[buffer(11)]],
2848
- constant int64_t & ne0 [[buffer(15)]],
2849
- constant int64_t & ne1 [[buffer(16)]],
2850
- constant uint & r2 [[buffer(17)]],
2851
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
2852
  uint3 tgpig[[threadgroup_position_in_grid]],
2853
  uint tiisg[[thread_index_in_simdgroup]],
2854
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
2984
  constant uint & r2,
2985
  constant uint & r3,
2986
  uint3 tgpig[[threadgroup_position_in_grid]],
2987
- uint tiisg[[thread_index_in_simdgroup]],
2988
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2989
 
2990
  const int ix = tiisg/4; // 0...7
2991
  const int it = tiisg%4; // 0...3
@@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
2994
  const int r0 = tgpig.x;
2995
  const int r1 = tgpig.y;
2996
  const int im = tgpig.z;
2997
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2998
  const int ib_row = first_row * nb;
2999
 
3000
  const uint i12 = im%ne12;
@@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3060
  for (int row = 0; row < N_DST; ++row) {
3061
  all_sum = simd_sum(sumf[row]);
3062
  if (tiisg == 0) {
3063
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
3064
  }
3065
  }
3066
  }
@@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
3072
  device const float * src1,
3073
  device float * dst,
3074
  constant int64_t & ne00,
3075
- constant int64_t & ne01[[buffer(4)]],
3076
- constant int64_t & ne02[[buffer(5)]],
3077
- constant int64_t & ne10[[buffer(9)]],
3078
- constant int64_t & ne12[[buffer(11)]],
3079
- constant int64_t & ne0 [[buffer(15)]],
3080
- constant int64_t & ne1 [[buffer(16)]],
3081
- constant uint & r2 [[buffer(17)]],
3082
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
3083
  uint3 tgpig[[threadgroup_position_in_grid]],
3084
  uint tiisg[[thread_index_in_simdgroup]],
3085
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
3271
  device const float * src1,
3272
  device float * dst,
3273
  constant int64_t & ne00,
3274
- constant int64_t & ne01[[buffer(4)]],
3275
- constant int64_t & ne02[[buffer(5)]],
3276
- constant int64_t & ne10[[buffer(9)]],
3277
- constant int64_t & ne12[[buffer(11)]],
3278
- constant int64_t & ne0 [[buffer(15)]],
3279
- constant int64_t & ne1 [[buffer(16)]],
3280
- constant uint & r2 [[buffer(17)]],
3281
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
3282
  uint3 tgpig[[threadgroup_position_in_grid]],
3283
  uint tiisg[[thread_index_in_simdgroup]],
3284
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
3398
  device const float * src1,
3399
  device float * dst,
3400
  constant int64_t & ne00,
3401
- constant int64_t & ne01[[buffer(4)]],
3402
- constant int64_t & ne02[[buffer(5)]],
3403
- constant int64_t & ne10[[buffer(9)]],
3404
- constant int64_t & ne12[[buffer(11)]],
3405
- constant int64_t & ne0 [[buffer(15)]],
3406
- constant int64_t & ne1 [[buffer(16)]],
3407
- constant uint & r2 [[buffer(17)]],
3408
- constant uint & r3 [[buffer(18)]],
 
 
 
 
 
 
 
3409
  uint3 tgpig[[threadgroup_position_in_grid]],
3410
  uint tiisg[[thread_index_in_simdgroup]],
3411
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
3523
  device const int8_t * qs = ((device const int8_t *)xb->qs);
3524
  const half d = xb->d;
3525
 
3526
- for (int i=0;i<16;i++) {
3527
  reg[i/4][i%4] = (qs[i + 16*il] * d);
3528
  }
3529
  }
@@ -3792,12 +3847,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
3792
  device float * dst,
3793
  constant int64_t & ne00,
3794
  constant int64_t & ne02,
3795
- constant int64_t & nb01,
3796
- constant int64_t & nb02,
3797
  constant int64_t & ne12,
3798
- constant int64_t & nb10,
3799
- constant int64_t & nb11,
3800
- constant int64_t & nb12,
3801
  constant int64_t & ne0,
3802
  constant int64_t & ne1,
3803
  constant uint & r2,
@@ -3924,12 +3979,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
3924
  device float * dst,
3925
  constant int64_t & ne00,
3926
  constant int64_t & ne02,
3927
- constant int64_t & nb01,
3928
- constant int64_t & nb02,
3929
  constant int64_t & ne12,
3930
- constant int64_t & nb10,
3931
- constant int64_t & nb11,
3932
- constant int64_t & nb12,
3933
  constant int64_t & ne0,
3934
  constant int64_t & ne1,
3935
  constant uint & r2,
@@ -3965,19 +4020,19 @@ kernel void kernel_mul_mm_id(
3965
  device const uchar * ids,
3966
  device const uchar * src1,
3967
  device uchar * dst,
3968
- constant int64_t & nbi1,
3969
  constant int64_t & ne00,
3970
  constant int64_t & ne02,
3971
- constant int64_t & nb01,
3972
- constant int64_t & nb02,
3973
  constant int64_t & ne12,
3974
  constant int64_t & ne13,
3975
- constant int64_t & nb10,
3976
- constant int64_t & nb11,
3977
- constant int64_t & nb12,
3978
  constant int64_t & ne0,
3979
  constant int64_t & ne1,
3980
- constant int64_t & nb1,
3981
  constant uint & r2,
3982
  constant uint & r3,
3983
  constant int & idx,
@@ -4070,12 +4125,12 @@ typedef void (mat_mm_t)(
4070
  device float * dst,
4071
  constant int64_t & ne00,
4072
  constant int64_t & ne02,
4073
- constant int64_t & nb01,
4074
- constant int64_t & nb02,
4075
  constant int64_t & ne12,
4076
- constant int64_t & nb10,
4077
- constant int64_t & nb11,
4078
- constant int64_t & nb12,
4079
  constant int64_t & ne0,
4080
  constant int64_t & ne1,
4081
  constant uint & r2,
@@ -4104,19 +4159,19 @@ typedef void (mat_mm_id_t)(
4104
  device const uchar * ids,
4105
  device const uchar * src1,
4106
  device uchar * dst,
4107
- constant int64_t & nbi1,
4108
  constant int64_t & ne00,
4109
  constant int64_t & ne02,
4110
- constant int64_t & nb01,
4111
- constant int64_t & nb02,
4112
  constant int64_t & ne12,
4113
  constant int64_t & ne13,
4114
- constant int64_t & nb10,
4115
- constant int64_t & nb11,
4116
- constant int64_t & nb12,
4117
  constant int64_t & ne0,
4118
  constant int64_t & ne1,
4119
- constant int64_t & nb1,
4120
  constant uint & r2,
4121
  constant uint & r3,
4122
  constant int & idx,
@@ -4153,7 +4208,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4153
  device const char * ids,
4154
  device const char * src1,
4155
  device uchar * dst,
4156
- constant int64_t & nbi1,
4157
  constant int64_t & ne00,
4158
  constant int64_t & ne01,
4159
  constant int64_t & ne02,
@@ -4169,7 +4224,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4169
  constant uint64_t & nb12,
4170
  constant int64_t & ne0,
4171
  constant int64_t & ne1,
4172
- constant int64_t & nb1,
4173
  constant uint & r2,
4174
  constant uint & r3,
4175
  constant int & idx,
@@ -4222,7 +4277,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4222
  device const char * ids,
4223
  device const char * src1,
4224
  device uchar * dst,
4225
- constant int64_t & nbi1,
4226
  constant int64_t & ne00,
4227
  constant int64_t & ne01,
4228
  constant int64_t & ne02,
@@ -4238,7 +4293,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4238
  constant uint64_t & nb12,
4239
  constant int64_t & ne0,
4240
  constant int64_t & ne1,
4241
- constant int64_t & nb1,
4242
  constant uint & r2,
4243
  constant uint & r3,
4244
  constant int & idx,
@@ -4291,7 +4346,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4291
  device const char * ids,
4292
  device const char * src1,
4293
  device uchar * dst,
4294
- constant int64_t & nbi1,
4295
  constant int64_t & ne00,
4296
  constant int64_t & ne01,
4297
  constant int64_t & ne02,
@@ -4307,7 +4362,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4307
  constant uint64_t & nb12,
4308
  constant int64_t & ne0,
4309
  constant int64_t & ne1,
4310
- constant int64_t & nb1,
4311
  constant uint & r2,
4312
  constant uint & r3,
4313
  constant int & idx,
@@ -4354,7 +4409,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4354
  device const char * ids,
4355
  device const char * src1,
4356
  device uchar * dst,
4357
- constant int64_t & nbi1,
4358
  constant int64_t & ne00,
4359
  constant int64_t & ne01,
4360
  constant int64_t & ne02,
@@ -4370,7 +4425,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4370
  constant uint64_t & nb12,
4371
  constant int64_t & ne0,
4372
  constant int64_t & ne1,
4373
- constant int64_t & nb1,
4374
  constant uint & r2,
4375
  constant uint & r3,
4376
  constant int & idx,
@@ -4417,7 +4472,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4417
  device const char * ids,
4418
  device const char * src1,
4419
  device uchar * dst,
4420
- constant int64_t & nbi1,
4421
  constant int64_t & ne00,
4422
  constant int64_t & ne01,
4423
  constant int64_t & ne02,
@@ -4433,7 +4488,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4433
  constant uint64_t & nb12,
4434
  constant int64_t & ne0,
4435
  constant int64_t & ne1,
4436
- constant int64_t & nb1,
4437
  constant uint & r2,
4438
  constant uint & r3,
4439
  constant int & idx,
@@ -4480,7 +4535,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4480
  device const char * ids,
4481
  device const char * src1,
4482
  device uchar * dst,
4483
- constant int64_t & nbi1,
4484
  constant int64_t & ne00,
4485
  constant int64_t & ne01,
4486
  constant int64_t & ne02,
@@ -4496,7 +4551,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4496
  constant uint64_t & nb12,
4497
  constant int64_t & ne0,
4498
  constant int64_t & ne1,
4499
- constant int64_t & nb1,
4500
  constant uint & r2,
4501
  constant uint & r3,
4502
  constant int & idx,
@@ -4543,7 +4598,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4543
  device const char * ids,
4544
  device const char * src1,
4545
  device uchar * dst,
4546
- constant int64_t & nbi1,
4547
  constant int64_t & ne00,
4548
  constant int64_t & ne01,
4549
  constant int64_t & ne02,
@@ -4559,7 +4614,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4559
  constant uint64_t & nb12,
4560
  constant int64_t & ne0,
4561
  constant int64_t & ne1,
4562
- constant int64_t & nb1,
4563
  constant uint & r2,
4564
  constant uint & r3,
4565
  constant int & idx,
@@ -4606,7 +4661,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4606
  device const char * ids,
4607
  device const char * src1,
4608
  device uchar * dst,
4609
- constant int64_t & nbi1,
4610
  constant int64_t & ne00,
4611
  constant int64_t & ne01,
4612
  constant int64_t & ne02,
@@ -4622,7 +4677,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4622
  constant uint64_t & nb12,
4623
  constant int64_t & ne0,
4624
  constant int64_t & ne1,
4625
- constant int64_t & nb1,
4626
  constant uint & r2,
4627
  constant uint & r3,
4628
  constant int & idx,
@@ -4669,7 +4724,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4669
  device const char * ids,
4670
  device const char * src1,
4671
  device uchar * dst,
4672
- constant int64_t & nbi1,
4673
  constant int64_t & ne00,
4674
  constant int64_t & ne01,
4675
  constant int64_t & ne02,
@@ -4685,7 +4740,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4685
  constant uint64_t & nb12,
4686
  constant int64_t & ne0,
4687
  constant int64_t & ne1,
4688
- constant int64_t & nb1,
4689
  constant uint & r2,
4690
  constant uint & r3,
4691
  constant int & idx,
@@ -4732,7 +4787,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4732
  device const char * ids,
4733
  device const char * src1,
4734
  device uchar * dst,
4735
- constant int64_t & nbi1,
4736
  constant int64_t & ne00,
4737
  constant int64_t & ne01,
4738
  constant int64_t & ne02,
@@ -4748,7 +4803,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4748
  constant uint64_t & nb12,
4749
  constant int64_t & ne0,
4750
  constant int64_t & ne1,
4751
- constant int64_t & nb1,
4752
  constant uint & r2,
4753
  constant uint & r3,
4754
  constant int & idx,
@@ -4795,7 +4850,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4795
  device const char * ids,
4796
  device const char * src1,
4797
  device uchar * dst,
4798
- constant int64_t & nbi1,
4799
  constant int64_t & ne00,
4800
  constant int64_t & ne01,
4801
  constant int64_t & ne02,
@@ -4811,7 +4866,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4811
  constant uint64_t & nb12,
4812
  constant int64_t & ne0,
4813
  constant int64_t & ne1,
4814
- constant int64_t & nb1,
4815
  constant uint & r2,
4816
  constant uint & r3,
4817
  constant int & idx,
@@ -4858,7 +4913,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4858
  device const char * ids,
4859
  device const char * src1,
4860
  device uchar * dst,
4861
- constant int64_t & nbi1,
4862
  constant int64_t & ne00,
4863
  constant int64_t & ne01,
4864
  constant int64_t & ne02,
@@ -4874,7 +4929,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4874
  constant uint64_t & nb12,
4875
  constant int64_t & ne0,
4876
  constant int64_t & ne1,
4877
- constant int64_t & nb1,
4878
  constant uint & r2,
4879
  constant uint & r3,
4880
  constant int & idx,
 
59
  constant int64_t & ne01,
60
  constant int64_t & ne02,
61
  constant int64_t & ne03,
62
+ constant uint64_t & nb00,
63
+ constant uint64_t & nb01,
64
+ constant uint64_t & nb02,
65
+ constant uint64_t & nb03,
66
  constant int64_t & ne10,
67
  constant int64_t & ne11,
68
  constant int64_t & ne12,
69
  constant int64_t & ne13,
70
+ constant uint64_t & nb10,
71
+ constant uint64_t & nb11,
72
+ constant uint64_t & nb12,
73
+ constant uint64_t & nb13,
74
  constant int64_t & ne0,
75
  constant int64_t & ne1,
76
  constant int64_t & ne2,
77
  constant int64_t & ne3,
78
+ constant uint64_t & nb0,
79
+ constant uint64_t & nb1,
80
+ constant uint64_t & nb2,
81
+ constant uint64_t & nb3,
82
  constant int64_t & offs,
83
  uint3 tgpig[[threadgroup_position_in_grid]],
84
  uint3 tpitg[[thread_position_in_threadgroup]],
 
109
  constant int64_t & ne01,
110
  constant int64_t & ne02,
111
  constant int64_t & ne03,
112
+ constant uint64_t & nb00,
113
+ constant uint64_t & nb01,
114
+ constant uint64_t & nb02,
115
+ constant uint64_t & nb03,
116
  constant int64_t & ne10,
117
  constant int64_t & ne11,
118
  constant int64_t & ne12,
119
  constant int64_t & ne13,
120
+ constant uint64_t & nb10,
121
+ constant uint64_t & nb11,
122
+ constant uint64_t & nb12,
123
+ constant uint64_t & nb13,
124
  constant int64_t & ne0,
125
  constant int64_t & ne1,
126
  constant int64_t & ne2,
127
  constant int64_t & ne3,
128
+ constant uint64_t & nb0,
129
+ constant uint64_t & nb1,
130
+ constant uint64_t & nb2,
131
+ constant uint64_t & nb3,
132
  uint3 tgpig[[threadgroup_position_in_grid]],
133
  uint3 tpitg[[thread_position_in_threadgroup]],
134
  uint3 ntg[[threads_per_threadgroup]]) {
 
158
  constant int64_t & ne01,
159
  constant int64_t & ne02,
160
  constant int64_t & ne03,
161
+ constant uint64_t & nb00,
162
+ constant uint64_t & nb01,
163
+ constant uint64_t & nb02,
164
+ constant uint64_t & nb03,
165
  constant int64_t & ne10,
166
  constant int64_t & ne11,
167
  constant int64_t & ne12,
168
  constant int64_t & ne13,
169
+ constant uint64_t & nb10,
170
+ constant uint64_t & nb11,
171
+ constant uint64_t & nb12,
172
+ constant uint64_t & nb13,
173
  constant int64_t & ne0,
174
  constant int64_t & ne1,
175
  constant int64_t & ne2,
176
  constant int64_t & ne3,
177
+ constant uint64_t & nb0,
178
+ constant uint64_t & nb1,
179
+ constant uint64_t & nb2,
180
+ constant uint64_t & nb3,
181
  uint3 tgpig[[threadgroup_position_in_grid]],
182
  uint3 tpitg[[thread_position_in_threadgroup]],
183
  uint3 ntg[[threads_per_threadgroup]]) {
 
205
  device const float4 * src0,
206
  device const float4 * src1,
207
  device float4 * dst,
208
+ constant uint64_t & nb [[buffer(28)]],
209
  uint tpig[[thread_position_in_grid]]) {
210
  dst[tpig] = src0[tpig] + src1[tpig % nb];
211
  }
 
214
  device const float4 * src0,
215
  device const float4 * src1,
216
  device float4 * dst,
217
+ constant uint64_t & nb [[buffer(28)]],
218
  uint tpig[[thread_position_in_grid]]) {
219
  dst[tpig] = src0[tpig] * src1[tpig % nb];
220
  }
 
223
  device const float4 * src0,
224
  device const float4 * src1,
225
  device float4 * dst,
226
+ constant uint64_t & nb [[buffer(28)]],
227
  uint tpig[[thread_position_in_grid]]) {
228
  dst[tpig] = src0[tpig] / src1[tpig % nb];
229
  }
 
307
  constant int64_t & ne01,
308
  constant int64_t & ne02,
309
  constant int64_t & ne03,
310
+ constant uint64_t & nb00,
311
+ constant uint64_t & nb01,
312
+ constant uint64_t & nb02,
313
+ constant uint64_t & nb03,
314
  constant int64_t & ne10,
315
  constant int64_t & ne11,
316
  constant int64_t & ne12,
317
  constant int64_t & ne13,
318
+ constant uint64_t & nb10,
319
+ constant uint64_t & nb11,
320
+ constant uint64_t & nb12,
321
+ constant uint64_t & nb13,
322
  constant int64_t & ne0,
323
  constant int64_t & ne1,
324
  constant int64_t & ne2,
325
  constant int64_t & ne3,
326
+ constant uint64_t & nb0,
327
+ constant uint64_t & nb1,
328
+ constant uint64_t & nb2,
329
+ constant uint64_t & nb3,
330
  uint3 tpig[[thread_position_in_grid]]) {
331
  int64_t i3 = tpig.z;
332
  int64_t i2 = tpig.y;
 
920
  device const float * src1,
921
  device float * dst,
922
  constant int64_t & ne00,
923
+ constant int64_t & ne01,
924
+ constant int64_t & ne02,
925
+ constant uint64_t & nb00,
926
+ constant uint64_t & nb01,
927
+ constant uint64_t & nb02,
928
+ constant int64_t & ne10,
929
+ constant int64_t & ne11,
930
+ constant int64_t & ne12,
931
+ constant uint64_t & nb10,
932
+ constant uint64_t & nb11,
933
+ constant uint64_t & nb12,
934
+ constant int64_t & ne0,
935
+ constant int64_t & ne1,
936
+ constant uint & r2,
937
+ constant uint & r3,
938
  uint3 tgpig[[threadgroup_position_in_grid]],
939
  uint tiisg[[thread_index_in_simdgroup]],
940
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
946
  device const float * src1,
947
  device float * dst,
948
  constant int64_t & ne00,
949
+ constant int64_t & ne01,
950
+ constant int64_t & ne02,
951
+ constant uint64_t & nb00,
952
+ constant uint64_t & nb01,
953
+ constant uint64_t & nb02,
954
+ constant int64_t & ne10,
955
+ constant int64_t & ne11,
956
+ constant int64_t & ne12,
957
+ constant uint64_t & nb10,
958
+ constant uint64_t & nb11,
959
+ constant uint64_t & nb12,
960
+ constant int64_t & ne0,
961
+ constant int64_t & ne1,
962
+ constant uint & r2,
963
+ constant uint & r3,
964
  uint3 tgpig[[threadgroup_position_in_grid]],
965
  uint tiisg[[thread_index_in_simdgroup]],
966
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
972
  device const float * src1,
973
  device float * dst,
974
  constant int64_t & ne00,
975
+ constant int64_t & ne01,
976
+ constant int64_t & ne02,
977
+ constant uint64_t & nb00,
978
+ constant uint64_t & nb01,
979
+ constant uint64_t & nb02,
980
+ constant int64_t & ne10,
981
+ constant int64_t & ne11,
982
+ constant int64_t & ne12,
983
+ constant uint64_t & nb10,
984
+ constant uint64_t & nb11,
985
+ constant uint64_t & nb12,
986
+ constant int64_t & ne0,
987
+ constant int64_t & ne1,
988
+ constant uint & r2,
989
+ constant uint & r3,
990
  uint3 tgpig[[threadgroup_position_in_grid]],
991
  uint tiisg[[thread_index_in_simdgroup]],
992
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
998
  device const float * src1,
999
  device float * dst,
1000
  constant int64_t & ne00,
1001
+ constant int64_t & ne01,
1002
+ constant int64_t & ne02,
1003
+ constant uint64_t & nb00,
1004
+ constant uint64_t & nb01,
1005
+ constant uint64_t & nb02,
1006
+ constant int64_t & ne10,
1007
+ constant int64_t & ne11,
1008
+ constant int64_t & ne12,
1009
+ constant uint64_t & nb10,
1010
+ constant uint64_t & nb11,
1011
+ constant uint64_t & nb12,
1012
+ constant int64_t & ne0,
1013
+ constant int64_t & ne1,
1014
+ constant uint & r2,
1015
+ constant uint & r3,
1016
  uint3 tgpig[[threadgroup_position_in_grid]],
1017
  uint tiisg[[thread_index_in_simdgroup]],
1018
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
1099
  constant int64_t & ne00,
1100
  constant int64_t & ne01,
1101
  constant int64_t & ne02,
1102
+ constant uint64_t & nb00,
1103
+ constant uint64_t & nb01,
1104
+ constant uint64_t & nb02,
1105
  constant int64_t & ne10,
1106
+ constant int64_t & ne11,
1107
  constant int64_t & ne12,
1108
+ constant uint64_t & nb10,
1109
+ constant uint64_t & nb11,
1110
+ constant uint64_t & nb12,
1111
  constant int64_t & ne0,
1112
  constant int64_t & ne1,
1113
+ constant uint & r2,
1114
+ constant uint & r3,
1115
  uint3 tgpig[[threadgroup_position_in_grid]],
1116
  uint tiisg[[thread_index_in_simdgroup]],
1117
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
1217
  constant uint64_t & nb12,
1218
  constant int64_t & ne0,
1219
  constant int64_t & ne1,
1220
+ constant uint & r2,
1221
+ constant uint & r3,
1222
  uint3 tgpig[[threadgroup_position_in_grid]],
1223
  uint tiisg[[thread_index_in_simdgroup]]) {
1224
  kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
 
1244
  constant uint64_t & nb12,
1245
  constant int64_t & ne0,
1246
  constant int64_t & ne1,
1247
+ constant uint & r2,
1248
+ constant uint & r3,
1249
  uint3 tgpig[[threadgroup_position_in_grid]],
1250
  uint tiisg[[thread_index_in_simdgroup]]) {
1251
 
 
1381
  constant uint64_t & nb12,
1382
  constant int64_t & ne0,
1383
  constant int64_t & ne1,
1384
+ constant uint & r2,
1385
+ constant uint & r3,
1386
  uint3 tgpig[[threadgroup_position_in_grid]],
1387
  uint tiisg[[thread_index_in_simdgroup]]) {
1388
  kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
 
1487
  constant uint64_t & nb12,
1488
  constant int64_t & ne0,
1489
  constant int64_t & ne1,
1490
+ constant uint & r2,
1491
+ constant uint & r3,
1492
  uint3 tgpig[[threadgroup_position_in_grid]],
1493
  uint tiisg[[thread_index_in_simdgroup]]) {
1494
  kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
 
1513
  constant uint64_t & nb12,
1514
  constant int64_t & ne0,
1515
  constant int64_t & ne1,
1516
+ constant uint & r2,
1517
+ constant uint & r3,
1518
  uint3 tgpig[[threadgroup_position_in_grid]],
1519
  uint tiisg[[thread_index_in_simdgroup]]) {
1520
 
 
1578
  const int64_t i3 = n / (ne2*ne1*ne0);
1579
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1580
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1581
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1582
+
1583
  const int64_t k = i3*ne3 + i2;
1584
 
1585
  float m_k;
 
2446
  } block_q6_K;
2447
  // 210 bytes / block
2448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2449
  //====================================== dot products =========================
2450
 
2451
  void kernel_mul_mv_q2_K_f32_impl(
 
2604
  device const float * src1,
2605
  device float * dst,
2606
  constant int64_t & ne00,
2607
+ constant int64_t & ne01,
2608
+ constant int64_t & ne02,
2609
+ constant uint64_t & nb00,
2610
+ constant uint64_t & nb01,
2611
+ constant uint64_t & nb02,
2612
+ constant int64_t & ne10,
2613
+ constant int64_t & ne11,
2614
+ constant int64_t & ne12,
2615
+ constant uint64_t & nb10,
2616
+ constant uint64_t & nb11,
2617
+ constant uint64_t & nb12,
2618
+ constant int64_t & ne0,
2619
+ constant int64_t & ne1,
2620
+ constant uint & r2,
2621
+ constant uint & r3,
2622
  uint3 tgpig[[threadgroup_position_in_grid]],
2623
  uint tiisg[[thread_index_in_simdgroup]],
2624
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2868
  device const float * src1,
2869
  device float * dst,
2870
  constant int64_t & ne00,
2871
+ constant int64_t & ne01,
2872
+ constant int64_t & ne02,
2873
+ constant uint64_t & nb00,
2874
+ constant uint64_t & nb01,
2875
+ constant uint64_t & nb02,
2876
+ constant int64_t & ne10,
2877
+ constant int64_t & ne11,
2878
+ constant int64_t & ne12,
2879
+ constant uint64_t & nb10,
2880
+ constant uint64_t & nb11,
2881
+ constant uint64_t & nb12,
2882
+ constant int64_t & ne0,
2883
+ constant int64_t & ne1,
2884
+ constant uint & r2,
2885
+ constant uint & r3,
2886
  uint3 tgpig[[threadgroup_position_in_grid]],
2887
  uint tiisg[[thread_index_in_simdgroup]],
2888
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3018
  constant uint & r2,
3019
  constant uint & r3,
3020
  uint3 tgpig[[threadgroup_position_in_grid]],
3021
+ uint tiisg[[thread_index_in_simdgroup]],
3022
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3023
 
3024
  const int ix = tiisg/4; // 0...7
3025
  const int it = tiisg%4; // 0...3
 
3028
  const int r0 = tgpig.x;
3029
  const int r1 = tgpig.y;
3030
  const int im = tgpig.z;
3031
+ const int first_row = r0 * N_DST;
3032
  const int ib_row = first_row * nb;
3033
 
3034
  const uint i12 = im%ne12;
 
3094
  for (int row = 0; row < N_DST; ++row) {
3095
  all_sum = simd_sum(sumf[row]);
3096
  if (tiisg == 0) {
3097
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
3098
  }
3099
  }
3100
  }
 
3106
  device const float * src1,
3107
  device float * dst,
3108
  constant int64_t & ne00,
3109
+ constant int64_t & ne01,
3110
+ constant int64_t & ne02,
3111
+ constant uint64_t & nb00,
3112
+ constant uint64_t & nb01,
3113
+ constant uint64_t & nb02,
3114
+ constant int64_t & ne10,
3115
+ constant int64_t & ne11,
3116
+ constant int64_t & ne12,
3117
+ constant uint64_t & nb10,
3118
+ constant uint64_t & nb11,
3119
+ constant uint64_t & nb12,
3120
+ constant int64_t & ne0,
3121
+ constant int64_t & ne1,
3122
+ constant uint & r2,
3123
+ constant uint & r3,
3124
  uint3 tgpig[[threadgroup_position_in_grid]],
3125
  uint tiisg[[thread_index_in_simdgroup]],
3126
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3312
  device const float * src1,
3313
  device float * dst,
3314
  constant int64_t & ne00,
3315
+ constant int64_t & ne01,
3316
+ constant int64_t & ne02,
3317
+ constant uint64_t & nb00,
3318
+ constant uint64_t & nb01,
3319
+ constant uint64_t & nb02,
3320
+ constant int64_t & ne10,
3321
+ constant int64_t & ne11,
3322
+ constant int64_t & ne12,
3323
+ constant uint64_t & nb10,
3324
+ constant uint64_t & nb11,
3325
+ constant uint64_t & nb12,
3326
+ constant int64_t & ne0,
3327
+ constant int64_t & ne1,
3328
+ constant uint & r2,
3329
+ constant uint & r3,
3330
  uint3 tgpig[[threadgroup_position_in_grid]],
3331
  uint tiisg[[thread_index_in_simdgroup]],
3332
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3446
  device const float * src1,
3447
  device float * dst,
3448
  constant int64_t & ne00,
3449
+ constant int64_t & ne01,
3450
+ constant int64_t & ne02,
3451
+ constant uint64_t & nb00,
3452
+ constant uint64_t & nb01,
3453
+ constant uint64_t & nb02,
3454
+ constant int64_t & ne10,
3455
+ constant int64_t & ne11,
3456
+ constant int64_t & ne12,
3457
+ constant uint64_t & nb10,
3458
+ constant uint64_t & nb11,
3459
+ constant uint64_t & nb12,
3460
+ constant int64_t & ne0,
3461
+ constant int64_t & ne1,
3462
+ constant uint & r2,
3463
+ constant uint & r3,
3464
  uint3 tgpig[[threadgroup_position_in_grid]],
3465
  uint tiisg[[thread_index_in_simdgroup]],
3466
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3578
  device const int8_t * qs = ((device const int8_t *)xb->qs);
3579
  const half d = xb->d;
3580
 
3581
+ for (int i = 0; i < 16; i++) {
3582
  reg[i/4][i%4] = (qs[i + 16*il] * d);
3583
  }
3584
  }
 
3847
  device float * dst,
3848
  constant int64_t & ne00,
3849
  constant int64_t & ne02,
3850
+ constant uint64_t & nb01,
3851
+ constant uint64_t & nb02,
3852
  constant int64_t & ne12,
3853
+ constant uint64_t & nb10,
3854
+ constant uint64_t & nb11,
3855
+ constant uint64_t & nb12,
3856
  constant int64_t & ne0,
3857
  constant int64_t & ne1,
3858
  constant uint & r2,
 
3979
  device float * dst,
3980
  constant int64_t & ne00,
3981
  constant int64_t & ne02,
3982
+ constant uint64_t & nb01,
3983
+ constant uint64_t & nb02,
3984
  constant int64_t & ne12,
3985
+ constant uint64_t & nb10,
3986
+ constant uint64_t & nb11,
3987
+ constant uint64_t & nb12,
3988
  constant int64_t & ne0,
3989
  constant int64_t & ne1,
3990
  constant uint & r2,
 
4020
  device const uchar * ids,
4021
  device const uchar * src1,
4022
  device uchar * dst,
4023
+ constant uint64_t & nbi1,
4024
  constant int64_t & ne00,
4025
  constant int64_t & ne02,
4026
+ constant uint64_t & nb01,
4027
+ constant uint64_t & nb02,
4028
  constant int64_t & ne12,
4029
  constant int64_t & ne13,
4030
+ constant uint64_t & nb10,
4031
+ constant uint64_t & nb11,
4032
+ constant uint64_t & nb12,
4033
  constant int64_t & ne0,
4034
  constant int64_t & ne1,
4035
+ constant uint64_t & nb1,
4036
  constant uint & r2,
4037
  constant uint & r3,
4038
  constant int & idx,
 
4125
  device float * dst,
4126
  constant int64_t & ne00,
4127
  constant int64_t & ne02,
4128
+ constant uint64_t & nb01,
4129
+ constant uint64_t & nb02,
4130
  constant int64_t & ne12,
4131
+ constant uint64_t & nb10,
4132
+ constant uint64_t & nb11,
4133
+ constant uint64_t & nb12,
4134
  constant int64_t & ne0,
4135
  constant int64_t & ne1,
4136
  constant uint & r2,
 
4159
  device const uchar * ids,
4160
  device const uchar * src1,
4161
  device uchar * dst,
4162
+ constant uint64_t & nbi1,
4163
  constant int64_t & ne00,
4164
  constant int64_t & ne02,
4165
+ constant uint64_t & nb01,
4166
+ constant uint64_t & nb02,
4167
  constant int64_t & ne12,
4168
  constant int64_t & ne13,
4169
+ constant uint64_t & nb10,
4170
+ constant uint64_t & nb11,
4171
+ constant uint64_t & nb12,
4172
  constant int64_t & ne0,
4173
  constant int64_t & ne1,
4174
+ constant uint64_t & nb1,
4175
  constant uint & r2,
4176
  constant uint & r3,
4177
  constant int & idx,
 
4208
  device const char * ids,
4209
  device const char * src1,
4210
  device uchar * dst,
4211
+ constant uint64_t & nbi1,
4212
  constant int64_t & ne00,
4213
  constant int64_t & ne01,
4214
  constant int64_t & ne02,
 
4224
  constant uint64_t & nb12,
4225
  constant int64_t & ne0,
4226
  constant int64_t & ne1,
4227
+ constant uint64_t & nb1,
4228
  constant uint & r2,
4229
  constant uint & r3,
4230
  constant int & idx,
 
4277
  device const char * ids,
4278
  device const char * src1,
4279
  device uchar * dst,
4280
+ constant uint64_t & nbi1,
4281
  constant int64_t & ne00,
4282
  constant int64_t & ne01,
4283
  constant int64_t & ne02,
 
4293
  constant uint64_t & nb12,
4294
  constant int64_t & ne0,
4295
  constant int64_t & ne1,
4296
+ constant uint64_t & nb1,
4297
  constant uint & r2,
4298
  constant uint & r3,
4299
  constant int & idx,
 
4346
  device const char * ids,
4347
  device const char * src1,
4348
  device uchar * dst,
4349
+ constant uint64_t & nbi1,
4350
  constant int64_t & ne00,
4351
  constant int64_t & ne01,
4352
  constant int64_t & ne02,
 
4362
  constant uint64_t & nb12,
4363
  constant int64_t & ne0,
4364
  constant int64_t & ne1,
4365
+ constant uint64_t & nb1,
4366
  constant uint & r2,
4367
  constant uint & r3,
4368
  constant int & idx,
 
4409
  device const char * ids,
4410
  device const char * src1,
4411
  device uchar * dst,
4412
+ constant uint64_t & nbi1,
4413
  constant int64_t & ne00,
4414
  constant int64_t & ne01,
4415
  constant int64_t & ne02,
 
4425
  constant uint64_t & nb12,
4426
  constant int64_t & ne0,
4427
  constant int64_t & ne1,
4428
+ constant uint64_t & nb1,
4429
  constant uint & r2,
4430
  constant uint & r3,
4431
  constant int & idx,
 
4472
  device const char * ids,
4473
  device const char * src1,
4474
  device uchar * dst,
4475
+ constant uint64_t & nbi1,
4476
  constant int64_t & ne00,
4477
  constant int64_t & ne01,
4478
  constant int64_t & ne02,
 
4488
  constant uint64_t & nb12,
4489
  constant int64_t & ne0,
4490
  constant int64_t & ne1,
4491
+ constant uint64_t & nb1,
4492
  constant uint & r2,
4493
  constant uint & r3,
4494
  constant int & idx,
 
4535
  device const char * ids,
4536
  device const char * src1,
4537
  device uchar * dst,
4538
+ constant uint64_t & nbi1,
4539
  constant int64_t & ne00,
4540
  constant int64_t & ne01,
4541
  constant int64_t & ne02,
 
4551
  constant uint64_t & nb12,
4552
  constant int64_t & ne0,
4553
  constant int64_t & ne1,
4554
+ constant uint64_t & nb1,
4555
  constant uint & r2,
4556
  constant uint & r3,
4557
  constant int & idx,
 
4598
  device const char * ids,
4599
  device const char * src1,
4600
  device uchar * dst,
4601
+ constant uint64_t & nbi1,
4602
  constant int64_t & ne00,
4603
  constant int64_t & ne01,
4604
  constant int64_t & ne02,
 
4614
  constant uint64_t & nb12,
4615
  constant int64_t & ne0,
4616
  constant int64_t & ne1,
4617
+ constant uint64_t & nb1,
4618
  constant uint & r2,
4619
  constant uint & r3,
4620
  constant int & idx,
 
4661
  device const char * ids,
4662
  device const char * src1,
4663
  device uchar * dst,
4664
+ constant uint64_t & nbi1,
4665
  constant int64_t & ne00,
4666
  constant int64_t & ne01,
4667
  constant int64_t & ne02,
 
4677
  constant uint64_t & nb12,
4678
  constant int64_t & ne0,
4679
  constant int64_t & ne1,
4680
+ constant uint64_t & nb1,
4681
  constant uint & r2,
4682
  constant uint & r3,
4683
  constant int & idx,
 
4724
  device const char * ids,
4725
  device const char * src1,
4726
  device uchar * dst,
4727
+ constant uint64_t & nbi1,
4728
  constant int64_t & ne00,
4729
  constant int64_t & ne01,
4730
  constant int64_t & ne02,
 
4740
  constant uint64_t & nb12,
4741
  constant int64_t & ne0,
4742
  constant int64_t & ne1,
4743
+ constant uint64_t & nb1,
4744
  constant uint & r2,
4745
  constant uint & r3,
4746
  constant int & idx,
 
4787
  device const char * ids,
4788
  device const char * src1,
4789
  device uchar * dst,
4790
+ constant uint64_t & nbi1,
4791
  constant int64_t & ne00,
4792
  constant int64_t & ne01,
4793
  constant int64_t & ne02,
 
4803
  constant uint64_t & nb12,
4804
  constant int64_t & ne0,
4805
  constant int64_t & ne1,
4806
+ constant uint64_t & nb1,
4807
  constant uint & r2,
4808
  constant uint & r3,
4809
  constant int & idx,
 
4850
  device const char * ids,
4851
  device const char * src1,
4852
  device uchar * dst,
4853
+ constant uint64_t & nbi1,
4854
  constant int64_t & ne00,
4855
  constant int64_t & ne01,
4856
  constant int64_t & ne02,
 
4866
  constant uint64_t & nb12,
4867
  constant int64_t & ne0,
4868
  constant int64_t & ne1,
4869
+ constant uint64_t & nb1,
4870
  constant uint & r2,
4871
  constant uint & r3,
4872
  constant int & idx,
 
4913
  device const char * ids,
4914
  device const char * src1,
4915
  device uchar * dst,
4916
+ constant uint64_t & nbi1,
4917
  constant int64_t & ne00,
4918
  constant int64_t & ne01,
4919
  constant int64_t & ne02,
 
4929
  constant uint64_t & nb12,
4930
  constant int64_t & ne0,
4931
  constant int64_t & ne1,
4932
+ constant uint64_t & nb1,
4933
  constant uint & r2,
4934
  constant uint & r3,
4935
  constant int & idx,