Spaces:
Running
Running
slaren commited on
Commit ·
e9910b5
1
Parent(s): 1706870
metal : unify mul_mv_id kernels (llama/6556)
Browse files- ggml-metal.m +5 -0
- ggml-metal.metal +135 -1054
- ggml.c +0 -1
ggml-metal.m
CHANGED
|
@@ -1941,7 +1941,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1941 |
{
|
| 1942 |
nth0 = 4;
|
| 1943 |
nth1 = 16;
|
|
|
|
|
|
|
|
|
|
| 1944 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
|
|
|
|
|
|
| 1945 |
} break;
|
| 1946 |
default:
|
| 1947 |
{
|
|
|
|
| 1941 |
{
|
| 1942 |
nth0 = 4;
|
| 1943 |
nth1 = 16;
|
| 1944 |
+
#if QK_K == 64
|
| 1945 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
| 1946 |
+
#else
|
| 1947 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
| 1948 |
+
#endif
|
| 1949 |
+
|
| 1950 |
} break;
|
| 1951 |
default:
|
| 1952 |
{
|
ggml-metal.metal
CHANGED
|
@@ -864,15 +864,16 @@ void mul_vec_q_n_f32_impl(
|
|
| 864 |
device const void * src0,
|
| 865 |
device const float * src1,
|
| 866 |
device float * dst,
|
| 867 |
-
int64_t ne00,
|
| 868 |
-
int64_t ne01,
|
| 869 |
-
int64_t ne02,
|
| 870 |
-
int64_t ne10,
|
| 871 |
-
int64_t ne12,
|
| 872 |
-
int64_t ne0,
|
| 873 |
-
int64_t ne1,
|
| 874 |
-
uint r2,
|
| 875 |
-
uint r3,
|
|
|
|
| 876 |
uint3 tgpig, uint tiisg, uint sgitg) {
|
| 877 |
const int nb = ne00/QK4_0;
|
| 878 |
|
|
@@ -949,7 +950,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
| 949 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 950 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 951 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 952 |
-
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 953 |
}
|
| 954 |
|
| 955 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -975,7 +976,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
| 975 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 976 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 977 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 978 |
-
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 979 |
}
|
| 980 |
|
| 981 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -1001,7 +1002,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
| 1001 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1002 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1003 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1004 |
-
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 1005 |
}
|
| 1006 |
|
| 1007 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -1027,7 +1028,7 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
| 1027 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1028 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1029 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1030 |
-
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 1031 |
}
|
| 1032 |
|
| 1033 |
|
|
@@ -1046,6 +1047,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 1046 |
constant int64_t & ne1,
|
| 1047 |
constant uint & r2,
|
| 1048 |
constant uint & r3,
|
|
|
|
| 1049 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1050 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1051 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1126,7 +1128,7 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 1126 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1127 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1128 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1129 |
-
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
| 1130 |
}
|
| 1131 |
|
| 1132 |
#define N_F32_F32 4
|
|
@@ -2716,6 +2718,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 2716 |
constant int64_t & ne1,
|
| 2717 |
constant uint & r2,
|
| 2718 |
constant uint & r3,
|
|
|
|
| 2719 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2720 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2721 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2878,7 +2881,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
| 2878 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2879 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2880 |
|
| 2881 |
-
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 2882 |
}
|
| 2883 |
|
| 2884 |
#if QK_K == 256
|
|
@@ -2895,6 +2898,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 2895 |
constant int64_t & ne1,
|
| 2896 |
constant uint & r2,
|
| 2897 |
constant uint & r3,
|
|
|
|
| 2898 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2899 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2900 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3053,6 +3057,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 3053 |
constant int64_t & ne1,
|
| 3054 |
constant uint & r2,
|
| 3055 |
constant uint & r3,
|
|
|
|
| 3056 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3057 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3058 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3142,7 +3147,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 3142 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3143 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3144 |
|
| 3145 |
-
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3146 |
}
|
| 3147 |
|
| 3148 |
#if QK_K == 256
|
|
@@ -3159,6 +3164,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3159 |
constant int64_t & ne1,
|
| 3160 |
constant uint & r2,
|
| 3161 |
constant uint & r3,
|
|
|
|
| 3162 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3163 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3164 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3272,6 +3278,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3272 |
constant int64_t & ne1,
|
| 3273 |
constant uint & r2,
|
| 3274 |
constant uint & r3,
|
|
|
|
| 3275 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3276 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3277 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3380,7 +3387,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
| 3380 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3381 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3382 |
|
| 3383 |
-
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3384 |
}
|
| 3385 |
|
| 3386 |
void kernel_mul_mv_q5_K_f32_impl(
|
|
@@ -3396,6 +3403,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 3396 |
constant int64_t & ne1,
|
| 3397 |
constant uint & r2,
|
| 3398 |
constant uint & r3,
|
|
|
|
| 3399 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3400 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3401 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3586,7 +3594,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
| 3586 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3587 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3588 |
|
| 3589 |
-
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3590 |
}
|
| 3591 |
|
| 3592 |
void kernel_mul_mv_q6_K_f32_impl(
|
|
@@ -3602,6 +3610,7 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
| 3602 |
constant int64_t & ne1,
|
| 3603 |
constant uint & r2,
|
| 3604 |
constant uint & r3,
|
|
|
|
| 3605 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3606 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3607 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3720,7 +3729,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 3720 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3721 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3722 |
|
| 3723 |
-
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 3724 |
}
|
| 3725 |
|
| 3726 |
// ======================= "True" 2-bit
|
|
@@ -4403,6 +4412,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 4403 |
constant int64_t & ne1,
|
| 4404 |
constant uint & r2,
|
| 4405 |
constant uint & r3,
|
|
|
|
| 4406 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4407 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4408 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -4492,6 +4502,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 4492 |
constant int64_t & ne1,
|
| 4493 |
constant uint & r2,
|
| 4494 |
constant uint & r3,
|
|
|
|
| 4495 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4496 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4497 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -4600,11 +4611,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 4600 |
constant int64_t & ne1,
|
| 4601 |
constant uint & r2,
|
| 4602 |
constant uint & r3,
|
| 4603 |
-
threadgroup
|
| 4604 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4605 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4606 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4607 |
|
|
|
|
| 4608 |
const int nb = ne00/QK4_NL;
|
| 4609 |
const int r0 = tgpig.x;
|
| 4610 |
const int r1 = tgpig.y;
|
|
@@ -4694,11 +4706,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 4694 |
constant int64_t & ne1,
|
| 4695 |
constant uint & r2,
|
| 4696 |
constant uint & r3,
|
| 4697 |
-
threadgroup
|
| 4698 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4699 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4700 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4701 |
-
|
| 4702 |
const int nb = ne00/QK_K;
|
| 4703 |
const int r0 = tgpig.x;
|
| 4704 |
const int r1 = tgpig.y;
|
|
@@ -4801,7 +4813,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
| 4801 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4802 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4803 |
|
| 4804 |
-
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 4805 |
}
|
| 4806 |
|
| 4807 |
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
|
@@ -4829,7 +4841,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
| 4829 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4830 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4831 |
|
| 4832 |
-
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
| 4833 |
}
|
| 4834 |
|
| 4835 |
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
@@ -4853,7 +4865,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
| 4853 |
constant int64_t & ne1,
|
| 4854 |
constant uint & r2,
|
| 4855 |
constant uint & r3,
|
| 4856 |
-
threadgroup
|
| 4857 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4858 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4859 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -4882,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
| 4882 |
constant int64_t & ne1,
|
| 4883 |
constant uint & r2,
|
| 4884 |
constant uint & r3,
|
| 4885 |
-
threadgroup
|
| 4886 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4887 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4888 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -6029,135 +6041,52 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
| 6029 |
// matrix-vector multiplication
|
| 6030 |
//
|
| 6031 |
|
| 6032 |
-
|
| 6033 |
-
|
| 6034 |
-
device const
|
| 6035 |
-
device
|
| 6036 |
-
|
| 6037 |
-
|
| 6038 |
-
constant
|
| 6039 |
-
constant
|
| 6040 |
-
constant
|
| 6041 |
-
constant
|
| 6042 |
-
constant
|
| 6043 |
-
constant
|
| 6044 |
-
constant
|
| 6045 |
-
constant
|
| 6046 |
-
constant
|
| 6047 |
-
constant
|
| 6048 |
-
constant
|
| 6049 |
-
constant
|
| 6050 |
-
constant
|
| 6051 |
-
constant
|
| 6052 |
-
|
| 6053 |
-
|
| 6054 |
-
constant uint64_t & nb1,
|
| 6055 |
-
constant uint & r2,
|
| 6056 |
-
constant uint & r3,
|
| 6057 |
-
constant int & idx,
|
| 6058 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6059 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6060 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6061 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6062 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6063 |
-
|
| 6064 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6065 |
-
|
| 6066 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6067 |
-
device const char * src0 = src0s + id*nb02;
|
| 6068 |
-
|
| 6069 |
-
kernel_mul_mv_f32_f32_impl(
|
| 6070 |
-
src0,
|
| 6071 |
-
src1 + bid*nb11,
|
| 6072 |
-
dst + bid*ne0,
|
| 6073 |
-
ne00,
|
| 6074 |
-
ne01,
|
| 6075 |
-
ne02,
|
| 6076 |
-
nb00,
|
| 6077 |
-
nb01,
|
| 6078 |
-
nb02,
|
| 6079 |
-
ne10,
|
| 6080 |
-
ne11,
|
| 6081 |
-
ne12,
|
| 6082 |
-
nb10,
|
| 6083 |
-
nb11,
|
| 6084 |
-
nb12,
|
| 6085 |
-
ne0,
|
| 6086 |
-
ne1,
|
| 6087 |
-
r2,
|
| 6088 |
-
r3,
|
| 6089 |
-
tgpig,
|
| 6090 |
-
tiisg);
|
| 6091 |
-
}
|
| 6092 |
-
|
| 6093 |
-
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
| 6094 |
-
kernel void kernel_mul_mv_id_f16_f32(
|
| 6095 |
-
device const char * src0s,
|
| 6096 |
-
device const char * src1,
|
| 6097 |
-
device float * dst,
|
| 6098 |
-
device const char * ids,
|
| 6099 |
-
constant uint64_t & nbi1,
|
| 6100 |
-
constant int64_t & ne00,
|
| 6101 |
-
constant int64_t & ne01,
|
| 6102 |
-
constant int64_t & ne02,
|
| 6103 |
-
constant uint64_t & nb00,
|
| 6104 |
-
constant uint64_t & nb01,
|
| 6105 |
-
constant uint64_t & nb02,
|
| 6106 |
-
constant int64_t & ne10,
|
| 6107 |
-
constant int64_t & ne11,
|
| 6108 |
-
constant int64_t & ne12,
|
| 6109 |
-
constant int64_t & ne13,
|
| 6110 |
-
constant uint64_t & nb10,
|
| 6111 |
-
constant uint64_t & nb11,
|
| 6112 |
-
constant uint64_t & nb12,
|
| 6113 |
-
constant int64_t & ne0,
|
| 6114 |
-
constant int64_t & ne1,
|
| 6115 |
-
constant uint64_t & nb1,
|
| 6116 |
-
constant uint & r2,
|
| 6117 |
-
constant uint & r3,
|
| 6118 |
-
constant int & idx,
|
| 6119 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6120 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6121 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6122 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6123 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6124 |
-
|
| 6125 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6126 |
-
|
| 6127 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6128 |
-
device const char * src0 = src0s + id*nb02;
|
| 6129 |
|
| 6130 |
-
|
| 6131 |
-
src0,
|
| 6132 |
-
|
| 6133 |
-
|
| 6134 |
-
ne00,
|
| 6135 |
-
ne01,
|
| 6136 |
-
ne02,
|
| 6137 |
-
|
| 6138 |
-
|
| 6139 |
-
|
| 6140 |
-
|
| 6141 |
-
|
| 6142 |
-
|
| 6143 |
-
|
| 6144 |
-
|
| 6145 |
-
|
| 6146 |
-
|
| 6147 |
-
ne1,
|
| 6148 |
-
r2,
|
| 6149 |
-
r3,
|
| 6150 |
-
tgpig,
|
| 6151 |
-
tiisg);
|
| 6152 |
-
}
|
| 6153 |
|
| 6154 |
-
|
| 6155 |
-
|
| 6156 |
-
device const char *
|
| 6157 |
device const char * src1,
|
| 6158 |
device float * dst,
|
| 6159 |
-
device const char * ids,
|
| 6160 |
-
constant uint64_t & nbi1,
|
| 6161 |
constant int64_t & ne00,
|
| 6162 |
constant int64_t & ne01,
|
| 6163 |
constant int64_t & ne02,
|
|
@@ -6176,43 +6105,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
| 6176 |
constant uint64_t & nb1,
|
| 6177 |
constant uint & r2,
|
| 6178 |
constant uint & r3,
|
| 6179 |
-
|
| 6180 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6181 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6182 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6183 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6184 |
-
|
| 6185 |
-
|
| 6186 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6187 |
-
|
| 6188 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6189 |
-
device const char * src0 = src0s + id*nb02;
|
| 6190 |
-
|
| 6191 |
-
kernel_mul_mv_q8_0_f32_impl(
|
| 6192 |
-
src0,
|
| 6193 |
-
(device const float *) (src1 + bid*nb11),
|
| 6194 |
-
dst + bid*ne0,
|
| 6195 |
-
ne00,
|
| 6196 |
-
ne01,
|
| 6197 |
-
ne02,
|
| 6198 |
-
ne10,
|
| 6199 |
-
ne12,
|
| 6200 |
-
ne0,
|
| 6201 |
-
ne1,
|
| 6202 |
-
r2,
|
| 6203 |
-
r3,
|
| 6204 |
-
tgpig,
|
| 6205 |
-
tiisg,
|
| 6206 |
-
sgitg);
|
| 6207 |
}
|
| 6208 |
|
| 6209 |
-
|
| 6210 |
-
|
| 6211 |
-
device const char *
|
| 6212 |
device const char * src1,
|
| 6213 |
device float * dst,
|
| 6214 |
-
device const char * ids,
|
| 6215 |
-
constant uint64_t & nbi1,
|
| 6216 |
constant int64_t & ne00,
|
| 6217 |
constant int64_t & ne01,
|
| 6218 |
constant int64_t & ne02,
|
|
@@ -6231,43 +6136,18 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
| 6231 |
constant uint64_t & nb1,
|
| 6232 |
constant uint & r2,
|
| 6233 |
constant uint & r3,
|
| 6234 |
-
|
| 6235 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6236 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6237 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6238 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6239 |
-
const
|
| 6240 |
-
|
| 6241 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6242 |
-
|
| 6243 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6244 |
-
device const char * src0 = src0s + id*nb02;
|
| 6245 |
-
|
| 6246 |
-
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6247 |
-
src0,
|
| 6248 |
-
(device const float *) (src1 + bid*nb11),
|
| 6249 |
-
dst + bid*ne0,
|
| 6250 |
-
ne00,
|
| 6251 |
-
ne01,
|
| 6252 |
-
ne02,
|
| 6253 |
-
ne10,
|
| 6254 |
-
ne12,
|
| 6255 |
-
ne0,
|
| 6256 |
-
ne1,
|
| 6257 |
-
r2,
|
| 6258 |
-
r3,
|
| 6259 |
-
tgpig,
|
| 6260 |
-
tiisg,
|
| 6261 |
-
sgitg);
|
| 6262 |
}
|
| 6263 |
|
| 6264 |
-
|
| 6265 |
-
|
| 6266 |
-
device const char * src0s,
|
| 6267 |
device const char * src1,
|
| 6268 |
device float * dst,
|
| 6269 |
-
device const char * ids,
|
| 6270 |
-
constant uint64_t & nbi1,
|
| 6271 |
constant int64_t & ne00,
|
| 6272 |
constant int64_t & ne01,
|
| 6273 |
constant int64_t & ne02,
|
|
@@ -6286,38 +6166,14 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
| 6286 |
constant uint64_t & nb1,
|
| 6287 |
constant uint & r2,
|
| 6288 |
constant uint & r3,
|
| 6289 |
-
|
| 6290 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6291 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6292 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6293 |
-
uint sgitg[[simdgroup_index_in_threadgroup]])
|
| 6294 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6295 |
-
|
| 6296 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6297 |
-
|
| 6298 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6299 |
-
device const char * src0 = src0s + id*nb02;
|
| 6300 |
-
|
| 6301 |
-
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6302 |
-
src0,
|
| 6303 |
-
(device const float *) (src1 + bid*nb11),
|
| 6304 |
-
dst + bid*ne0,
|
| 6305 |
-
ne00,
|
| 6306 |
-
ne01,
|
| 6307 |
-
ne02,
|
| 6308 |
-
ne10,
|
| 6309 |
-
ne12,
|
| 6310 |
-
ne0,
|
| 6311 |
-
ne1,
|
| 6312 |
-
r2,
|
| 6313 |
-
r3,
|
| 6314 |
-
tgpig,
|
| 6315 |
-
tiisg,
|
| 6316 |
-
sgitg);
|
| 6317 |
-
}
|
| 6318 |
|
| 6319 |
-
|
| 6320 |
-
kernel void
|
| 6321 |
device const char * src0s,
|
| 6322 |
device const char * src1,
|
| 6323 |
device float * dst,
|
|
@@ -6342,6 +6198,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
| 6342 |
constant uint & r2,
|
| 6343 |
constant uint & r3,
|
| 6344 |
constant int & idx,
|
|
|
|
| 6345 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6346 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6347 |
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -6353,26 +6210,36 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
| 6353 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6354 |
device const char * src0 = src0s + id*nb02;
|
| 6355 |
|
| 6356 |
-
|
| 6357 |
src0,
|
| 6358 |
-
|
| 6359 |
-
dst
|
| 6360 |
ne00,
|
| 6361 |
ne01,
|
| 6362 |
ne02,
|
|
|
|
|
|
|
|
|
|
| 6363 |
ne10,
|
|
|
|
| 6364 |
ne12,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6365 |
ne0,
|
| 6366 |
ne1,
|
|
|
|
| 6367 |
r2,
|
| 6368 |
r3,
|
|
|
|
| 6369 |
tgpig,
|
|
|
|
| 6370 |
tiisg,
|
| 6371 |
sgitg);
|
| 6372 |
}
|
| 6373 |
|
| 6374 |
-
|
| 6375 |
-
kernel void kernel_mul_mv_id_q5_1_f32(
|
| 6376 |
device const char * src0s,
|
| 6377 |
device const char * src1,
|
| 6378 |
device float * dst,
|
|
@@ -6397,819 +6264,33 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
| 6397 |
constant uint & r2,
|
| 6398 |
constant uint & r3,
|
| 6399 |
constant int & idx,
|
|
|
|
| 6400 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6401 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6402 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6403 |
-
uint sgitg[[simdgroup_index_in_threadgroup]])
|
| 6404 |
-
|
| 6405 |
-
|
| 6406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6407 |
|
| 6408 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6409 |
-
device const char * src0 = src0s + id*nb02;
|
| 6410 |
-
|
| 6411 |
-
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
| 6412 |
-
src0,
|
| 6413 |
-
(device const float *) (src1 + bid*nb11),
|
| 6414 |
-
dst + bid*ne0,
|
| 6415 |
-
ne00,
|
| 6416 |
-
ne01,
|
| 6417 |
-
ne02,
|
| 6418 |
-
ne10,
|
| 6419 |
-
ne12,
|
| 6420 |
-
ne0,
|
| 6421 |
-
ne1,
|
| 6422 |
-
r2,
|
| 6423 |
-
r3,
|
| 6424 |
-
tgpig,
|
| 6425 |
-
tiisg,
|
| 6426 |
-
sgitg);
|
| 6427 |
-
}
|
| 6428 |
-
|
| 6429 |
-
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
| 6430 |
-
kernel void kernel_mul_mv_id_q2_K_f32(
|
| 6431 |
-
device const char * src0s,
|
| 6432 |
-
device const char * src1,
|
| 6433 |
-
device float * dst,
|
| 6434 |
-
device const char * ids,
|
| 6435 |
-
constant uint64_t & nbi1,
|
| 6436 |
-
constant int64_t & ne00,
|
| 6437 |
-
constant int64_t & ne01,
|
| 6438 |
-
constant int64_t & ne02,
|
| 6439 |
-
constant uint64_t & nb00,
|
| 6440 |
-
constant uint64_t & nb01,
|
| 6441 |
-
constant uint64_t & nb02,
|
| 6442 |
-
constant int64_t & ne10,
|
| 6443 |
-
constant int64_t & ne11,
|
| 6444 |
-
constant int64_t & ne12,
|
| 6445 |
-
constant int64_t & ne13,
|
| 6446 |
-
constant uint64_t & nb10,
|
| 6447 |
-
constant uint64_t & nb11,
|
| 6448 |
-
constant uint64_t & nb12,
|
| 6449 |
-
constant int64_t & ne0,
|
| 6450 |
-
constant int64_t & ne1,
|
| 6451 |
-
constant uint64_t & nb1,
|
| 6452 |
-
constant uint & r2,
|
| 6453 |
-
constant uint & r3,
|
| 6454 |
-
constant int & idx,
|
| 6455 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6456 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6457 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6458 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6459 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6460 |
-
|
| 6461 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6462 |
-
|
| 6463 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6464 |
-
device const char * src0 = src0s + id*nb02;
|
| 6465 |
-
|
| 6466 |
-
kernel_mul_mv_q2_K_f32_impl(
|
| 6467 |
-
src0,
|
| 6468 |
-
(device const float *) (src1 + bid*nb11),
|
| 6469 |
-
dst + bid*ne0,
|
| 6470 |
-
ne00,
|
| 6471 |
-
ne01,
|
| 6472 |
-
ne02,
|
| 6473 |
-
ne10,
|
| 6474 |
-
ne12,
|
| 6475 |
-
ne0,
|
| 6476 |
-
ne1,
|
| 6477 |
-
r2,
|
| 6478 |
-
r3,
|
| 6479 |
-
tgpig,
|
| 6480 |
-
tiisg,
|
| 6481 |
-
sgitg);
|
| 6482 |
-
}
|
| 6483 |
-
|
| 6484 |
-
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
| 6485 |
-
kernel void kernel_mul_mv_id_q3_K_f32(
|
| 6486 |
-
device const char * src0s,
|
| 6487 |
-
device const char * src1,
|
| 6488 |
-
device float * dst,
|
| 6489 |
-
device const char * ids,
|
| 6490 |
-
constant uint64_t & nbi1,
|
| 6491 |
-
constant int64_t & ne00,
|
| 6492 |
-
constant int64_t & ne01,
|
| 6493 |
-
constant int64_t & ne02,
|
| 6494 |
-
constant uint64_t & nb00,
|
| 6495 |
-
constant uint64_t & nb01,
|
| 6496 |
-
constant uint64_t & nb02,
|
| 6497 |
-
constant int64_t & ne10,
|
| 6498 |
-
constant int64_t & ne11,
|
| 6499 |
-
constant int64_t & ne12,
|
| 6500 |
-
constant int64_t & ne13,
|
| 6501 |
-
constant uint64_t & nb10,
|
| 6502 |
-
constant uint64_t & nb11,
|
| 6503 |
-
constant uint64_t & nb12,
|
| 6504 |
-
constant int64_t & ne0,
|
| 6505 |
-
constant int64_t & ne1,
|
| 6506 |
-
constant uint64_t & nb1,
|
| 6507 |
-
constant uint & r2,
|
| 6508 |
-
constant uint & r3,
|
| 6509 |
-
constant int & idx,
|
| 6510 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6511 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6512 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6513 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6514 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6515 |
-
|
| 6516 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6517 |
-
|
| 6518 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6519 |
-
device const char * src0 = src0s + id*nb02;
|
| 6520 |
-
|
| 6521 |
-
kernel_mul_mv_q3_K_f32_impl(
|
| 6522 |
-
src0,
|
| 6523 |
-
(device const float *) (src1 + bid*nb11),
|
| 6524 |
-
dst + bid*ne0,
|
| 6525 |
-
ne00,
|
| 6526 |
-
ne01,
|
| 6527 |
-
ne02,
|
| 6528 |
-
ne10,
|
| 6529 |
-
ne12,
|
| 6530 |
-
ne0,
|
| 6531 |
-
ne1,
|
| 6532 |
-
r2,
|
| 6533 |
-
r3,
|
| 6534 |
-
tgpig,
|
| 6535 |
-
tiisg,
|
| 6536 |
-
sgitg);
|
| 6537 |
-
}
|
| 6538 |
-
|
| 6539 |
-
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
| 6540 |
-
kernel void kernel_mul_mv_id_q4_K_f32(
|
| 6541 |
-
device const char * src0s,
|
| 6542 |
-
device const char * src1,
|
| 6543 |
-
device float * dst,
|
| 6544 |
-
device const char * ids,
|
| 6545 |
-
constant uint64_t & nbi1,
|
| 6546 |
-
constant int64_t & ne00,
|
| 6547 |
-
constant int64_t & ne01,
|
| 6548 |
-
constant int64_t & ne02,
|
| 6549 |
-
constant uint64_t & nb00,
|
| 6550 |
-
constant uint64_t & nb01,
|
| 6551 |
-
constant uint64_t & nb02,
|
| 6552 |
-
constant int64_t & ne10,
|
| 6553 |
-
constant int64_t & ne11,
|
| 6554 |
-
constant int64_t & ne12,
|
| 6555 |
-
constant int64_t & ne13,
|
| 6556 |
-
constant uint64_t & nb10,
|
| 6557 |
-
constant uint64_t & nb11,
|
| 6558 |
-
constant uint64_t & nb12,
|
| 6559 |
-
constant int64_t & ne0,
|
| 6560 |
-
constant int64_t & ne1,
|
| 6561 |
-
constant uint64_t & nb1,
|
| 6562 |
-
constant uint & r2,
|
| 6563 |
-
constant uint & r3,
|
| 6564 |
-
constant int & idx,
|
| 6565 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6566 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6567 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6568 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6569 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6570 |
-
|
| 6571 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6572 |
-
|
| 6573 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6574 |
-
device const char * src0 = src0s + id*nb02;
|
| 6575 |
-
|
| 6576 |
-
kernel_mul_mv_q4_K_f32_impl(
|
| 6577 |
-
src0,
|
| 6578 |
-
(device const float *) (src1 + bid*nb11),
|
| 6579 |
-
dst + bid*ne0,
|
| 6580 |
-
ne00,
|
| 6581 |
-
ne01,
|
| 6582 |
-
ne02,
|
| 6583 |
-
ne10,
|
| 6584 |
-
ne12,
|
| 6585 |
-
ne0,
|
| 6586 |
-
ne1,
|
| 6587 |
-
r2,
|
| 6588 |
-
r3,
|
| 6589 |
-
tgpig,
|
| 6590 |
-
tiisg,
|
| 6591 |
-
sgitg);
|
| 6592 |
-
}
|
| 6593 |
-
|
| 6594 |
-
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
| 6595 |
-
kernel void kernel_mul_mv_id_q5_K_f32(
|
| 6596 |
-
device const char * src0s,
|
| 6597 |
-
device const char * src1,
|
| 6598 |
-
device float * dst,
|
| 6599 |
-
device const char * ids,
|
| 6600 |
-
constant uint64_t & nbi1,
|
| 6601 |
-
constant int64_t & ne00,
|
| 6602 |
-
constant int64_t & ne01,
|
| 6603 |
-
constant int64_t & ne02,
|
| 6604 |
-
constant uint64_t & nb00,
|
| 6605 |
-
constant uint64_t & nb01,
|
| 6606 |
-
constant uint64_t & nb02,
|
| 6607 |
-
constant int64_t & ne10,
|
| 6608 |
-
constant int64_t & ne11,
|
| 6609 |
-
constant int64_t & ne12,
|
| 6610 |
-
constant int64_t & ne13,
|
| 6611 |
-
constant uint64_t & nb10,
|
| 6612 |
-
constant uint64_t & nb11,
|
| 6613 |
-
constant uint64_t & nb12,
|
| 6614 |
-
constant int64_t & ne0,
|
| 6615 |
-
constant int64_t & ne1,
|
| 6616 |
-
constant uint64_t & nb1,
|
| 6617 |
-
constant uint & r2,
|
| 6618 |
-
constant uint & r3,
|
| 6619 |
-
constant int & idx,
|
| 6620 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6621 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6622 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6623 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6624 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6625 |
-
|
| 6626 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6627 |
-
|
| 6628 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6629 |
-
device const char * src0 = src0s + id*nb02;
|
| 6630 |
-
|
| 6631 |
-
kernel_mul_mv_q5_K_f32_impl(
|
| 6632 |
-
src0,
|
| 6633 |
-
(device const float *) (src1 + bid*nb11),
|
| 6634 |
-
dst + bid*ne0,
|
| 6635 |
-
ne00,
|
| 6636 |
-
ne01,
|
| 6637 |
-
ne02,
|
| 6638 |
-
ne10,
|
| 6639 |
-
ne12,
|
| 6640 |
-
ne0,
|
| 6641 |
-
ne1,
|
| 6642 |
-
r2,
|
| 6643 |
-
r3,
|
| 6644 |
-
tgpig,
|
| 6645 |
-
tiisg,
|
| 6646 |
-
sgitg);
|
| 6647 |
-
}
|
| 6648 |
-
|
| 6649 |
-
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
| 6650 |
-
kernel void kernel_mul_mv_id_q6_K_f32(
|
| 6651 |
-
device const char * src0s,
|
| 6652 |
-
device const char * src1,
|
| 6653 |
-
device float * dst,
|
| 6654 |
-
device const char * ids,
|
| 6655 |
-
constant uint64_t & nbi1,
|
| 6656 |
-
constant int64_t & ne00,
|
| 6657 |
-
constant int64_t & ne01,
|
| 6658 |
-
constant int64_t & ne02,
|
| 6659 |
-
constant uint64_t & nb00,
|
| 6660 |
-
constant uint64_t & nb01,
|
| 6661 |
-
constant uint64_t & nb02,
|
| 6662 |
-
constant int64_t & ne10,
|
| 6663 |
-
constant int64_t & ne11,
|
| 6664 |
-
constant int64_t & ne12,
|
| 6665 |
-
constant int64_t & ne13,
|
| 6666 |
-
constant uint64_t & nb10,
|
| 6667 |
-
constant uint64_t & nb11,
|
| 6668 |
-
constant uint64_t & nb12,
|
| 6669 |
-
constant int64_t & ne0,
|
| 6670 |
-
constant int64_t & ne1,
|
| 6671 |
-
constant uint64_t & nb1,
|
| 6672 |
-
constant uint & r2,
|
| 6673 |
-
constant uint & r3,
|
| 6674 |
-
constant int & idx,
|
| 6675 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6676 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6677 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6678 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6679 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6680 |
-
|
| 6681 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6682 |
-
|
| 6683 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6684 |
-
device const char * src0 = src0s + id*nb02;
|
| 6685 |
-
|
| 6686 |
-
kernel_mul_mv_q6_K_f32_impl(
|
| 6687 |
-
src0,
|
| 6688 |
-
(device const float *) (src1 + bid*nb11),
|
| 6689 |
-
dst + bid*ne0,
|
| 6690 |
-
ne00,
|
| 6691 |
-
ne01,
|
| 6692 |
-
ne02,
|
| 6693 |
-
ne10,
|
| 6694 |
-
ne12,
|
| 6695 |
-
ne0,
|
| 6696 |
-
ne1,
|
| 6697 |
-
r2,
|
| 6698 |
-
r3,
|
| 6699 |
-
tgpig,
|
| 6700 |
-
tiisg,
|
| 6701 |
-
sgitg);
|
| 6702 |
-
}
|
| 6703 |
-
|
| 6704 |
-
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
| 6705 |
-
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
| 6706 |
-
device const char * src0s,
|
| 6707 |
-
device const char * src1,
|
| 6708 |
-
device float * dst,
|
| 6709 |
-
device const char * ids,
|
| 6710 |
-
constant uint64_t & nbi1,
|
| 6711 |
-
constant int64_t & ne00,
|
| 6712 |
-
constant int64_t & ne01,
|
| 6713 |
-
constant int64_t & ne02,
|
| 6714 |
-
constant uint64_t & nb00,
|
| 6715 |
-
constant uint64_t & nb01,
|
| 6716 |
-
constant uint64_t & nb02,
|
| 6717 |
-
constant int64_t & ne10,
|
| 6718 |
-
constant int64_t & ne11,
|
| 6719 |
-
constant int64_t & ne12,
|
| 6720 |
-
constant int64_t & ne13,
|
| 6721 |
-
constant uint64_t & nb10,
|
| 6722 |
-
constant uint64_t & nb11,
|
| 6723 |
-
constant uint64_t & nb12,
|
| 6724 |
-
constant int64_t & ne0,
|
| 6725 |
-
constant int64_t & ne1,
|
| 6726 |
-
constant uint64_t & nb1,
|
| 6727 |
-
constant uint & r2,
|
| 6728 |
-
constant uint & r3,
|
| 6729 |
-
constant int & idx,
|
| 6730 |
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6731 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6732 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6733 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6734 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6735 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6736 |
-
|
| 6737 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6738 |
-
|
| 6739 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6740 |
-
device const char * src0 = src0s + id*nb02;
|
| 6741 |
-
|
| 6742 |
-
kernel_mul_mv_iq2_xxs_f32_impl(
|
| 6743 |
-
src0,
|
| 6744 |
-
(device const float *) (src1 + bid*nb11),
|
| 6745 |
-
dst + bid*ne0,
|
| 6746 |
-
ne00,
|
| 6747 |
-
ne01,
|
| 6748 |
-
ne02,
|
| 6749 |
-
ne10,
|
| 6750 |
-
ne12,
|
| 6751 |
-
ne0,
|
| 6752 |
-
ne1,
|
| 6753 |
-
r2,
|
| 6754 |
-
r3,
|
| 6755 |
-
shared_values,
|
| 6756 |
-
tgpig,
|
| 6757 |
-
tiisg,
|
| 6758 |
-
sgitg);
|
| 6759 |
-
}
|
| 6760 |
-
|
| 6761 |
-
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
| 6762 |
-
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
| 6763 |
-
device const char * src0s,
|
| 6764 |
-
device const char * src1,
|
| 6765 |
-
device float * dst,
|
| 6766 |
-
device const char * ids,
|
| 6767 |
-
constant uint64_t & nbi1,
|
| 6768 |
-
constant int64_t & ne00,
|
| 6769 |
-
constant int64_t & ne01,
|
| 6770 |
-
constant int64_t & ne02,
|
| 6771 |
-
constant uint64_t & nb00,
|
| 6772 |
-
constant uint64_t & nb01,
|
| 6773 |
-
constant uint64_t & nb02,
|
| 6774 |
-
constant int64_t & ne10,
|
| 6775 |
-
constant int64_t & ne11,
|
| 6776 |
-
constant int64_t & ne12,
|
| 6777 |
-
constant int64_t & ne13,
|
| 6778 |
-
constant uint64_t & nb10,
|
| 6779 |
-
constant uint64_t & nb11,
|
| 6780 |
-
constant uint64_t & nb12,
|
| 6781 |
-
constant int64_t & ne0,
|
| 6782 |
-
constant int64_t & ne1,
|
| 6783 |
-
constant uint64_t & nb1,
|
| 6784 |
-
constant uint & r2,
|
| 6785 |
-
constant uint & r3,
|
| 6786 |
-
constant int & idx,
|
| 6787 |
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6788 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6789 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6790 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6791 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6792 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6793 |
-
|
| 6794 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6795 |
-
|
| 6796 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6797 |
-
device const char * src0 = src0s + id*nb02;
|
| 6798 |
-
|
| 6799 |
-
kernel_mul_mv_iq2_xs_f32_impl(
|
| 6800 |
-
src0,
|
| 6801 |
-
(device const float *) (src1 + bid*nb11),
|
| 6802 |
-
dst + bid*ne0,
|
| 6803 |
-
ne00,
|
| 6804 |
-
ne01,
|
| 6805 |
-
ne02,
|
| 6806 |
-
ne10,
|
| 6807 |
-
ne12,
|
| 6808 |
-
ne0,
|
| 6809 |
-
ne1,
|
| 6810 |
-
r2,
|
| 6811 |
-
r3,
|
| 6812 |
-
shared_values,
|
| 6813 |
-
tgpig,
|
| 6814 |
-
tiisg,
|
| 6815 |
-
sgitg);
|
| 6816 |
-
}
|
| 6817 |
-
|
| 6818 |
-
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
| 6819 |
-
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
| 6820 |
-
device const char * src0s,
|
| 6821 |
-
device const char * src1,
|
| 6822 |
-
device float * dst,
|
| 6823 |
-
device const char * ids,
|
| 6824 |
-
constant uint64_t & nbi1,
|
| 6825 |
-
constant int64_t & ne00,
|
| 6826 |
-
constant int64_t & ne01,
|
| 6827 |
-
constant int64_t & ne02,
|
| 6828 |
-
constant uint64_t & nb00,
|
| 6829 |
-
constant uint64_t & nb01,
|
| 6830 |
-
constant uint64_t & nb02,
|
| 6831 |
-
constant int64_t & ne10,
|
| 6832 |
-
constant int64_t & ne11,
|
| 6833 |
-
constant int64_t & ne12,
|
| 6834 |
-
constant int64_t & ne13,
|
| 6835 |
-
constant uint64_t & nb10,
|
| 6836 |
-
constant uint64_t & nb11,
|
| 6837 |
-
constant uint64_t & nb12,
|
| 6838 |
-
constant int64_t & ne0,
|
| 6839 |
-
constant int64_t & ne1,
|
| 6840 |
-
constant uint64_t & nb1,
|
| 6841 |
-
constant uint & r2,
|
| 6842 |
-
constant uint & r3,
|
| 6843 |
-
constant int & idx,
|
| 6844 |
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6845 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6846 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6847 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6848 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6849 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6850 |
-
|
| 6851 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6852 |
-
|
| 6853 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6854 |
-
device const char * src0 = src0s + id*nb02;
|
| 6855 |
-
|
| 6856 |
-
kernel_mul_mv_iq3_xxs_f32_impl(
|
| 6857 |
-
src0,
|
| 6858 |
-
(device const float *) (src1 + bid*nb11),
|
| 6859 |
-
dst + bid*ne0,
|
| 6860 |
-
ne00,
|
| 6861 |
-
ne01,
|
| 6862 |
-
ne02,
|
| 6863 |
-
ne10,
|
| 6864 |
-
ne12,
|
| 6865 |
-
ne0,
|
| 6866 |
-
ne1,
|
| 6867 |
-
r2,
|
| 6868 |
-
r3,
|
| 6869 |
-
shared_values,
|
| 6870 |
-
tgpig,
|
| 6871 |
-
tiisg,
|
| 6872 |
-
sgitg);
|
| 6873 |
-
}
|
| 6874 |
-
|
| 6875 |
-
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
| 6876 |
-
kernel void kernel_mul_mv_id_iq3_s_f32(
|
| 6877 |
-
device const char * src0s,
|
| 6878 |
-
device const char * src1,
|
| 6879 |
-
device float * dst,
|
| 6880 |
-
device const char * ids,
|
| 6881 |
-
constant uint64_t & nbi1,
|
| 6882 |
-
constant int64_t & ne00,
|
| 6883 |
-
constant int64_t & ne01,
|
| 6884 |
-
constant int64_t & ne02,
|
| 6885 |
-
constant uint64_t & nb00,
|
| 6886 |
-
constant uint64_t & nb01,
|
| 6887 |
-
constant uint64_t & nb02,
|
| 6888 |
-
constant int64_t & ne10,
|
| 6889 |
-
constant int64_t & ne11,
|
| 6890 |
-
constant int64_t & ne12,
|
| 6891 |
-
constant int64_t & ne13,
|
| 6892 |
-
constant uint64_t & nb10,
|
| 6893 |
-
constant uint64_t & nb11,
|
| 6894 |
-
constant uint64_t & nb12,
|
| 6895 |
-
constant int64_t & ne0,
|
| 6896 |
-
constant int64_t & ne1,
|
| 6897 |
-
constant uint64_t & nb1,
|
| 6898 |
-
constant uint & r2,
|
| 6899 |
-
constant uint & r3,
|
| 6900 |
-
constant int & idx,
|
| 6901 |
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6902 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6903 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6904 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6905 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6906 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6907 |
-
|
| 6908 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6909 |
-
|
| 6910 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6911 |
-
device const char * src0 = src0s + id*nb02;
|
| 6912 |
-
|
| 6913 |
-
kernel_mul_mv_iq3_s_f32_impl(
|
| 6914 |
-
src0,
|
| 6915 |
-
(device const float *) (src1 + bid*nb11),
|
| 6916 |
-
dst + bid*ne0,
|
| 6917 |
-
ne00,
|
| 6918 |
-
ne01,
|
| 6919 |
-
ne02,
|
| 6920 |
-
ne10,
|
| 6921 |
-
ne12,
|
| 6922 |
-
ne0,
|
| 6923 |
-
ne1,
|
| 6924 |
-
r2,
|
| 6925 |
-
r3,
|
| 6926 |
-
shared_values,
|
| 6927 |
-
tgpig,
|
| 6928 |
-
tiisg,
|
| 6929 |
-
sgitg);
|
| 6930 |
-
}
|
| 6931 |
-
|
| 6932 |
-
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
| 6933 |
-
kernel void kernel_mul_mv_id_iq2_s_f32(
|
| 6934 |
-
device const char * src0s,
|
| 6935 |
-
device const char * src1,
|
| 6936 |
-
device float * dst,
|
| 6937 |
-
device const char * ids,
|
| 6938 |
-
constant uint64_t & nbi1,
|
| 6939 |
-
constant int64_t & ne00,
|
| 6940 |
-
constant int64_t & ne01,
|
| 6941 |
-
constant int64_t & ne02,
|
| 6942 |
-
constant uint64_t & nb00,
|
| 6943 |
-
constant uint64_t & nb01,
|
| 6944 |
-
constant uint64_t & nb02,
|
| 6945 |
-
constant int64_t & ne10,
|
| 6946 |
-
constant int64_t & ne11,
|
| 6947 |
-
constant int64_t & ne12,
|
| 6948 |
-
constant int64_t & ne13,
|
| 6949 |
-
constant uint64_t & nb10,
|
| 6950 |
-
constant uint64_t & nb11,
|
| 6951 |
-
constant uint64_t & nb12,
|
| 6952 |
-
constant int64_t & ne0,
|
| 6953 |
-
constant int64_t & ne1,
|
| 6954 |
-
constant uint64_t & nb1,
|
| 6955 |
-
constant uint & r2,
|
| 6956 |
-
constant uint & r3,
|
| 6957 |
-
constant int & idx,
|
| 6958 |
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6959 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6960 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6961 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6962 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6963 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 6964 |
-
|
| 6965 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 6966 |
-
|
| 6967 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6968 |
-
device const char * src0 = src0s + id*nb02;
|
| 6969 |
-
|
| 6970 |
-
kernel_mul_mv_iq2_s_f32_impl(
|
| 6971 |
-
src0,
|
| 6972 |
-
(device const float *) (src1 + bid*nb11),
|
| 6973 |
-
dst + bid*ne0,
|
| 6974 |
-
ne00,
|
| 6975 |
-
ne01,
|
| 6976 |
-
ne02,
|
| 6977 |
-
ne10,
|
| 6978 |
-
ne12,
|
| 6979 |
-
ne0,
|
| 6980 |
-
ne1,
|
| 6981 |
-
r2,
|
| 6982 |
-
r3,
|
| 6983 |
-
shared_values,
|
| 6984 |
-
tgpig,
|
| 6985 |
-
tiisg,
|
| 6986 |
-
sgitg);
|
| 6987 |
-
}
|
| 6988 |
-
|
| 6989 |
-
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
| 6990 |
-
kernel void kernel_mul_mv_id_iq1_s_f32(
|
| 6991 |
-
device const char * src0s,
|
| 6992 |
-
device const char * src1,
|
| 6993 |
-
device float * dst,
|
| 6994 |
-
device const char * ids,
|
| 6995 |
-
constant uint64_t & nbi1,
|
| 6996 |
-
constant int64_t & ne00,
|
| 6997 |
-
constant int64_t & ne01,
|
| 6998 |
-
constant int64_t & ne02,
|
| 6999 |
-
constant uint64_t & nb00,
|
| 7000 |
-
constant uint64_t & nb01,
|
| 7001 |
-
constant uint64_t & nb02,
|
| 7002 |
-
constant int64_t & ne10,
|
| 7003 |
-
constant int64_t & ne11,
|
| 7004 |
-
constant int64_t & ne12,
|
| 7005 |
-
constant int64_t & ne13,
|
| 7006 |
-
constant uint64_t & nb10,
|
| 7007 |
-
constant uint64_t & nb11,
|
| 7008 |
-
constant uint64_t & nb12,
|
| 7009 |
-
constant int64_t & ne0,
|
| 7010 |
-
constant int64_t & ne1,
|
| 7011 |
-
constant uint64_t & nb1,
|
| 7012 |
-
constant uint & r2,
|
| 7013 |
-
constant uint & r3,
|
| 7014 |
-
constant int & idx,
|
| 7015 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7016 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 7017 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 7018 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7019 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7020 |
-
|
| 7021 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7022 |
-
|
| 7023 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7024 |
-
device const char * src0 = src0s + id*nb02;
|
| 7025 |
-
|
| 7026 |
-
kernel_mul_mv_iq1_s_f32_impl(
|
| 7027 |
-
src0,
|
| 7028 |
-
(device const float *) (src1 + bid*nb11),
|
| 7029 |
-
dst + bid*ne0,
|
| 7030 |
-
ne00,
|
| 7031 |
-
ne01,
|
| 7032 |
-
ne02,
|
| 7033 |
-
ne10,
|
| 7034 |
-
ne12,
|
| 7035 |
-
ne0,
|
| 7036 |
-
ne1,
|
| 7037 |
-
r2,
|
| 7038 |
-
r3,
|
| 7039 |
-
tgpig,
|
| 7040 |
-
tiisg,
|
| 7041 |
-
sgitg);
|
| 7042 |
-
}
|
| 7043 |
-
|
| 7044 |
-
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
| 7045 |
-
kernel void kernel_mul_mv_id_iq1_m_f32(
|
| 7046 |
-
device const char * src0s,
|
| 7047 |
-
device const char * src1,
|
| 7048 |
-
device float * dst,
|
| 7049 |
-
device const char * ids,
|
| 7050 |
-
constant uint64_t & nbi1,
|
| 7051 |
-
constant int64_t & ne00,
|
| 7052 |
-
constant int64_t & ne01,
|
| 7053 |
-
constant int64_t & ne02,
|
| 7054 |
-
constant uint64_t & nb00,
|
| 7055 |
-
constant uint64_t & nb01,
|
| 7056 |
-
constant uint64_t & nb02,
|
| 7057 |
-
constant int64_t & ne10,
|
| 7058 |
-
constant int64_t & ne11,
|
| 7059 |
-
constant int64_t & ne12,
|
| 7060 |
-
constant int64_t & ne13,
|
| 7061 |
-
constant uint64_t & nb10,
|
| 7062 |
-
constant uint64_t & nb11,
|
| 7063 |
-
constant uint64_t & nb12,
|
| 7064 |
-
constant int64_t & ne0,
|
| 7065 |
-
constant int64_t & ne1,
|
| 7066 |
-
constant uint64_t & nb1,
|
| 7067 |
-
constant uint & r2,
|
| 7068 |
-
constant uint & r3,
|
| 7069 |
-
constant int & idx,
|
| 7070 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7071 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 7072 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 7073 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7074 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7075 |
-
|
| 7076 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7077 |
-
|
| 7078 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7079 |
-
device const char * src0 = src0s + id*nb02;
|
| 7080 |
-
|
| 7081 |
-
kernel_mul_mv_iq1_m_f32_impl(
|
| 7082 |
-
src0,
|
| 7083 |
-
(device const float *) (src1 + bid*nb11),
|
| 7084 |
-
dst + bid*ne0,
|
| 7085 |
-
ne00,
|
| 7086 |
-
ne01,
|
| 7087 |
-
ne02,
|
| 7088 |
-
ne10,
|
| 7089 |
-
ne12,
|
| 7090 |
-
ne0,
|
| 7091 |
-
ne1,
|
| 7092 |
-
r2,
|
| 7093 |
-
r3,
|
| 7094 |
-
tgpig,
|
| 7095 |
-
tiisg,
|
| 7096 |
-
sgitg);
|
| 7097 |
-
}
|
| 7098 |
-
|
| 7099 |
-
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
| 7100 |
-
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
| 7101 |
-
device const char * src0s,
|
| 7102 |
-
device const char * src1,
|
| 7103 |
-
device float * dst,
|
| 7104 |
-
device const char * ids,
|
| 7105 |
-
constant uint64_t & nbi1,
|
| 7106 |
-
constant int64_t & ne00,
|
| 7107 |
-
constant int64_t & ne01,
|
| 7108 |
-
constant int64_t & ne02,
|
| 7109 |
-
constant uint64_t & nb00,
|
| 7110 |
-
constant uint64_t & nb01,
|
| 7111 |
-
constant uint64_t & nb02,
|
| 7112 |
-
constant int64_t & ne10,
|
| 7113 |
-
constant int64_t & ne11,
|
| 7114 |
-
constant int64_t & ne12,
|
| 7115 |
-
constant int64_t & ne13,
|
| 7116 |
-
constant uint64_t & nb10,
|
| 7117 |
-
constant uint64_t & nb11,
|
| 7118 |
-
constant uint64_t & nb12,
|
| 7119 |
-
constant int64_t & ne0,
|
| 7120 |
-
constant int64_t & ne1,
|
| 7121 |
-
constant uint64_t & nb1,
|
| 7122 |
-
constant uint & r2,
|
| 7123 |
-
constant uint & r3,
|
| 7124 |
-
constant int & idx,
|
| 7125 |
-
threadgroup float * shared_values [[threadgroup(0)]],
|
| 7126 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7127 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 7128 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 7129 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7130 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7131 |
-
|
| 7132 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7133 |
-
|
| 7134 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7135 |
-
device const char * src0 = src0s + id*nb02;
|
| 7136 |
-
|
| 7137 |
-
kernel_mul_mv_iq4_nl_f32_impl(
|
| 7138 |
-
src0,
|
| 7139 |
-
(device const float *) (src1 + bid*nb11),
|
| 7140 |
-
dst + bid*ne0,
|
| 7141 |
-
ne00,
|
| 7142 |
-
ne01,
|
| 7143 |
-
ne02,
|
| 7144 |
-
ne10,
|
| 7145 |
-
ne12,
|
| 7146 |
-
ne0,
|
| 7147 |
-
ne1,
|
| 7148 |
-
r2,
|
| 7149 |
-
r3,
|
| 7150 |
-
shared_values,
|
| 7151 |
-
tgpig,
|
| 7152 |
-
tiisg,
|
| 7153 |
-
sgitg);
|
| 7154 |
-
}
|
| 7155 |
-
|
| 7156 |
-
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
| 7157 |
-
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
| 7158 |
-
device const char * src0s,
|
| 7159 |
-
device const char * src1,
|
| 7160 |
-
device float * dst,
|
| 7161 |
-
device const char * ids,
|
| 7162 |
-
constant uint64_t & nbi1,
|
| 7163 |
-
constant int64_t & ne00,
|
| 7164 |
-
constant int64_t & ne01,
|
| 7165 |
-
constant int64_t & ne02,
|
| 7166 |
-
constant uint64_t & nb00,
|
| 7167 |
-
constant uint64_t & nb01,
|
| 7168 |
-
constant uint64_t & nb02,
|
| 7169 |
-
constant int64_t & ne10,
|
| 7170 |
-
constant int64_t & ne11,
|
| 7171 |
-
constant int64_t & ne12,
|
| 7172 |
-
constant int64_t & ne13,
|
| 7173 |
-
constant uint64_t & nb10,
|
| 7174 |
-
constant uint64_t & nb11,
|
| 7175 |
-
constant uint64_t & nb12,
|
| 7176 |
-
constant int64_t & ne0,
|
| 7177 |
-
constant int64_t & ne1,
|
| 7178 |
-
constant uint64_t & nb1,
|
| 7179 |
-
constant uint & r2,
|
| 7180 |
-
constant uint & r3,
|
| 7181 |
-
constant int & idx,
|
| 7182 |
-
threadgroup float * shared_values [[threadgroup(0)]],
|
| 7183 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 7184 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 7185 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 7186 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 7187 |
-
const int64_t bid = tgpig.z/(ne12*ne13);
|
| 7188 |
-
|
| 7189 |
-
tgpig.z = tgpig.z%(ne12*ne13);
|
| 7190 |
-
|
| 7191 |
-
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 7192 |
-
device const char * src0 = src0s + id*nb02;
|
| 7193 |
-
|
| 7194 |
-
#if QK_K == 64
|
| 7195 |
-
kernel_mul_mv_iq4_nl_f32_impl(
|
| 7196 |
-
#else
|
| 7197 |
-
kernel_mul_mv_iq4_xs_f32_impl(
|
| 7198 |
-
#endif
|
| 7199 |
-
src0,
|
| 7200 |
-
(device const float *) (src1 + bid*nb11),
|
| 7201 |
-
dst + bid*ne0,
|
| 7202 |
-
ne00,
|
| 7203 |
-
ne01,
|
| 7204 |
-
ne02,
|
| 7205 |
-
ne10,
|
| 7206 |
-
ne12,
|
| 7207 |
-
ne0,
|
| 7208 |
-
ne1,
|
| 7209 |
-
r2,
|
| 7210 |
-
r3,
|
| 7211 |
-
shared_values,
|
| 7212 |
-
tgpig,
|
| 7213 |
-
tiisg,
|
| 7214 |
-
sgitg);
|
| 7215 |
-
}
|
|
|
|
| 864 |
device const void * src0,
|
| 865 |
device const float * src1,
|
| 866 |
device float * dst,
|
| 867 |
+
constant int64_t & ne00,
|
| 868 |
+
constant int64_t & ne01,
|
| 869 |
+
constant int64_t & ne02,
|
| 870 |
+
constant int64_t & ne10,
|
| 871 |
+
constant int64_t & ne12,
|
| 872 |
+
constant int64_t & ne0,
|
| 873 |
+
constant int64_t & ne1,
|
| 874 |
+
constant uint & r2,
|
| 875 |
+
constant uint & r3,
|
| 876 |
+
threadgroup int8_t * shared_values,
|
| 877 |
uint3 tgpig, uint tiisg, uint sgitg) {
|
| 878 |
const int nb = ne00/QK4_0;
|
| 879 |
|
|
|
|
| 950 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 951 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 952 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 953 |
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 954 |
}
|
| 955 |
|
| 956 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
|
|
| 976 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 977 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 978 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 979 |
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 980 |
}
|
| 981 |
|
| 982 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
|
|
| 1002 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1003 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1004 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1005 |
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1006 |
}
|
| 1007 |
|
| 1008 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
|
|
| 1028 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1029 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1030 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1031 |
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1032 |
}
|
| 1033 |
|
| 1034 |
|
|
|
|
| 1047 |
constant int64_t & ne1,
|
| 1048 |
constant uint & r2,
|
| 1049 |
constant uint & r3,
|
| 1050 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 1051 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1052 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1053 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 1128 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1129 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1130 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1131 |
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1132 |
}
|
| 1133 |
|
| 1134 |
#define N_F32_F32 4
|
|
|
|
| 2718 |
constant int64_t & ne1,
|
| 2719 |
constant uint & r2,
|
| 2720 |
constant uint & r3,
|
| 2721 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 2722 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2723 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2724 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 2881 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2882 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2883 |
|
| 2884 |
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 2885 |
}
|
| 2886 |
|
| 2887 |
#if QK_K == 256
|
|
|
|
| 2898 |
constant int64_t & ne1,
|
| 2899 |
constant uint & r2,
|
| 2900 |
constant uint & r3,
|
| 2901 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 2902 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2903 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2904 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3057 |
constant int64_t & ne1,
|
| 3058 |
constant uint & r2,
|
| 3059 |
constant uint & r3,
|
| 3060 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3061 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3062 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3063 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3147 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3148 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3149 |
|
| 3150 |
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3151 |
}
|
| 3152 |
|
| 3153 |
#if QK_K == 256
|
|
|
|
| 3164 |
constant int64_t & ne1,
|
| 3165 |
constant uint & r2,
|
| 3166 |
constant uint & r3,
|
| 3167 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3168 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3169 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3170 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3278 |
constant int64_t & ne1,
|
| 3279 |
constant uint & r2,
|
| 3280 |
constant uint & r3,
|
| 3281 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3282 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3283 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3284 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3387 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3388 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3389 |
|
| 3390 |
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3391 |
}
|
| 3392 |
|
| 3393 |
void kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
| 3403 |
constant int64_t & ne1,
|
| 3404 |
constant uint & r2,
|
| 3405 |
constant uint & r3,
|
| 3406 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3407 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3408 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3409 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3594 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3595 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3596 |
|
| 3597 |
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3598 |
}
|
| 3599 |
|
| 3600 |
void kernel_mul_mv_q6_K_f32_impl(
|
|
|
|
| 3610 |
constant int64_t & ne1,
|
| 3611 |
constant uint & r2,
|
| 3612 |
constant uint & r3,
|
| 3613 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 3614 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3615 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3616 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3729 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3730 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3731 |
|
| 3732 |
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3733 |
}
|
| 3734 |
|
| 3735 |
// ======================= "True" 2-bit
|
|
|
|
| 4412 |
constant int64_t & ne1,
|
| 4413 |
constant uint & r2,
|
| 4414 |
constant uint & r3,
|
| 4415 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4416 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4417 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4418 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 4502 |
constant int64_t & ne1,
|
| 4503 |
constant uint & r2,
|
| 4504 |
constant uint & r3,
|
| 4505 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4506 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4507 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4508 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 4611 |
constant int64_t & ne1,
|
| 4612 |
constant uint & r2,
|
| 4613 |
constant uint & r3,
|
| 4614 |
+
threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
|
| 4615 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4616 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4617 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4618 |
|
| 4619 |
+
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
| 4620 |
const int nb = ne00/QK4_NL;
|
| 4621 |
const int r0 = tgpig.x;
|
| 4622 |
const int r1 = tgpig.y;
|
|
|
|
| 4706 |
constant int64_t & ne1,
|
| 4707 |
constant uint & r2,
|
| 4708 |
constant uint & r3,
|
| 4709 |
+
threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
|
| 4710 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4711 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4712 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4713 |
+
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
| 4714 |
const int nb = ne00/QK_K;
|
| 4715 |
const int r0 = tgpig.x;
|
| 4716 |
const int r1 = tgpig.y;
|
|
|
|
| 4813 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4814 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4815 |
|
| 4816 |
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 4817 |
}
|
| 4818 |
|
| 4819 |
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
|
|
|
| 4841 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4842 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4843 |
|
| 4844 |
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 4845 |
}
|
| 4846 |
|
| 4847 |
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
|
|
| 4865 |
constant int64_t & ne1,
|
| 4866 |
constant uint & r2,
|
| 4867 |
constant uint & r3,
|
| 4868 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4869 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4870 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4871 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 4894 |
constant int64_t & ne1,
|
| 4895 |
constant uint & r2,
|
| 4896 |
constant uint & r3,
|
| 4897 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 4898 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 4899 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4900 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 6041 |
// matrix-vector multiplication
|
| 6042 |
//
|
| 6043 |
|
| 6044 |
+
typedef void (kernel_mul_mv_impl_t)(
|
| 6045 |
+
device const char * src0,
|
| 6046 |
+
device const char * src1,
|
| 6047 |
+
device float * dst,
|
| 6048 |
+
constant int64_t & ne00,
|
| 6049 |
+
constant int64_t & ne01,
|
| 6050 |
+
constant int64_t & ne02,
|
| 6051 |
+
constant uint64_t & nb00,
|
| 6052 |
+
constant uint64_t & nb01,
|
| 6053 |
+
constant uint64_t & nb02,
|
| 6054 |
+
constant int64_t & ne10,
|
| 6055 |
+
constant int64_t & ne11,
|
| 6056 |
+
constant int64_t & ne12,
|
| 6057 |
+
constant uint64_t & nb10,
|
| 6058 |
+
constant uint64_t & nb11,
|
| 6059 |
+
constant uint64_t & nb12,
|
| 6060 |
+
constant int64_t & ne0,
|
| 6061 |
+
constant int64_t & ne1,
|
| 6062 |
+
constant uint & r2,
|
| 6063 |
+
constant uint & r3,
|
| 6064 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6065 |
+
uint tiisg[[thread_index_in_simdgroup]]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6066 |
|
| 6067 |
+
typedef void (kernel_mul_mv2_impl_t)(
|
| 6068 |
+
device const void * src0,
|
| 6069 |
+
device const float * src1,
|
| 6070 |
+
device float * dst,
|
| 6071 |
+
constant int64_t & ne00,
|
| 6072 |
+
constant int64_t & ne01,
|
| 6073 |
+
constant int64_t & ne02,
|
| 6074 |
+
constant int64_t & ne10,
|
| 6075 |
+
constant int64_t & ne12,
|
| 6076 |
+
constant int64_t & ne0,
|
| 6077 |
+
constant int64_t & ne1,
|
| 6078 |
+
constant uint & r2,
|
| 6079 |
+
constant uint & r3,
|
| 6080 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6081 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6082 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 6083 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6084 |
|
| 6085 |
+
template<kernel_mul_mv_impl_t impl_fn>
|
| 6086 |
+
void mmv_fn(
|
| 6087 |
+
device const char * src0,
|
| 6088 |
device const char * src1,
|
| 6089 |
device float * dst,
|
|
|
|
|
|
|
| 6090 |
constant int64_t & ne00,
|
| 6091 |
constant int64_t & ne01,
|
| 6092 |
constant int64_t & ne02,
|
|
|
|
| 6105 |
constant uint64_t & nb1,
|
| 6106 |
constant uint & r2,
|
| 6107 |
constant uint & r3,
|
| 6108 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6109 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6110 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6111 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6112 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6113 |
+
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6114 |
}
|
| 6115 |
|
| 6116 |
+
template<kernel_mul_mv2_impl_t impl_fn>
|
| 6117 |
+
void mmv_fn(
|
| 6118 |
+
device const char * src0,
|
| 6119 |
device const char * src1,
|
| 6120 |
device float * dst,
|
|
|
|
|
|
|
| 6121 |
constant int64_t & ne00,
|
| 6122 |
constant int64_t & ne01,
|
| 6123 |
constant int64_t & ne02,
|
|
|
|
| 6136 |
constant uint64_t & nb1,
|
| 6137 |
constant uint & r2,
|
| 6138 |
constant uint & r3,
|
| 6139 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6140 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6141 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6142 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6143 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6144 |
+
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6145 |
}
|
| 6146 |
|
| 6147 |
+
typedef void (mul_mv_impl_fn_t)(
|
| 6148 |
+
device const char * src0,
|
|
|
|
| 6149 |
device const char * src1,
|
| 6150 |
device float * dst,
|
|
|
|
|
|
|
| 6151 |
constant int64_t & ne00,
|
| 6152 |
constant int64_t & ne01,
|
| 6153 |
constant int64_t & ne02,
|
|
|
|
| 6166 |
constant uint64_t & nb1,
|
| 6167 |
constant uint & r2,
|
| 6168 |
constant uint & r3,
|
| 6169 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6170 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6171 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6172 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6173 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6174 |
|
| 6175 |
+
template<mul_mv_impl_fn_t impl_fn>
|
| 6176 |
+
kernel void kernel_mul_mv_id(
|
| 6177 |
device const char * src0s,
|
| 6178 |
device const char * src1,
|
| 6179 |
device float * dst,
|
|
|
|
| 6198 |
constant uint & r2,
|
| 6199 |
constant uint & r3,
|
| 6200 |
constant int & idx,
|
| 6201 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6202 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6203 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6204 |
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
| 6210 |
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
| 6211 |
device const char * src0 = src0s + id*nb02;
|
| 6212 |
|
| 6213 |
+
impl_fn(
|
| 6214 |
src0,
|
| 6215 |
+
src1 + bid*nb11,
|
| 6216 |
+
dst + bid*ne0,
|
| 6217 |
ne00,
|
| 6218 |
ne01,
|
| 6219 |
ne02,
|
| 6220 |
+
nb00,
|
| 6221 |
+
nb01,
|
| 6222 |
+
nb02,
|
| 6223 |
ne10,
|
| 6224 |
+
ne11,
|
| 6225 |
ne12,
|
| 6226 |
+
ne13,
|
| 6227 |
+
nb10,
|
| 6228 |
+
nb11,
|
| 6229 |
+
nb12,
|
| 6230 |
ne0,
|
| 6231 |
ne1,
|
| 6232 |
+
nb1,
|
| 6233 |
r2,
|
| 6234 |
r3,
|
| 6235 |
+
shared_values,
|
| 6236 |
tgpig,
|
| 6237 |
+
tiitg,
|
| 6238 |
tiisg,
|
| 6239 |
sgitg);
|
| 6240 |
}
|
| 6241 |
|
| 6242 |
+
typedef void (kernel_mul_mv_id_t)(
|
|
|
|
| 6243 |
device const char * src0s,
|
| 6244 |
device const char * src1,
|
| 6245 |
device float * dst,
|
|
|
|
| 6264 |
constant uint & r2,
|
| 6265 |
constant uint & r3,
|
| 6266 |
constant int & idx,
|
| 6267 |
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6268 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6269 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6270 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6271 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
| 6272 |
+
|
| 6273 |
+
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
| 6274 |
+
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
| 6275 |
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
| 6276 |
+
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6277 |
+
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6278 |
+
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6279 |
+
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6280 |
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
| 6281 |
+
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
| 6282 |
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
| 6283 |
+
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
| 6284 |
+
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
| 6285 |
+
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
| 6286 |
+
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
| 6287 |
+
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
| 6288 |
+
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
|
| 6289 |
+
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
|
| 6290 |
+
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
| 6291 |
+
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
| 6292 |
+
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
| 6293 |
+
#if QK_K != 64
|
| 6294 |
+
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
| 6295 |
+
#endif
|
| 6296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml.c
CHANGED
|
@@ -11074,7 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11074 |
}
|
| 11075 |
|
| 11076 |
// initialize matrix_row_counts
|
| 11077 |
-
GGML_ASSERT(wdata == wdata_src1_end);
|
| 11078 |
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
| 11079 |
|
| 11080 |
// group rows by src0 matrix
|
|
|
|
| 11074 |
}
|
| 11075 |
|
| 11076 |
// initialize matrix_row_counts
|
|
|
|
| 11077 |
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
| 11078 |
|
| 11079 |
// group rows by src0 matrix
|