slaren commited on
Commit
e9910b5
·
1 Parent(s): 1706870

metal : unify mul_mv_id kernels (llama/6556)

Browse files
Files changed (3) hide show
  1. ggml-metal.m +5 -0
  2. ggml-metal.metal +135 -1054
  3. ggml.c +0 -1
ggml-metal.m CHANGED
@@ -1941,7 +1941,12 @@ static enum ggml_status ggml_metal_graph_compute(
1941
  {
1942
  nth0 = 4;
1943
  nth1 = 16;
 
 
 
1944
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
 
 
1945
  } break;
1946
  default:
1947
  {
 
1941
  {
1942
  nth0 = 4;
1943
  nth1 = 16;
1944
+ #if QK_K == 64
1945
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1946
+ #else
1947
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1948
+ #endif
1949
+
1950
  } break;
1951
  default:
1952
  {
ggml-metal.metal CHANGED
@@ -864,15 +864,16 @@ void mul_vec_q_n_f32_impl(
864
  device const void * src0,
865
  device const float * src1,
866
  device float * dst,
867
- int64_t ne00,
868
- int64_t ne01,
869
- int64_t ne02,
870
- int64_t ne10,
871
- int64_t ne12,
872
- int64_t ne0,
873
- int64_t ne1,
874
- uint r2,
875
- uint r3,
 
876
  uint3 tgpig, uint tiisg, uint sgitg) {
877
  const int nb = ne00/QK4_0;
878
 
@@ -949,7 +950,7 @@ kernel void kernel_mul_mv_q4_0_f32(
949
  uint3 tgpig[[threadgroup_position_in_grid]],
950
  uint tiisg[[thread_index_in_simdgroup]],
951
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
952
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
953
  }
954
 
955
  kernel void kernel_mul_mv_q4_1_f32(
@@ -975,7 +976,7 @@ kernel void kernel_mul_mv_q4_1_f32(
975
  uint3 tgpig[[threadgroup_position_in_grid]],
976
  uint tiisg[[thread_index_in_simdgroup]],
977
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
978
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
979
  }
980
 
981
  kernel void kernel_mul_mv_q5_0_f32(
@@ -1001,7 +1002,7 @@ kernel void kernel_mul_mv_q5_0_f32(
1001
  uint3 tgpig[[threadgroup_position_in_grid]],
1002
  uint tiisg[[thread_index_in_simdgroup]],
1003
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1004
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1005
  }
1006
 
1007
  kernel void kernel_mul_mv_q5_1_f32(
@@ -1027,7 +1028,7 @@ kernel void kernel_mul_mv_q5_1_f32(
1027
  uint3 tgpig[[threadgroup_position_in_grid]],
1028
  uint tiisg[[thread_index_in_simdgroup]],
1029
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1030
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1031
  }
1032
 
1033
 
@@ -1046,6 +1047,7 @@ void kernel_mul_mv_q8_0_f32_impl(
1046
  constant int64_t & ne1,
1047
  constant uint & r2,
1048
  constant uint & r3,
 
1049
  uint3 tgpig[[threadgroup_position_in_grid]],
1050
  uint tiisg[[thread_index_in_simdgroup]],
1051
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1126,7 +1128,7 @@ kernel void kernel_mul_mv_q8_0_f32(
1126
  uint3 tgpig[[threadgroup_position_in_grid]],
1127
  uint tiisg[[thread_index_in_simdgroup]],
1128
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1129
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
1130
  }
1131
 
1132
  #define N_F32_F32 4
@@ -2716,6 +2718,7 @@ void kernel_mul_mv_q2_K_f32_impl(
2716
  constant int64_t & ne1,
2717
  constant uint & r2,
2718
  constant uint & r3,
 
2719
  uint3 tgpig[[threadgroup_position_in_grid]],
2720
  uint tiisg[[thread_index_in_simdgroup]],
2721
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2878,7 +2881,7 @@ kernel void kernel_mul_mv_q2_K_f32(
2878
  uint tiisg[[thread_index_in_simdgroup]],
2879
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2880
 
2881
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
2882
  }
2883
 
2884
  #if QK_K == 256
@@ -2895,6 +2898,7 @@ void kernel_mul_mv_q3_K_f32_impl(
2895
  constant int64_t & ne1,
2896
  constant uint & r2,
2897
  constant uint & r3,
 
2898
  uint3 tgpig[[threadgroup_position_in_grid]],
2899
  uint tiisg[[thread_index_in_simdgroup]],
2900
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3053,6 +3057,7 @@ void kernel_mul_mv_q3_K_f32_impl(
3053
  constant int64_t & ne1,
3054
  constant uint & r2,
3055
  constant uint & r3,
 
3056
  uint3 tgpig[[threadgroup_position_in_grid]],
3057
  uint tiisg[[thread_index_in_simdgroup]],
3058
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3142,7 +3147,7 @@ kernel void kernel_mul_mv_q3_K_f32(
3142
  uint tiisg[[thread_index_in_simdgroup]],
3143
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3144
 
3145
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3146
  }
3147
 
3148
  #if QK_K == 256
@@ -3159,6 +3164,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3159
  constant int64_t & ne1,
3160
  constant uint & r2,
3161
  constant uint & r3,
 
3162
  uint3 tgpig[[threadgroup_position_in_grid]],
3163
  uint tiisg[[thread_index_in_simdgroup]],
3164
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3272,6 +3278,7 @@ void kernel_mul_mv_q4_K_f32_impl(
3272
  constant int64_t & ne1,
3273
  constant uint & r2,
3274
  constant uint & r3,
 
3275
  uint3 tgpig[[threadgroup_position_in_grid]],
3276
  uint tiisg[[thread_index_in_simdgroup]],
3277
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3380,7 +3387,7 @@ kernel void kernel_mul_mv_q4_K_f32(
3380
  uint tiisg[[thread_index_in_simdgroup]],
3381
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3382
 
3383
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3384
  }
3385
 
3386
  void kernel_mul_mv_q5_K_f32_impl(
@@ -3396,6 +3403,7 @@ void kernel_mul_mv_q5_K_f32_impl(
3396
  constant int64_t & ne1,
3397
  constant uint & r2,
3398
  constant uint & r3,
 
3399
  uint3 tgpig[[threadgroup_position_in_grid]],
3400
  uint tiisg[[thread_index_in_simdgroup]],
3401
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3586,7 +3594,7 @@ kernel void kernel_mul_mv_q5_K_f32(
3586
  uint tiisg[[thread_index_in_simdgroup]],
3587
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3588
 
3589
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3590
  }
3591
 
3592
  void kernel_mul_mv_q6_K_f32_impl(
@@ -3602,6 +3610,7 @@ void kernel_mul_mv_q6_K_f32_impl(
3602
  constant int64_t & ne1,
3603
  constant uint & r2,
3604
  constant uint & r3,
 
3605
  uint3 tgpig[[threadgroup_position_in_grid]],
3606
  uint tiisg[[thread_index_in_simdgroup]],
3607
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3720,7 +3729,7 @@ kernel void kernel_mul_mv_q6_K_f32(
3720
  uint tiisg[[thread_index_in_simdgroup]],
3721
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3722
 
3723
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3724
  }
3725
 
3726
  // ======================= "True" 2-bit
@@ -4403,6 +4412,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
4403
  constant int64_t & ne1,
4404
  constant uint & r2,
4405
  constant uint & r3,
 
4406
  uint3 tgpig[[threadgroup_position_in_grid]],
4407
  uint tiisg[[thread_index_in_simdgroup]],
4408
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4492,6 +4502,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
4492
  constant int64_t & ne1,
4493
  constant uint & r2,
4494
  constant uint & r3,
 
4495
  uint3 tgpig[[threadgroup_position_in_grid]],
4496
  uint tiisg[[thread_index_in_simdgroup]],
4497
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4600,11 +4611,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
4600
  constant int64_t & ne1,
4601
  constant uint & r2,
4602
  constant uint & r3,
4603
- threadgroup float * shared_values [[threadgroup(0)]],
4604
  uint3 tgpig[[threadgroup_position_in_grid]],
4605
  uint tiisg[[thread_index_in_simdgroup]],
4606
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4607
 
 
4608
  const int nb = ne00/QK4_NL;
4609
  const int r0 = tgpig.x;
4610
  const int r1 = tgpig.y;
@@ -4694,11 +4706,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
4694
  constant int64_t & ne1,
4695
  constant uint & r2,
4696
  constant uint & r3,
4697
- threadgroup float * shared_values [[threadgroup(0)]],
4698
  uint3 tgpig[[threadgroup_position_in_grid]],
4699
  uint tiisg[[thread_index_in_simdgroup]],
4700
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4701
-
4702
  const int nb = ne00/QK_K;
4703
  const int r0 = tgpig.x;
4704
  const int r1 = tgpig.y;
@@ -4801,7 +4813,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
4801
  uint tiisg[[thread_index_in_simdgroup]],
4802
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4803
 
4804
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4805
  }
4806
 
4807
  [[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -4829,7 +4841,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
4829
  uint tiisg[[thread_index_in_simdgroup]],
4830
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4831
 
4832
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4833
  }
4834
 
4835
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -4853,7 +4865,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
4853
  constant int64_t & ne1,
4854
  constant uint & r2,
4855
  constant uint & r3,
4856
- threadgroup float * shared_values [[threadgroup(0)]],
4857
  uint3 tgpig[[threadgroup_position_in_grid]],
4858
  uint tiisg[[thread_index_in_simdgroup]],
4859
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4882,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
4882
  constant int64_t & ne1,
4883
  constant uint & r2,
4884
  constant uint & r3,
4885
- threadgroup float * shared_values [[threadgroup(0)]],
4886
  uint3 tgpig[[threadgroup_position_in_grid]],
4887
  uint tiisg[[thread_index_in_simdgroup]],
4888
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -6029,135 +6041,52 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
6029
  // matrix-vector multiplication
6030
  //
6031
 
6032
- [[host_name("kernel_mul_mv_id_f32_f32")]]
6033
- kernel void kernel_mul_mv_id_f32_f32(
6034
- device const char * src0s,
6035
- device const char * src1,
6036
- device float * dst,
6037
- device const char * ids,
6038
- constant uint64_t & nbi1,
6039
- constant int64_t & ne00,
6040
- constant int64_t & ne01,
6041
- constant int64_t & ne02,
6042
- constant uint64_t & nb00,
6043
- constant uint64_t & nb01,
6044
- constant uint64_t & nb02,
6045
- constant int64_t & ne10,
6046
- constant int64_t & ne11,
6047
- constant int64_t & ne12,
6048
- constant int64_t & ne13,
6049
- constant uint64_t & nb10,
6050
- constant uint64_t & nb11,
6051
- constant uint64_t & nb12,
6052
- constant int64_t & ne0,
6053
- constant int64_t & ne1,
6054
- constant uint64_t & nb1,
6055
- constant uint & r2,
6056
- constant uint & r3,
6057
- constant int & idx,
6058
- uint3 tgpig[[threadgroup_position_in_grid]],
6059
- uint tiitg[[thread_index_in_threadgroup]],
6060
- uint tiisg[[thread_index_in_simdgroup]],
6061
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6062
- const int64_t bid = tgpig.z/(ne12*ne13);
6063
-
6064
- tgpig.z = tgpig.z%(ne12*ne13);
6065
-
6066
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6067
- device const char * src0 = src0s + id*nb02;
6068
-
6069
- kernel_mul_mv_f32_f32_impl(
6070
- src0,
6071
- src1 + bid*nb11,
6072
- dst + bid*ne0,
6073
- ne00,
6074
- ne01,
6075
- ne02,
6076
- nb00,
6077
- nb01,
6078
- nb02,
6079
- ne10,
6080
- ne11,
6081
- ne12,
6082
- nb10,
6083
- nb11,
6084
- nb12,
6085
- ne0,
6086
- ne1,
6087
- r2,
6088
- r3,
6089
- tgpig,
6090
- tiisg);
6091
- }
6092
-
6093
- [[host_name("kernel_mul_mv_id_f16_f32")]]
6094
- kernel void kernel_mul_mv_id_f16_f32(
6095
- device const char * src0s,
6096
- device const char * src1,
6097
- device float * dst,
6098
- device const char * ids,
6099
- constant uint64_t & nbi1,
6100
- constant int64_t & ne00,
6101
- constant int64_t & ne01,
6102
- constant int64_t & ne02,
6103
- constant uint64_t & nb00,
6104
- constant uint64_t & nb01,
6105
- constant uint64_t & nb02,
6106
- constant int64_t & ne10,
6107
- constant int64_t & ne11,
6108
- constant int64_t & ne12,
6109
- constant int64_t & ne13,
6110
- constant uint64_t & nb10,
6111
- constant uint64_t & nb11,
6112
- constant uint64_t & nb12,
6113
- constant int64_t & ne0,
6114
- constant int64_t & ne1,
6115
- constant uint64_t & nb1,
6116
- constant uint & r2,
6117
- constant uint & r3,
6118
- constant int & idx,
6119
- uint3 tgpig[[threadgroup_position_in_grid]],
6120
- uint tiitg[[thread_index_in_threadgroup]],
6121
- uint tiisg[[thread_index_in_simdgroup]],
6122
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6123
- const int64_t bid = tgpig.z/(ne12*ne13);
6124
-
6125
- tgpig.z = tgpig.z%(ne12*ne13);
6126
-
6127
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6128
- device const char * src0 = src0s + id*nb02;
6129
 
6130
- kernel_mul_mv_f16_f32_impl(
6131
- src0,
6132
- src1 + bid*nb11,
6133
- dst + bid*ne0,
6134
- ne00,
6135
- ne01,
6136
- ne02,
6137
- nb00,
6138
- nb01,
6139
- nb02,
6140
- ne10,
6141
- ne11,
6142
- ne12,
6143
- nb10,
6144
- nb11,
6145
- nb12,
6146
- ne0,
6147
- ne1,
6148
- r2,
6149
- r3,
6150
- tgpig,
6151
- tiisg);
6152
- }
6153
 
6154
- [[host_name("kernel_mul_mv_id_q8_0_f32")]]
6155
- kernel void kernel_mul_mv_id_q8_0_f32(
6156
- device const char * src0s,
6157
  device const char * src1,
6158
  device float * dst,
6159
- device const char * ids,
6160
- constant uint64_t & nbi1,
6161
  constant int64_t & ne00,
6162
  constant int64_t & ne01,
6163
  constant int64_t & ne02,
@@ -6176,43 +6105,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
6176
  constant uint64_t & nb1,
6177
  constant uint & r2,
6178
  constant uint & r3,
6179
- constant int & idx,
6180
  uint3 tgpig[[threadgroup_position_in_grid]],
6181
  uint tiitg[[thread_index_in_threadgroup]],
6182
  uint tiisg[[thread_index_in_simdgroup]],
6183
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6184
- const int64_t bid = tgpig.z/(ne12*ne13);
6185
-
6186
- tgpig.z = tgpig.z%(ne12*ne13);
6187
-
6188
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6189
- device const char * src0 = src0s + id*nb02;
6190
-
6191
- kernel_mul_mv_q8_0_f32_impl(
6192
- src0,
6193
- (device const float *) (src1 + bid*nb11),
6194
- dst + bid*ne0,
6195
- ne00,
6196
- ne01,
6197
- ne02,
6198
- ne10,
6199
- ne12,
6200
- ne0,
6201
- ne1,
6202
- r2,
6203
- r3,
6204
- tgpig,
6205
- tiisg,
6206
- sgitg);
6207
  }
6208
 
6209
- [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6210
- kernel void kernel_mul_mv_id_q4_0_f32(
6211
- device const char * src0s,
6212
  device const char * src1,
6213
  device float * dst,
6214
- device const char * ids,
6215
- constant uint64_t & nbi1,
6216
  constant int64_t & ne00,
6217
  constant int64_t & ne01,
6218
  constant int64_t & ne02,
@@ -6231,43 +6136,18 @@ kernel void kernel_mul_mv_id_q4_0_f32(
6231
  constant uint64_t & nb1,
6232
  constant uint & r2,
6233
  constant uint & r3,
6234
- constant int & idx,
6235
  uint3 tgpig[[threadgroup_position_in_grid]],
6236
  uint tiitg[[thread_index_in_threadgroup]],
6237
  uint tiisg[[thread_index_in_simdgroup]],
6238
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6239
- const int64_t bid = tgpig.z/(ne12*ne13);
6240
-
6241
- tgpig.z = tgpig.z%(ne12*ne13);
6242
-
6243
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6244
- device const char * src0 = src0s + id*nb02;
6245
-
6246
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6247
- src0,
6248
- (device const float *) (src1 + bid*nb11),
6249
- dst + bid*ne0,
6250
- ne00,
6251
- ne01,
6252
- ne02,
6253
- ne10,
6254
- ne12,
6255
- ne0,
6256
- ne1,
6257
- r2,
6258
- r3,
6259
- tgpig,
6260
- tiisg,
6261
- sgitg);
6262
  }
6263
 
6264
- [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6265
- kernel void kernel_mul_mv_id_q4_1_f32(
6266
- device const char * src0s,
6267
  device const char * src1,
6268
  device float * dst,
6269
- device const char * ids,
6270
- constant uint64_t & nbi1,
6271
  constant int64_t & ne00,
6272
  constant int64_t & ne01,
6273
  constant int64_t & ne02,
@@ -6286,38 +6166,14 @@ kernel void kernel_mul_mv_id_q4_1_f32(
6286
  constant uint64_t & nb1,
6287
  constant uint & r2,
6288
  constant uint & r3,
6289
- constant int & idx,
6290
  uint3 tgpig[[threadgroup_position_in_grid]],
6291
  uint tiitg[[thread_index_in_threadgroup]],
6292
  uint tiisg[[thread_index_in_simdgroup]],
6293
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6294
- const int64_t bid = tgpig.z/(ne12*ne13);
6295
-
6296
- tgpig.z = tgpig.z%(ne12*ne13);
6297
-
6298
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6299
- device const char * src0 = src0s + id*nb02;
6300
-
6301
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6302
- src0,
6303
- (device const float *) (src1 + bid*nb11),
6304
- dst + bid*ne0,
6305
- ne00,
6306
- ne01,
6307
- ne02,
6308
- ne10,
6309
- ne12,
6310
- ne0,
6311
- ne1,
6312
- r2,
6313
- r3,
6314
- tgpig,
6315
- tiisg,
6316
- sgitg);
6317
- }
6318
 
6319
- [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6320
- kernel void kernel_mul_mv_id_q5_0_f32(
6321
  device const char * src0s,
6322
  device const char * src1,
6323
  device float * dst,
@@ -6342,6 +6198,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6342
  constant uint & r2,
6343
  constant uint & r3,
6344
  constant int & idx,
 
6345
  uint3 tgpig[[threadgroup_position_in_grid]],
6346
  uint tiitg[[thread_index_in_threadgroup]],
6347
  uint tiisg[[thread_index_in_simdgroup]],
@@ -6353,26 +6210,36 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6353
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6354
  device const char * src0 = src0s + id*nb02;
6355
 
6356
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6357
  src0,
6358
- (device const float *) (src1 + bid*nb11),
6359
- dst + bid*ne0,
6360
  ne00,
6361
  ne01,
6362
  ne02,
 
 
 
6363
  ne10,
 
6364
  ne12,
 
 
 
 
6365
  ne0,
6366
  ne1,
 
6367
  r2,
6368
  r3,
 
6369
  tgpig,
 
6370
  tiisg,
6371
  sgitg);
6372
  }
6373
 
6374
- [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6375
- kernel void kernel_mul_mv_id_q5_1_f32(
6376
  device const char * src0s,
6377
  device const char * src1,
6378
  device float * dst,
@@ -6397,819 +6264,33 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6397
  constant uint & r2,
6398
  constant uint & r3,
6399
  constant int & idx,
 
6400
  uint3 tgpig[[threadgroup_position_in_grid]],
6401
  uint tiitg[[thread_index_in_threadgroup]],
6402
  uint tiisg[[thread_index_in_simdgroup]],
6403
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6404
- const int64_t bid = tgpig.z/(ne12*ne13);
6405
-
6406
- tgpig.z = tgpig.z%(ne12*ne13);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6407
 
6408
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6409
- device const char * src0 = src0s + id*nb02;
6410
-
6411
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6412
- src0,
6413
- (device const float *) (src1 + bid*nb11),
6414
- dst + bid*ne0,
6415
- ne00,
6416
- ne01,
6417
- ne02,
6418
- ne10,
6419
- ne12,
6420
- ne0,
6421
- ne1,
6422
- r2,
6423
- r3,
6424
- tgpig,
6425
- tiisg,
6426
- sgitg);
6427
- }
6428
-
6429
- [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6430
- kernel void kernel_mul_mv_id_q2_K_f32(
6431
- device const char * src0s,
6432
- device const char * src1,
6433
- device float * dst,
6434
- device const char * ids,
6435
- constant uint64_t & nbi1,
6436
- constant int64_t & ne00,
6437
- constant int64_t & ne01,
6438
- constant int64_t & ne02,
6439
- constant uint64_t & nb00,
6440
- constant uint64_t & nb01,
6441
- constant uint64_t & nb02,
6442
- constant int64_t & ne10,
6443
- constant int64_t & ne11,
6444
- constant int64_t & ne12,
6445
- constant int64_t & ne13,
6446
- constant uint64_t & nb10,
6447
- constant uint64_t & nb11,
6448
- constant uint64_t & nb12,
6449
- constant int64_t & ne0,
6450
- constant int64_t & ne1,
6451
- constant uint64_t & nb1,
6452
- constant uint & r2,
6453
- constant uint & r3,
6454
- constant int & idx,
6455
- uint3 tgpig[[threadgroup_position_in_grid]],
6456
- uint tiitg[[thread_index_in_threadgroup]],
6457
- uint tiisg[[thread_index_in_simdgroup]],
6458
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6459
- const int64_t bid = tgpig.z/(ne12*ne13);
6460
-
6461
- tgpig.z = tgpig.z%(ne12*ne13);
6462
-
6463
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6464
- device const char * src0 = src0s + id*nb02;
6465
-
6466
- kernel_mul_mv_q2_K_f32_impl(
6467
- src0,
6468
- (device const float *) (src1 + bid*nb11),
6469
- dst + bid*ne0,
6470
- ne00,
6471
- ne01,
6472
- ne02,
6473
- ne10,
6474
- ne12,
6475
- ne0,
6476
- ne1,
6477
- r2,
6478
- r3,
6479
- tgpig,
6480
- tiisg,
6481
- sgitg);
6482
- }
6483
-
6484
- [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6485
- kernel void kernel_mul_mv_id_q3_K_f32(
6486
- device const char * src0s,
6487
- device const char * src1,
6488
- device float * dst,
6489
- device const char * ids,
6490
- constant uint64_t & nbi1,
6491
- constant int64_t & ne00,
6492
- constant int64_t & ne01,
6493
- constant int64_t & ne02,
6494
- constant uint64_t & nb00,
6495
- constant uint64_t & nb01,
6496
- constant uint64_t & nb02,
6497
- constant int64_t & ne10,
6498
- constant int64_t & ne11,
6499
- constant int64_t & ne12,
6500
- constant int64_t & ne13,
6501
- constant uint64_t & nb10,
6502
- constant uint64_t & nb11,
6503
- constant uint64_t & nb12,
6504
- constant int64_t & ne0,
6505
- constant int64_t & ne1,
6506
- constant uint64_t & nb1,
6507
- constant uint & r2,
6508
- constant uint & r3,
6509
- constant int & idx,
6510
- uint3 tgpig[[threadgroup_position_in_grid]],
6511
- uint tiitg[[thread_index_in_threadgroup]],
6512
- uint tiisg[[thread_index_in_simdgroup]],
6513
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6514
- const int64_t bid = tgpig.z/(ne12*ne13);
6515
-
6516
- tgpig.z = tgpig.z%(ne12*ne13);
6517
-
6518
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6519
- device const char * src0 = src0s + id*nb02;
6520
-
6521
- kernel_mul_mv_q3_K_f32_impl(
6522
- src0,
6523
- (device const float *) (src1 + bid*nb11),
6524
- dst + bid*ne0,
6525
- ne00,
6526
- ne01,
6527
- ne02,
6528
- ne10,
6529
- ne12,
6530
- ne0,
6531
- ne1,
6532
- r2,
6533
- r3,
6534
- tgpig,
6535
- tiisg,
6536
- sgitg);
6537
- }
6538
-
6539
- [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6540
- kernel void kernel_mul_mv_id_q4_K_f32(
6541
- device const char * src0s,
6542
- device const char * src1,
6543
- device float * dst,
6544
- device const char * ids,
6545
- constant uint64_t & nbi1,
6546
- constant int64_t & ne00,
6547
- constant int64_t & ne01,
6548
- constant int64_t & ne02,
6549
- constant uint64_t & nb00,
6550
- constant uint64_t & nb01,
6551
- constant uint64_t & nb02,
6552
- constant int64_t & ne10,
6553
- constant int64_t & ne11,
6554
- constant int64_t & ne12,
6555
- constant int64_t & ne13,
6556
- constant uint64_t & nb10,
6557
- constant uint64_t & nb11,
6558
- constant uint64_t & nb12,
6559
- constant int64_t & ne0,
6560
- constant int64_t & ne1,
6561
- constant uint64_t & nb1,
6562
- constant uint & r2,
6563
- constant uint & r3,
6564
- constant int & idx,
6565
- uint3 tgpig[[threadgroup_position_in_grid]],
6566
- uint tiitg[[thread_index_in_threadgroup]],
6567
- uint tiisg[[thread_index_in_simdgroup]],
6568
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6569
- const int64_t bid = tgpig.z/(ne12*ne13);
6570
-
6571
- tgpig.z = tgpig.z%(ne12*ne13);
6572
-
6573
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6574
- device const char * src0 = src0s + id*nb02;
6575
-
6576
- kernel_mul_mv_q4_K_f32_impl(
6577
- src0,
6578
- (device const float *) (src1 + bid*nb11),
6579
- dst + bid*ne0,
6580
- ne00,
6581
- ne01,
6582
- ne02,
6583
- ne10,
6584
- ne12,
6585
- ne0,
6586
- ne1,
6587
- r2,
6588
- r3,
6589
- tgpig,
6590
- tiisg,
6591
- sgitg);
6592
- }
6593
-
6594
- [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6595
- kernel void kernel_mul_mv_id_q5_K_f32(
6596
- device const char * src0s,
6597
- device const char * src1,
6598
- device float * dst,
6599
- device const char * ids,
6600
- constant uint64_t & nbi1,
6601
- constant int64_t & ne00,
6602
- constant int64_t & ne01,
6603
- constant int64_t & ne02,
6604
- constant uint64_t & nb00,
6605
- constant uint64_t & nb01,
6606
- constant uint64_t & nb02,
6607
- constant int64_t & ne10,
6608
- constant int64_t & ne11,
6609
- constant int64_t & ne12,
6610
- constant int64_t & ne13,
6611
- constant uint64_t & nb10,
6612
- constant uint64_t & nb11,
6613
- constant uint64_t & nb12,
6614
- constant int64_t & ne0,
6615
- constant int64_t & ne1,
6616
- constant uint64_t & nb1,
6617
- constant uint & r2,
6618
- constant uint & r3,
6619
- constant int & idx,
6620
- uint3 tgpig[[threadgroup_position_in_grid]],
6621
- uint tiitg[[thread_index_in_threadgroup]],
6622
- uint tiisg[[thread_index_in_simdgroup]],
6623
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6624
- const int64_t bid = tgpig.z/(ne12*ne13);
6625
-
6626
- tgpig.z = tgpig.z%(ne12*ne13);
6627
-
6628
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6629
- device const char * src0 = src0s + id*nb02;
6630
-
6631
- kernel_mul_mv_q5_K_f32_impl(
6632
- src0,
6633
- (device const float *) (src1 + bid*nb11),
6634
- dst + bid*ne0,
6635
- ne00,
6636
- ne01,
6637
- ne02,
6638
- ne10,
6639
- ne12,
6640
- ne0,
6641
- ne1,
6642
- r2,
6643
- r3,
6644
- tgpig,
6645
- tiisg,
6646
- sgitg);
6647
- }
6648
-
6649
- [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6650
- kernel void kernel_mul_mv_id_q6_K_f32(
6651
- device const char * src0s,
6652
- device const char * src1,
6653
- device float * dst,
6654
- device const char * ids,
6655
- constant uint64_t & nbi1,
6656
- constant int64_t & ne00,
6657
- constant int64_t & ne01,
6658
- constant int64_t & ne02,
6659
- constant uint64_t & nb00,
6660
- constant uint64_t & nb01,
6661
- constant uint64_t & nb02,
6662
- constant int64_t & ne10,
6663
- constant int64_t & ne11,
6664
- constant int64_t & ne12,
6665
- constant int64_t & ne13,
6666
- constant uint64_t & nb10,
6667
- constant uint64_t & nb11,
6668
- constant uint64_t & nb12,
6669
- constant int64_t & ne0,
6670
- constant int64_t & ne1,
6671
- constant uint64_t & nb1,
6672
- constant uint & r2,
6673
- constant uint & r3,
6674
- constant int & idx,
6675
- uint3 tgpig[[threadgroup_position_in_grid]],
6676
- uint tiitg[[thread_index_in_threadgroup]],
6677
- uint tiisg[[thread_index_in_simdgroup]],
6678
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6679
- const int64_t bid = tgpig.z/(ne12*ne13);
6680
-
6681
- tgpig.z = tgpig.z%(ne12*ne13);
6682
-
6683
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6684
- device const char * src0 = src0s + id*nb02;
6685
-
6686
- kernel_mul_mv_q6_K_f32_impl(
6687
- src0,
6688
- (device const float *) (src1 + bid*nb11),
6689
- dst + bid*ne0,
6690
- ne00,
6691
- ne01,
6692
- ne02,
6693
- ne10,
6694
- ne12,
6695
- ne0,
6696
- ne1,
6697
- r2,
6698
- r3,
6699
- tgpig,
6700
- tiisg,
6701
- sgitg);
6702
- }
6703
-
6704
- [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6705
- kernel void kernel_mul_mv_id_iq2_xxs_f32(
6706
- device const char * src0s,
6707
- device const char * src1,
6708
- device float * dst,
6709
- device const char * ids,
6710
- constant uint64_t & nbi1,
6711
- constant int64_t & ne00,
6712
- constant int64_t & ne01,
6713
- constant int64_t & ne02,
6714
- constant uint64_t & nb00,
6715
- constant uint64_t & nb01,
6716
- constant uint64_t & nb02,
6717
- constant int64_t & ne10,
6718
- constant int64_t & ne11,
6719
- constant int64_t & ne12,
6720
- constant int64_t & ne13,
6721
- constant uint64_t & nb10,
6722
- constant uint64_t & nb11,
6723
- constant uint64_t & nb12,
6724
- constant int64_t & ne0,
6725
- constant int64_t & ne1,
6726
- constant uint64_t & nb1,
6727
- constant uint & r2,
6728
- constant uint & r3,
6729
- constant int & idx,
6730
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6731
- uint3 tgpig[[threadgroup_position_in_grid]],
6732
- uint tiitg[[thread_index_in_threadgroup]],
6733
- uint tiisg[[thread_index_in_simdgroup]],
6734
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6735
- const int64_t bid = tgpig.z/(ne12*ne13);
6736
-
6737
- tgpig.z = tgpig.z%(ne12*ne13);
6738
-
6739
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6740
- device const char * src0 = src0s + id*nb02;
6741
-
6742
- kernel_mul_mv_iq2_xxs_f32_impl(
6743
- src0,
6744
- (device const float *) (src1 + bid*nb11),
6745
- dst + bid*ne0,
6746
- ne00,
6747
- ne01,
6748
- ne02,
6749
- ne10,
6750
- ne12,
6751
- ne0,
6752
- ne1,
6753
- r2,
6754
- r3,
6755
- shared_values,
6756
- tgpig,
6757
- tiisg,
6758
- sgitg);
6759
- }
6760
-
6761
- [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6762
- kernel void kernel_mul_mv_id_iq2_xs_f32(
6763
- device const char * src0s,
6764
- device const char * src1,
6765
- device float * dst,
6766
- device const char * ids,
6767
- constant uint64_t & nbi1,
6768
- constant int64_t & ne00,
6769
- constant int64_t & ne01,
6770
- constant int64_t & ne02,
6771
- constant uint64_t & nb00,
6772
- constant uint64_t & nb01,
6773
- constant uint64_t & nb02,
6774
- constant int64_t & ne10,
6775
- constant int64_t & ne11,
6776
- constant int64_t & ne12,
6777
- constant int64_t & ne13,
6778
- constant uint64_t & nb10,
6779
- constant uint64_t & nb11,
6780
- constant uint64_t & nb12,
6781
- constant int64_t & ne0,
6782
- constant int64_t & ne1,
6783
- constant uint64_t & nb1,
6784
- constant uint & r2,
6785
- constant uint & r3,
6786
- constant int & idx,
6787
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6788
- uint3 tgpig[[threadgroup_position_in_grid]],
6789
- uint tiitg[[thread_index_in_threadgroup]],
6790
- uint tiisg[[thread_index_in_simdgroup]],
6791
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6792
- const int64_t bid = tgpig.z/(ne12*ne13);
6793
-
6794
- tgpig.z = tgpig.z%(ne12*ne13);
6795
-
6796
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6797
- device const char * src0 = src0s + id*nb02;
6798
-
6799
- kernel_mul_mv_iq2_xs_f32_impl(
6800
- src0,
6801
- (device const float *) (src1 + bid*nb11),
6802
- dst + bid*ne0,
6803
- ne00,
6804
- ne01,
6805
- ne02,
6806
- ne10,
6807
- ne12,
6808
- ne0,
6809
- ne1,
6810
- r2,
6811
- r3,
6812
- shared_values,
6813
- tgpig,
6814
- tiisg,
6815
- sgitg);
6816
- }
6817
-
6818
- [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6819
- kernel void kernel_mul_mv_id_iq3_xxs_f32(
6820
- device const char * src0s,
6821
- device const char * src1,
6822
- device float * dst,
6823
- device const char * ids,
6824
- constant uint64_t & nbi1,
6825
- constant int64_t & ne00,
6826
- constant int64_t & ne01,
6827
- constant int64_t & ne02,
6828
- constant uint64_t & nb00,
6829
- constant uint64_t & nb01,
6830
- constant uint64_t & nb02,
6831
- constant int64_t & ne10,
6832
- constant int64_t & ne11,
6833
- constant int64_t & ne12,
6834
- constant int64_t & ne13,
6835
- constant uint64_t & nb10,
6836
- constant uint64_t & nb11,
6837
- constant uint64_t & nb12,
6838
- constant int64_t & ne0,
6839
- constant int64_t & ne1,
6840
- constant uint64_t & nb1,
6841
- constant uint & r2,
6842
- constant uint & r3,
6843
- constant int & idx,
6844
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6845
- uint3 tgpig[[threadgroup_position_in_grid]],
6846
- uint tiitg[[thread_index_in_threadgroup]],
6847
- uint tiisg[[thread_index_in_simdgroup]],
6848
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6849
- const int64_t bid = tgpig.z/(ne12*ne13);
6850
-
6851
- tgpig.z = tgpig.z%(ne12*ne13);
6852
-
6853
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6854
- device const char * src0 = src0s + id*nb02;
6855
-
6856
- kernel_mul_mv_iq3_xxs_f32_impl(
6857
- src0,
6858
- (device const float *) (src1 + bid*nb11),
6859
- dst + bid*ne0,
6860
- ne00,
6861
- ne01,
6862
- ne02,
6863
- ne10,
6864
- ne12,
6865
- ne0,
6866
- ne1,
6867
- r2,
6868
- r3,
6869
- shared_values,
6870
- tgpig,
6871
- tiisg,
6872
- sgitg);
6873
- }
6874
-
6875
- [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6876
- kernel void kernel_mul_mv_id_iq3_s_f32(
6877
- device const char * src0s,
6878
- device const char * src1,
6879
- device float * dst,
6880
- device const char * ids,
6881
- constant uint64_t & nbi1,
6882
- constant int64_t & ne00,
6883
- constant int64_t & ne01,
6884
- constant int64_t & ne02,
6885
- constant uint64_t & nb00,
6886
- constant uint64_t & nb01,
6887
- constant uint64_t & nb02,
6888
- constant int64_t & ne10,
6889
- constant int64_t & ne11,
6890
- constant int64_t & ne12,
6891
- constant int64_t & ne13,
6892
- constant uint64_t & nb10,
6893
- constant uint64_t & nb11,
6894
- constant uint64_t & nb12,
6895
- constant int64_t & ne0,
6896
- constant int64_t & ne1,
6897
- constant uint64_t & nb1,
6898
- constant uint & r2,
6899
- constant uint & r3,
6900
- constant int & idx,
6901
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6902
- uint3 tgpig[[threadgroup_position_in_grid]],
6903
- uint tiitg[[thread_index_in_threadgroup]],
6904
- uint tiisg[[thread_index_in_simdgroup]],
6905
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6906
- const int64_t bid = tgpig.z/(ne12*ne13);
6907
-
6908
- tgpig.z = tgpig.z%(ne12*ne13);
6909
-
6910
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6911
- device const char * src0 = src0s + id*nb02;
6912
-
6913
- kernel_mul_mv_iq3_s_f32_impl(
6914
- src0,
6915
- (device const float *) (src1 + bid*nb11),
6916
- dst + bid*ne0,
6917
- ne00,
6918
- ne01,
6919
- ne02,
6920
- ne10,
6921
- ne12,
6922
- ne0,
6923
- ne1,
6924
- r2,
6925
- r3,
6926
- shared_values,
6927
- tgpig,
6928
- tiisg,
6929
- sgitg);
6930
- }
6931
-
6932
- [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6933
- kernel void kernel_mul_mv_id_iq2_s_f32(
6934
- device const char * src0s,
6935
- device const char * src1,
6936
- device float * dst,
6937
- device const char * ids,
6938
- constant uint64_t & nbi1,
6939
- constant int64_t & ne00,
6940
- constant int64_t & ne01,
6941
- constant int64_t & ne02,
6942
- constant uint64_t & nb00,
6943
- constant uint64_t & nb01,
6944
- constant uint64_t & nb02,
6945
- constant int64_t & ne10,
6946
- constant int64_t & ne11,
6947
- constant int64_t & ne12,
6948
- constant int64_t & ne13,
6949
- constant uint64_t & nb10,
6950
- constant uint64_t & nb11,
6951
- constant uint64_t & nb12,
6952
- constant int64_t & ne0,
6953
- constant int64_t & ne1,
6954
- constant uint64_t & nb1,
6955
- constant uint & r2,
6956
- constant uint & r3,
6957
- constant int & idx,
6958
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6959
- uint3 tgpig[[threadgroup_position_in_grid]],
6960
- uint tiitg[[thread_index_in_threadgroup]],
6961
- uint tiisg[[thread_index_in_simdgroup]],
6962
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6963
- const int64_t bid = tgpig.z/(ne12*ne13);
6964
-
6965
- tgpig.z = tgpig.z%(ne12*ne13);
6966
-
6967
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6968
- device const char * src0 = src0s + id*nb02;
6969
-
6970
- kernel_mul_mv_iq2_s_f32_impl(
6971
- src0,
6972
- (device const float *) (src1 + bid*nb11),
6973
- dst + bid*ne0,
6974
- ne00,
6975
- ne01,
6976
- ne02,
6977
- ne10,
6978
- ne12,
6979
- ne0,
6980
- ne1,
6981
- r2,
6982
- r3,
6983
- shared_values,
6984
- tgpig,
6985
- tiisg,
6986
- sgitg);
6987
- }
6988
-
6989
- [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6990
- kernel void kernel_mul_mv_id_iq1_s_f32(
6991
- device const char * src0s,
6992
- device const char * src1,
6993
- device float * dst,
6994
- device const char * ids,
6995
- constant uint64_t & nbi1,
6996
- constant int64_t & ne00,
6997
- constant int64_t & ne01,
6998
- constant int64_t & ne02,
6999
- constant uint64_t & nb00,
7000
- constant uint64_t & nb01,
7001
- constant uint64_t & nb02,
7002
- constant int64_t & ne10,
7003
- constant int64_t & ne11,
7004
- constant int64_t & ne12,
7005
- constant int64_t & ne13,
7006
- constant uint64_t & nb10,
7007
- constant uint64_t & nb11,
7008
- constant uint64_t & nb12,
7009
- constant int64_t & ne0,
7010
- constant int64_t & ne1,
7011
- constant uint64_t & nb1,
7012
- constant uint & r2,
7013
- constant uint & r3,
7014
- constant int & idx,
7015
- uint3 tgpig[[threadgroup_position_in_grid]],
7016
- uint tiitg[[thread_index_in_threadgroup]],
7017
- uint tiisg[[thread_index_in_simdgroup]],
7018
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7019
- const int64_t bid = tgpig.z/(ne12*ne13);
7020
-
7021
- tgpig.z = tgpig.z%(ne12*ne13);
7022
-
7023
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7024
- device const char * src0 = src0s + id*nb02;
7025
-
7026
- kernel_mul_mv_iq1_s_f32_impl(
7027
- src0,
7028
- (device const float *) (src1 + bid*nb11),
7029
- dst + bid*ne0,
7030
- ne00,
7031
- ne01,
7032
- ne02,
7033
- ne10,
7034
- ne12,
7035
- ne0,
7036
- ne1,
7037
- r2,
7038
- r3,
7039
- tgpig,
7040
- tiisg,
7041
- sgitg);
7042
- }
7043
-
7044
- [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7045
- kernel void kernel_mul_mv_id_iq1_m_f32(
7046
- device const char * src0s,
7047
- device const char * src1,
7048
- device float * dst,
7049
- device const char * ids,
7050
- constant uint64_t & nbi1,
7051
- constant int64_t & ne00,
7052
- constant int64_t & ne01,
7053
- constant int64_t & ne02,
7054
- constant uint64_t & nb00,
7055
- constant uint64_t & nb01,
7056
- constant uint64_t & nb02,
7057
- constant int64_t & ne10,
7058
- constant int64_t & ne11,
7059
- constant int64_t & ne12,
7060
- constant int64_t & ne13,
7061
- constant uint64_t & nb10,
7062
- constant uint64_t & nb11,
7063
- constant uint64_t & nb12,
7064
- constant int64_t & ne0,
7065
- constant int64_t & ne1,
7066
- constant uint64_t & nb1,
7067
- constant uint & r2,
7068
- constant uint & r3,
7069
- constant int & idx,
7070
- uint3 tgpig[[threadgroup_position_in_grid]],
7071
- uint tiitg[[thread_index_in_threadgroup]],
7072
- uint tiisg[[thread_index_in_simdgroup]],
7073
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7074
- const int64_t bid = tgpig.z/(ne12*ne13);
7075
-
7076
- tgpig.z = tgpig.z%(ne12*ne13);
7077
-
7078
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7079
- device const char * src0 = src0s + id*nb02;
7080
-
7081
- kernel_mul_mv_iq1_m_f32_impl(
7082
- src0,
7083
- (device const float *) (src1 + bid*nb11),
7084
- dst + bid*ne0,
7085
- ne00,
7086
- ne01,
7087
- ne02,
7088
- ne10,
7089
- ne12,
7090
- ne0,
7091
- ne1,
7092
- r2,
7093
- r3,
7094
- tgpig,
7095
- tiisg,
7096
- sgitg);
7097
- }
7098
-
7099
- [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7100
- kernel void kernel_mul_mv_id_iq4_nl_f32(
7101
- device const char * src0s,
7102
- device const char * src1,
7103
- device float * dst,
7104
- device const char * ids,
7105
- constant uint64_t & nbi1,
7106
- constant int64_t & ne00,
7107
- constant int64_t & ne01,
7108
- constant int64_t & ne02,
7109
- constant uint64_t & nb00,
7110
- constant uint64_t & nb01,
7111
- constant uint64_t & nb02,
7112
- constant int64_t & ne10,
7113
- constant int64_t & ne11,
7114
- constant int64_t & ne12,
7115
- constant int64_t & ne13,
7116
- constant uint64_t & nb10,
7117
- constant uint64_t & nb11,
7118
- constant uint64_t & nb12,
7119
- constant int64_t & ne0,
7120
- constant int64_t & ne1,
7121
- constant uint64_t & nb1,
7122
- constant uint & r2,
7123
- constant uint & r3,
7124
- constant int & idx,
7125
- threadgroup float * shared_values [[threadgroup(0)]],
7126
- uint3 tgpig[[threadgroup_position_in_grid]],
7127
- uint tiitg[[thread_index_in_threadgroup]],
7128
- uint tiisg[[thread_index_in_simdgroup]],
7129
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7130
- const int64_t bid = tgpig.z/(ne12*ne13);
7131
-
7132
- tgpig.z = tgpig.z%(ne12*ne13);
7133
-
7134
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7135
- device const char * src0 = src0s + id*nb02;
7136
-
7137
- kernel_mul_mv_iq4_nl_f32_impl(
7138
- src0,
7139
- (device const float *) (src1 + bid*nb11),
7140
- dst + bid*ne0,
7141
- ne00,
7142
- ne01,
7143
- ne02,
7144
- ne10,
7145
- ne12,
7146
- ne0,
7147
- ne1,
7148
- r2,
7149
- r3,
7150
- shared_values,
7151
- tgpig,
7152
- tiisg,
7153
- sgitg);
7154
- }
7155
-
7156
- [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7157
- kernel void kernel_mul_mv_id_iq4_xs_f32(
7158
- device const char * src0s,
7159
- device const char * src1,
7160
- device float * dst,
7161
- device const char * ids,
7162
- constant uint64_t & nbi1,
7163
- constant int64_t & ne00,
7164
- constant int64_t & ne01,
7165
- constant int64_t & ne02,
7166
- constant uint64_t & nb00,
7167
- constant uint64_t & nb01,
7168
- constant uint64_t & nb02,
7169
- constant int64_t & ne10,
7170
- constant int64_t & ne11,
7171
- constant int64_t & ne12,
7172
- constant int64_t & ne13,
7173
- constant uint64_t & nb10,
7174
- constant uint64_t & nb11,
7175
- constant uint64_t & nb12,
7176
- constant int64_t & ne0,
7177
- constant int64_t & ne1,
7178
- constant uint64_t & nb1,
7179
- constant uint & r2,
7180
- constant uint & r3,
7181
- constant int & idx,
7182
- threadgroup float * shared_values [[threadgroup(0)]],
7183
- uint3 tgpig[[threadgroup_position_in_grid]],
7184
- uint tiitg[[thread_index_in_threadgroup]],
7185
- uint tiisg[[thread_index_in_simdgroup]],
7186
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
7187
- const int64_t bid = tgpig.z/(ne12*ne13);
7188
-
7189
- tgpig.z = tgpig.z%(ne12*ne13);
7190
-
7191
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7192
- device const char * src0 = src0s + id*nb02;
7193
-
7194
- #if QK_K == 64
7195
- kernel_mul_mv_iq4_nl_f32_impl(
7196
- #else
7197
- kernel_mul_mv_iq4_xs_f32_impl(
7198
- #endif
7199
- src0,
7200
- (device const float *) (src1 + bid*nb11),
7201
- dst + bid*ne0,
7202
- ne00,
7203
- ne01,
7204
- ne02,
7205
- ne10,
7206
- ne12,
7207
- ne0,
7208
- ne1,
7209
- r2,
7210
- r3,
7211
- shared_values,
7212
- tgpig,
7213
- tiisg,
7214
- sgitg);
7215
- }
 
864
  device const void * src0,
865
  device const float * src1,
866
  device float * dst,
867
+ constant int64_t & ne00,
868
+ constant int64_t & ne01,
869
+ constant int64_t & ne02,
870
+ constant int64_t & ne10,
871
+ constant int64_t & ne12,
872
+ constant int64_t & ne0,
873
+ constant int64_t & ne1,
874
+ constant uint & r2,
875
+ constant uint & r3,
876
+ threadgroup int8_t * shared_values,
877
  uint3 tgpig, uint tiisg, uint sgitg) {
878
  const int nb = ne00/QK4_0;
879
 
 
950
  uint3 tgpig[[threadgroup_position_in_grid]],
951
  uint tiisg[[thread_index_in_simdgroup]],
952
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
953
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
954
  }
955
 
956
  kernel void kernel_mul_mv_q4_1_f32(
 
976
  uint3 tgpig[[threadgroup_position_in_grid]],
977
  uint tiisg[[thread_index_in_simdgroup]],
978
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
979
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
980
  }
981
 
982
  kernel void kernel_mul_mv_q5_0_f32(
 
1002
  uint3 tgpig[[threadgroup_position_in_grid]],
1003
  uint tiisg[[thread_index_in_simdgroup]],
1004
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1005
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1006
  }
1007
 
1008
  kernel void kernel_mul_mv_q5_1_f32(
 
1028
  uint3 tgpig[[threadgroup_position_in_grid]],
1029
  uint tiisg[[thread_index_in_simdgroup]],
1030
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1031
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1032
  }
1033
 
1034
 
 
1047
  constant int64_t & ne1,
1048
  constant uint & r2,
1049
  constant uint & r3,
1050
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
1051
  uint3 tgpig[[threadgroup_position_in_grid]],
1052
  uint tiisg[[thread_index_in_simdgroup]],
1053
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
1128
  uint3 tgpig[[threadgroup_position_in_grid]],
1129
  uint tiisg[[thread_index_in_simdgroup]],
1130
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1131
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1132
  }
1133
 
1134
  #define N_F32_F32 4
 
2718
  constant int64_t & ne1,
2719
  constant uint & r2,
2720
  constant uint & r3,
2721
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
2722
  uint3 tgpig[[threadgroup_position_in_grid]],
2723
  uint tiisg[[thread_index_in_simdgroup]],
2724
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2881
  uint tiisg[[thread_index_in_simdgroup]],
2882
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
2883
 
2884
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
2885
  }
2886
 
2887
  #if QK_K == 256
 
2898
  constant int64_t & ne1,
2899
  constant uint & r2,
2900
  constant uint & r3,
2901
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
2902
  uint3 tgpig[[threadgroup_position_in_grid]],
2903
  uint tiisg[[thread_index_in_simdgroup]],
2904
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3057
  constant int64_t & ne1,
3058
  constant uint & r2,
3059
  constant uint & r3,
3060
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3061
  uint3 tgpig[[threadgroup_position_in_grid]],
3062
  uint tiisg[[thread_index_in_simdgroup]],
3063
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3147
  uint tiisg[[thread_index_in_simdgroup]],
3148
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3149
 
3150
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3151
  }
3152
 
3153
  #if QK_K == 256
 
3164
  constant int64_t & ne1,
3165
  constant uint & r2,
3166
  constant uint & r3,
3167
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3168
  uint3 tgpig[[threadgroup_position_in_grid]],
3169
  uint tiisg[[thread_index_in_simdgroup]],
3170
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3278
  constant int64_t & ne1,
3279
  constant uint & r2,
3280
  constant uint & r3,
3281
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3282
  uint3 tgpig[[threadgroup_position_in_grid]],
3283
  uint tiisg[[thread_index_in_simdgroup]],
3284
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3387
  uint tiisg[[thread_index_in_simdgroup]],
3388
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3389
 
3390
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3391
  }
3392
 
3393
  void kernel_mul_mv_q5_K_f32_impl(
 
3403
  constant int64_t & ne1,
3404
  constant uint & r2,
3405
  constant uint & r3,
3406
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3407
  uint3 tgpig[[threadgroup_position_in_grid]],
3408
  uint tiisg[[thread_index_in_simdgroup]],
3409
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3594
  uint tiisg[[thread_index_in_simdgroup]],
3595
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3596
 
3597
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3598
  }
3599
 
3600
  void kernel_mul_mv_q6_K_f32_impl(
 
3610
  constant int64_t & ne1,
3611
  constant uint & r2,
3612
  constant uint & r3,
3613
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3614
  uint3 tgpig[[threadgroup_position_in_grid]],
3615
  uint tiisg[[thread_index_in_simdgroup]],
3616
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
3729
  uint tiisg[[thread_index_in_simdgroup]],
3730
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3731
 
3732
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3733
  }
3734
 
3735
  // ======================= "True" 2-bit
 
4412
  constant int64_t & ne1,
4413
  constant uint & r2,
4414
  constant uint & r3,
4415
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4416
  uint3 tgpig[[threadgroup_position_in_grid]],
4417
  uint tiisg[[thread_index_in_simdgroup]],
4418
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
4502
  constant int64_t & ne1,
4503
  constant uint & r2,
4504
  constant uint & r3,
4505
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4506
  uint3 tgpig[[threadgroup_position_in_grid]],
4507
  uint tiisg[[thread_index_in_simdgroup]],
4508
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
4611
  constant int64_t & ne1,
4612
  constant uint & r2,
4613
  constant uint & r3,
4614
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
4615
  uint3 tgpig[[threadgroup_position_in_grid]],
4616
  uint tiisg[[thread_index_in_simdgroup]],
4617
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4618
 
4619
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4620
  const int nb = ne00/QK4_NL;
4621
  const int r0 = tgpig.x;
4622
  const int r1 = tgpig.y;
 
4706
  constant int64_t & ne1,
4707
  constant uint & r2,
4708
  constant uint & r3,
4709
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
4710
  uint3 tgpig[[threadgroup_position_in_grid]],
4711
  uint tiisg[[thread_index_in_simdgroup]],
4712
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4713
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4714
  const int nb = ne00/QK_K;
4715
  const int r0 = tgpig.x;
4716
  const int r1 = tgpig.y;
 
4813
  uint tiisg[[thread_index_in_simdgroup]],
4814
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4815
 
4816
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4817
  }
4818
 
4819
  [[host_name("kernel_mul_mv_iq1_m_f32")]]
 
4841
  uint tiisg[[thread_index_in_simdgroup]],
4842
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4843
 
4844
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4845
  }
4846
 
4847
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
 
4865
  constant int64_t & ne1,
4866
  constant uint & r2,
4867
  constant uint & r3,
4868
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4869
  uint3 tgpig[[threadgroup_position_in_grid]],
4870
  uint tiisg[[thread_index_in_simdgroup]],
4871
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
4894
  constant int64_t & ne1,
4895
  constant uint & r2,
4896
  constant uint & r3,
4897
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4898
  uint3 tgpig[[threadgroup_position_in_grid]],
4899
  uint tiisg[[thread_index_in_simdgroup]],
4900
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
6041
  // matrix-vector multiplication
6042
  //
6043
 
6044
+ typedef void (kernel_mul_mv_impl_t)(
6045
+ device const char * src0,
6046
+ device const char * src1,
6047
+ device float * dst,
6048
+ constant int64_t & ne00,
6049
+ constant int64_t & ne01,
6050
+ constant int64_t & ne02,
6051
+ constant uint64_t & nb00,
6052
+ constant uint64_t & nb01,
6053
+ constant uint64_t & nb02,
6054
+ constant int64_t & ne10,
6055
+ constant int64_t & ne11,
6056
+ constant int64_t & ne12,
6057
+ constant uint64_t & nb10,
6058
+ constant uint64_t & nb11,
6059
+ constant uint64_t & nb12,
6060
+ constant int64_t & ne0,
6061
+ constant int64_t & ne1,
6062
+ constant uint & r2,
6063
+ constant uint & r3,
6064
+ uint3 tgpig[[threadgroup_position_in_grid]],
6065
+ uint tiisg[[thread_index_in_simdgroup]]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6066
 
6067
+ typedef void (kernel_mul_mv2_impl_t)(
6068
+ device const void * src0,
6069
+ device const float * src1,
6070
+ device float * dst,
6071
+ constant int64_t & ne00,
6072
+ constant int64_t & ne01,
6073
+ constant int64_t & ne02,
6074
+ constant int64_t & ne10,
6075
+ constant int64_t & ne12,
6076
+ constant int64_t & ne0,
6077
+ constant int64_t & ne1,
6078
+ constant uint & r2,
6079
+ constant uint & r3,
6080
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6081
+ uint3 tgpig[[threadgroup_position_in_grid]],
6082
+ uint tiisg[[thread_index_in_simdgroup]],
6083
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
 
 
 
 
 
 
6084
 
6085
+ template<kernel_mul_mv_impl_t impl_fn>
6086
+ void mmv_fn(
6087
+ device const char * src0,
6088
  device const char * src1,
6089
  device float * dst,
 
 
6090
  constant int64_t & ne00,
6091
  constant int64_t & ne01,
6092
  constant int64_t & ne02,
 
6105
  constant uint64_t & nb1,
6106
  constant uint & r2,
6107
  constant uint & r3,
6108
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6109
  uint3 tgpig[[threadgroup_position_in_grid]],
6110
  uint tiitg[[thread_index_in_threadgroup]],
6111
  uint tiisg[[thread_index_in_simdgroup]],
6112
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6113
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6114
  }
6115
 
6116
+ template<kernel_mul_mv2_impl_t impl_fn>
6117
+ void mmv_fn(
6118
+ device const char * src0,
6119
  device const char * src1,
6120
  device float * dst,
 
 
6121
  constant int64_t & ne00,
6122
  constant int64_t & ne01,
6123
  constant int64_t & ne02,
 
6136
  constant uint64_t & nb1,
6137
  constant uint & r2,
6138
  constant uint & r3,
6139
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6140
  uint3 tgpig[[threadgroup_position_in_grid]],
6141
  uint tiitg[[thread_index_in_threadgroup]],
6142
  uint tiisg[[thread_index_in_simdgroup]],
6143
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6144
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6145
  }
6146
 
6147
+ typedef void (mul_mv_impl_fn_t)(
6148
+ device const char * src0,
 
6149
  device const char * src1,
6150
  device float * dst,
 
 
6151
  constant int64_t & ne00,
6152
  constant int64_t & ne01,
6153
  constant int64_t & ne02,
 
6166
  constant uint64_t & nb1,
6167
  constant uint & r2,
6168
  constant uint & r3,
6169
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6170
  uint3 tgpig[[threadgroup_position_in_grid]],
6171
  uint tiitg[[thread_index_in_threadgroup]],
6172
  uint tiisg[[thread_index_in_simdgroup]],
6173
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6174
 
6175
+ template<mul_mv_impl_fn_t impl_fn>
6176
+ kernel void kernel_mul_mv_id(
6177
  device const char * src0s,
6178
  device const char * src1,
6179
  device float * dst,
 
6198
  constant uint & r2,
6199
  constant uint & r3,
6200
  constant int & idx,
6201
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6202
  uint3 tgpig[[threadgroup_position_in_grid]],
6203
  uint tiitg[[thread_index_in_threadgroup]],
6204
  uint tiisg[[thread_index_in_simdgroup]],
 
6210
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6211
  device const char * src0 = src0s + id*nb02;
6212
 
6213
+ impl_fn(
6214
  src0,
6215
+ src1 + bid*nb11,
6216
+ dst + bid*ne0,
6217
  ne00,
6218
  ne01,
6219
  ne02,
6220
+ nb00,
6221
+ nb01,
6222
+ nb02,
6223
  ne10,
6224
+ ne11,
6225
  ne12,
6226
+ ne13,
6227
+ nb10,
6228
+ nb11,
6229
+ nb12,
6230
  ne0,
6231
  ne1,
6232
+ nb1,
6233
  r2,
6234
  r3,
6235
+ shared_values,
6236
  tgpig,
6237
+ tiitg,
6238
  tiisg,
6239
  sgitg);
6240
  }
6241
 
6242
+ typedef void (kernel_mul_mv_id_t)(
 
6243
  device const char * src0s,
6244
  device const char * src1,
6245
  device float * dst,
 
6264
  constant uint & r2,
6265
  constant uint & r3,
6266
  constant int & idx,
6267
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6268
  uint3 tgpig[[threadgroup_position_in_grid]],
6269
  uint tiitg[[thread_index_in_threadgroup]],
6270
  uint tiisg[[thread_index_in_simdgroup]],
6271
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
6272
+
6273
+ template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
6274
+ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
6275
+ template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6276
+ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6277
+ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6278
+ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6279
+ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6280
+ template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6281
+ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6282
+ template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6283
+ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6284
+ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6285
+ template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6286
+ template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6287
+ template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
6288
+ template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
6289
+ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
6290
+ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6291
+ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6292
+ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6293
+ #if QK_K != 64
6294
+ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6295
+ #endif
6296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml.c CHANGED
@@ -11074,7 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
11074
  }
11075
 
11076
  // initialize matrix_row_counts
11077
- GGML_ASSERT(wdata == wdata_src1_end);
11078
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
11079
 
11080
  // group rows by src0 matrix
 
11074
  }
11075
 
11076
  // initialize matrix_row_counts
 
11077
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
11078
 
11079
  // group rows by src0 matrix