Spaces:
Running
Running
metal : enable shader debugging (cmake option) (llama/4705)
Browse files* ggml : disable fast-math for Metal (cmake build only)
ggml-ci
* metal : fix Metal API debug warnings
* cmake : add -fno-inline for Metal build (llama/4545)
* metal : fix API debug warnings
* metal : fix compile warnings
* metal : use uint64_t for strides
* cmake : rename option to LLAMA_METAL_SHADER_DEBUG
* metal : fix mat-vec Q8_0 kernel for BS > 1
* metal : normalize mat-vec kernel signatures
* cmake : respect LLAMA_QKK_64 option
* metal : fix mat-vec Q4_K kernel for QK_K == 64
ggml-ci
- ggml-metal.m +19 -9
- ggml-metal.metal +265 -210
ggml-metal.m
CHANGED
|
@@ -257,13 +257,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 257 |
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
| 258 |
#endif
|
| 259 |
NSError * error = nil;
|
| 260 |
-
NSString * libPath = [bundle pathForResource:@"
|
| 261 |
if (libPath != nil) {
|
|
|
|
| 262 |
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
| 263 |
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
| 264 |
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
| 265 |
} else {
|
| 266 |
-
GGML_METAL_LOG_INFO("%s:
|
| 267 |
|
| 268 |
NSString * sourcePath;
|
| 269 |
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
|
@@ -291,6 +292,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 291 |
options = [MTLCompileOptions new];
|
| 292 |
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
| 293 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
| 295 |
}
|
| 296 |
|
|
@@ -1230,7 +1238,7 @@ void ggml_metal_graph_compute(
|
|
| 1230 |
// not sure how to avoid this
|
| 1231 |
// TODO: make a simpler cpy_bytes kernel
|
| 1232 |
|
| 1233 |
-
const int nth = MIN(
|
| 1234 |
|
| 1235 |
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
| 1236 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1285,7 +1293,7 @@ void ggml_metal_graph_compute(
|
|
| 1285 |
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
| 1286 |
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
| 1287 |
|
| 1288 |
-
const int nth = MIN(
|
| 1289 |
|
| 1290 |
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1291 |
} break;
|
|
@@ -1785,8 +1793,9 @@ void ggml_metal_graph_compute(
|
|
| 1785 |
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
| 1786 |
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
| 1787 |
// TODO: how to make this an array? read Metal docs
|
| 1788 |
-
for (int j = 0; j <
|
| 1789 |
-
|
|
|
|
| 1790 |
|
| 1791 |
size_t offs_src_cur = 0;
|
| 1792 |
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
@@ -1909,8 +1918,9 @@ void ggml_metal_graph_compute(
|
|
| 1909 |
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
| 1910 |
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
| 1911 |
// TODO: how to make this an array? read Metal docs
|
| 1912 |
-
for (int j = 0; j <
|
| 1913 |
-
|
|
|
|
| 1914 |
|
| 1915 |
size_t offs_src_cur = 0;
|
| 1916 |
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
@@ -2229,7 +2239,7 @@ void ggml_metal_graph_compute(
|
|
| 2229 |
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
| 2230 |
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
| 2231 |
|
| 2232 |
-
const int nth = MIN(
|
| 2233 |
|
| 2234 |
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2235 |
} break;
|
|
|
|
| 257 |
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
| 258 |
#endif
|
| 259 |
NSError * error = nil;
|
| 260 |
+
NSString * libPath = [bundle pathForResource:@"ggml" ofType:@"metallib"];
|
| 261 |
if (libPath != nil) {
|
| 262 |
+
// pre-compiled library found
|
| 263 |
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
| 264 |
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
| 265 |
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
| 266 |
} else {
|
| 267 |
+
GGML_METAL_LOG_INFO("%s: ggml.metallib not found, loading from source\n", __func__);
|
| 268 |
|
| 269 |
NSString * sourcePath;
|
| 270 |
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
|
|
|
| 292 |
options = [MTLCompileOptions new];
|
| 293 |
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
| 294 |
#endif
|
| 295 |
+
// try to disable fast-math
|
| 296 |
+
// NOTE: this seems to have no effect whatsoever
|
| 297 |
+
// instead, in order to disable fast-math, we have to build ggml.metallib from the command line
|
| 298 |
+
// using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
|
| 299 |
+
// and go through the "pre-compiled library found" path above
|
| 300 |
+
//[options setFastMathEnabled:false];
|
| 301 |
+
|
| 302 |
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
| 303 |
}
|
| 304 |
|
|
|
|
| 1238 |
// not sure how to avoid this
|
| 1239 |
// TODO: make a simpler cpy_bytes kernel
|
| 1240 |
|
| 1241 |
+
const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
|
| 1242 |
|
| 1243 |
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
| 1244 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 1293 |
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
| 1294 |
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
| 1295 |
|
| 1296 |
+
const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
|
| 1297 |
|
| 1298 |
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1299 |
} break;
|
|
|
|
| 1793 |
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
| 1794 |
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
| 1795 |
// TODO: how to make this an array? read Metal docs
|
| 1796 |
+
for (int j = 0; j < 8; ++j) {
|
| 1797 |
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
| 1798 |
+
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
| 1799 |
|
| 1800 |
size_t offs_src_cur = 0;
|
| 1801 |
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
|
|
| 1918 |
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
| 1919 |
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
| 1920 |
// TODO: how to make this an array? read Metal docs
|
| 1921 |
+
for (int j = 0; j < 8; ++j) {
|
| 1922 |
+
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
| 1923 |
+
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
| 1924 |
|
| 1925 |
size_t offs_src_cur = 0;
|
| 1926 |
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
|
|
| 2239 |
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
| 2240 |
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
| 2241 |
|
| 2242 |
+
const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
|
| 2243 |
|
| 2244 |
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2245 |
} break;
|
ggml-metal.metal
CHANGED
|
@@ -59,26 +59,26 @@ kernel void kernel_add(
|
|
| 59 |
constant int64_t & ne01,
|
| 60 |
constant int64_t & ne02,
|
| 61 |
constant int64_t & ne03,
|
| 62 |
-
constant
|
| 63 |
-
constant
|
| 64 |
-
constant
|
| 65 |
-
constant
|
| 66 |
constant int64_t & ne10,
|
| 67 |
constant int64_t & ne11,
|
| 68 |
constant int64_t & ne12,
|
| 69 |
constant int64_t & ne13,
|
| 70 |
-
constant
|
| 71 |
-
constant
|
| 72 |
-
constant
|
| 73 |
-
constant
|
| 74 |
constant int64_t & ne0,
|
| 75 |
constant int64_t & ne1,
|
| 76 |
constant int64_t & ne2,
|
| 77 |
constant int64_t & ne3,
|
| 78 |
-
constant
|
| 79 |
-
constant
|
| 80 |
-
constant
|
| 81 |
-
constant
|
| 82 |
constant int64_t & offs,
|
| 83 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 84 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -109,26 +109,26 @@ kernel void kernel_mul(
|
|
| 109 |
constant int64_t & ne01,
|
| 110 |
constant int64_t & ne02,
|
| 111 |
constant int64_t & ne03,
|
| 112 |
-
constant
|
| 113 |
-
constant
|
| 114 |
-
constant
|
| 115 |
-
constant
|
| 116 |
constant int64_t & ne10,
|
| 117 |
constant int64_t & ne11,
|
| 118 |
constant int64_t & ne12,
|
| 119 |
constant int64_t & ne13,
|
| 120 |
-
constant
|
| 121 |
-
constant
|
| 122 |
-
constant
|
| 123 |
-
constant
|
| 124 |
constant int64_t & ne0,
|
| 125 |
constant int64_t & ne1,
|
| 126 |
constant int64_t & ne2,
|
| 127 |
constant int64_t & ne3,
|
| 128 |
-
constant
|
| 129 |
-
constant
|
| 130 |
-
constant
|
| 131 |
-
constant
|
| 132 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 133 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 134 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -158,26 +158,26 @@ kernel void kernel_div(
|
|
| 158 |
constant int64_t & ne01,
|
| 159 |
constant int64_t & ne02,
|
| 160 |
constant int64_t & ne03,
|
| 161 |
-
constant
|
| 162 |
-
constant
|
| 163 |
-
constant
|
| 164 |
-
constant
|
| 165 |
constant int64_t & ne10,
|
| 166 |
constant int64_t & ne11,
|
| 167 |
constant int64_t & ne12,
|
| 168 |
constant int64_t & ne13,
|
| 169 |
-
constant
|
| 170 |
-
constant
|
| 171 |
-
constant
|
| 172 |
-
constant
|
| 173 |
constant int64_t & ne0,
|
| 174 |
constant int64_t & ne1,
|
| 175 |
constant int64_t & ne2,
|
| 176 |
constant int64_t & ne3,
|
| 177 |
-
constant
|
| 178 |
-
constant
|
| 179 |
-
constant
|
| 180 |
-
constant
|
| 181 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 182 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 183 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
|
|
| 205 |
device const float4 * src0,
|
| 206 |
device const float4 * src1,
|
| 207 |
device float4 * dst,
|
| 208 |
-
constant
|
| 209 |
uint tpig[[thread_position_in_grid]]) {
|
| 210 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 211 |
}
|
|
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
|
|
| 214 |
device const float4 * src0,
|
| 215 |
device const float4 * src1,
|
| 216 |
device float4 * dst,
|
| 217 |
-
constant
|
| 218 |
uint tpig[[thread_position_in_grid]]) {
|
| 219 |
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
| 220 |
}
|
|
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
|
|
| 223 |
device const float4 * src0,
|
| 224 |
device const float4 * src1,
|
| 225 |
device float4 * dst,
|
| 226 |
-
constant
|
| 227 |
uint tpig[[thread_position_in_grid]]) {
|
| 228 |
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
| 229 |
}
|
|
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
|
|
| 307 |
constant int64_t & ne01,
|
| 308 |
constant int64_t & ne02,
|
| 309 |
constant int64_t & ne03,
|
| 310 |
-
constant
|
| 311 |
-
constant
|
| 312 |
-
constant
|
| 313 |
-
constant
|
| 314 |
constant int64_t & ne10,
|
| 315 |
constant int64_t & ne11,
|
| 316 |
constant int64_t & ne12,
|
| 317 |
constant int64_t & ne13,
|
| 318 |
-
constant
|
| 319 |
-
constant
|
| 320 |
-
constant
|
| 321 |
-
constant
|
| 322 |
constant int64_t & ne0,
|
| 323 |
constant int64_t & ne1,
|
| 324 |
constant int64_t & ne2,
|
| 325 |
constant int64_t & ne3,
|
| 326 |
-
constant
|
| 327 |
-
constant
|
| 328 |
-
constant
|
| 329 |
-
constant
|
| 330 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 331 |
int64_t i3 = tpig.z;
|
| 332 |
int64_t i2 = tpig.y;
|
|
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
| 920 |
device const float * src1,
|
| 921 |
device float * dst,
|
| 922 |
constant int64_t & ne00,
|
| 923 |
-
constant int64_t & ne01
|
| 924 |
-
constant int64_t & ne02
|
| 925 |
-
constant
|
| 926 |
-
constant
|
| 927 |
-
constant
|
| 928 |
-
constant int64_t &
|
| 929 |
-
constant
|
| 930 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 932 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 933 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
| 939 |
device const float * src1,
|
| 940 |
device float * dst,
|
| 941 |
constant int64_t & ne00,
|
| 942 |
-
constant int64_t & ne01
|
| 943 |
-
constant int64_t & ne02
|
| 944 |
-
constant
|
| 945 |
-
constant
|
| 946 |
-
constant
|
| 947 |
-
constant int64_t &
|
| 948 |
-
constant
|
| 949 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 951 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 952 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
| 958 |
device const float * src1,
|
| 959 |
device float * dst,
|
| 960 |
constant int64_t & ne00,
|
| 961 |
-
constant int64_t & ne01
|
| 962 |
-
constant int64_t & ne02
|
| 963 |
-
constant
|
| 964 |
-
constant
|
| 965 |
-
constant
|
| 966 |
-
constant int64_t &
|
| 967 |
-
constant
|
| 968 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 969 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 970 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 971 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
| 977 |
device const float * src1,
|
| 978 |
device float * dst,
|
| 979 |
constant int64_t & ne00,
|
| 980 |
-
constant int64_t & ne01
|
| 981 |
-
constant int64_t & ne02
|
| 982 |
-
constant
|
| 983 |
-
constant
|
| 984 |
-
constant
|
| 985 |
-
constant int64_t &
|
| 986 |
-
constant
|
| 987 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 988 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 989 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 990 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 1071 |
constant int64_t & ne00,
|
| 1072 |
constant int64_t & ne01,
|
| 1073 |
constant int64_t & ne02,
|
|
|
|
|
|
|
|
|
|
| 1074 |
constant int64_t & ne10,
|
|
|
|
| 1075 |
constant int64_t & ne12,
|
|
|
|
|
|
|
|
|
|
| 1076 |
constant int64_t & ne0,
|
| 1077 |
constant int64_t & ne1,
|
| 1078 |
-
constant uint & r2
|
| 1079 |
-
constant uint & r3
|
| 1080 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1081 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1082 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
| 1182 |
constant uint64_t & nb12,
|
| 1183 |
constant int64_t & ne0,
|
| 1184 |
constant int64_t & ne1,
|
| 1185 |
-
constant uint & r2
|
| 1186 |
-
constant uint & r3
|
| 1187 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1188 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1189 |
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
| 1209 |
constant uint64_t & nb12,
|
| 1210 |
constant int64_t & ne0,
|
| 1211 |
constant int64_t & ne1,
|
| 1212 |
-
constant uint & r2
|
| 1213 |
-
constant uint & r3
|
| 1214 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1215 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1216 |
|
|
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
| 1346 |
constant uint64_t & nb12,
|
| 1347 |
constant int64_t & ne0,
|
| 1348 |
constant int64_t & ne1,
|
| 1349 |
-
constant uint & r2
|
| 1350 |
-
constant uint & r3
|
| 1351 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1352 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1353 |
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
| 1452 |
constant uint64_t & nb12,
|
| 1453 |
constant int64_t & ne0,
|
| 1454 |
constant int64_t & ne1,
|
| 1455 |
-
constant uint & r2
|
| 1456 |
-
constant uint & r3
|
| 1457 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1458 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1459 |
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
| 1478 |
constant uint64_t & nb12,
|
| 1479 |
constant int64_t & ne0,
|
| 1480 |
constant int64_t & ne1,
|
| 1481 |
-
constant uint & r2
|
| 1482 |
-
constant uint & r3
|
| 1483 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1484 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1485 |
|
|
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
|
|
| 1543 |
const int64_t i3 = n / (ne2*ne1*ne0);
|
| 1544 |
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
| 1545 |
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 1546 |
-
|
|
|
|
| 1547 |
const int64_t k = i3*ne3 + i2;
|
| 1548 |
|
| 1549 |
float m_k;
|
|
@@ -2410,22 +2446,6 @@ typedef struct {
|
|
| 2410 |
} block_q6_K;
|
| 2411 |
// 210 bytes / block
|
| 2412 |
|
| 2413 |
-
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
| 2414 |
-
uchar4 r;
|
| 2415 |
-
if (j < 4) {
|
| 2416 |
-
r[0] = q[j+0] & 63;
|
| 2417 |
-
r[2] = q[j+1] & 63;
|
| 2418 |
-
r[1] = q[j+4] & 63;
|
| 2419 |
-
r[3] = q[j+5] & 63;
|
| 2420 |
-
} else {
|
| 2421 |
-
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
| 2422 |
-
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
|
| 2423 |
-
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
| 2424 |
-
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
|
| 2425 |
-
}
|
| 2426 |
-
return r;
|
| 2427 |
-
}
|
| 2428 |
-
|
| 2429 |
//====================================== dot products =========================
|
| 2430 |
|
| 2431 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
| 2584 |
device const float * src1,
|
| 2585 |
device float * dst,
|
| 2586 |
constant int64_t & ne00,
|
| 2587 |
-
constant int64_t & ne01
|
| 2588 |
-
constant int64_t & ne02
|
| 2589 |
-
constant
|
| 2590 |
-
constant
|
| 2591 |
-
constant
|
| 2592 |
-
constant int64_t &
|
| 2593 |
-
constant
|
| 2594 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2595 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2596 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2597 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 2841 |
device const float * src1,
|
| 2842 |
device float * dst,
|
| 2843 |
constant int64_t & ne00,
|
| 2844 |
-
constant int64_t & ne01
|
| 2845 |
-
constant int64_t & ne02
|
| 2846 |
-
constant
|
| 2847 |
-
constant
|
| 2848 |
-
constant
|
| 2849 |
-
constant int64_t &
|
| 2850 |
-
constant
|
| 2851 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2852 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2853 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2854 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 2984 |
constant uint & r2,
|
| 2985 |
constant uint & r3,
|
| 2986 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2987 |
-
uint
|
| 2988 |
-
uint
|
| 2989 |
|
| 2990 |
const int ix = tiisg/4; // 0...7
|
| 2991 |
const int it = tiisg%4; // 0...3
|
|
@@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 2994 |
const int r0 = tgpig.x;
|
| 2995 |
const int r1 = tgpig.y;
|
| 2996 |
const int im = tgpig.z;
|
| 2997 |
-
const int first_row =
|
| 2998 |
const int ib_row = first_row * nb;
|
| 2999 |
|
| 3000 |
const uint i12 = im%ne12;
|
|
@@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3060 |
for (int row = 0; row < N_DST; ++row) {
|
| 3061 |
all_sum = simd_sum(sumf[row]);
|
| 3062 |
if (tiisg == 0) {
|
| 3063 |
-
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
| 3064 |
}
|
| 3065 |
}
|
| 3066 |
}
|
|
@@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
| 3072 |
device const float * src1,
|
| 3073 |
device float * dst,
|
| 3074 |
constant int64_t & ne00,
|
| 3075 |
-
constant int64_t & ne01
|
| 3076 |
-
constant int64_t & ne02
|
| 3077 |
-
constant
|
| 3078 |
-
constant
|
| 3079 |
-
constant
|
| 3080 |
-
constant int64_t &
|
| 3081 |
-
constant
|
| 3082 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3083 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3084 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3085 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
| 3271 |
device const float * src1,
|
| 3272 |
device float * dst,
|
| 3273 |
constant int64_t & ne00,
|
| 3274 |
-
constant int64_t & ne01
|
| 3275 |
-
constant int64_t & ne02
|
| 3276 |
-
constant
|
| 3277 |
-
constant
|
| 3278 |
-
constant
|
| 3279 |
-
constant int64_t &
|
| 3280 |
-
constant
|
| 3281 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3282 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3283 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3284 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 3398 |
device const float * src1,
|
| 3399 |
device float * dst,
|
| 3400 |
constant int64_t & ne00,
|
| 3401 |
-
constant int64_t & ne01
|
| 3402 |
-
constant int64_t & ne02
|
| 3403 |
-
constant
|
| 3404 |
-
constant
|
| 3405 |
-
constant
|
| 3406 |
-
constant int64_t &
|
| 3407 |
-
constant
|
| 3408 |
-
constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3409 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3410 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3411 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
| 3523 |
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 3524 |
const half d = xb->d;
|
| 3525 |
|
| 3526 |
-
for (int i=0;i<16;i++) {
|
| 3527 |
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
| 3528 |
}
|
| 3529 |
}
|
|
@@ -3792,12 +3847,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
| 3792 |
device float * dst,
|
| 3793 |
constant int64_t & ne00,
|
| 3794 |
constant int64_t & ne02,
|
| 3795 |
-
constant
|
| 3796 |
-
constant
|
| 3797 |
constant int64_t & ne12,
|
| 3798 |
-
constant
|
| 3799 |
-
constant
|
| 3800 |
-
constant
|
| 3801 |
constant int64_t & ne0,
|
| 3802 |
constant int64_t & ne1,
|
| 3803 |
constant uint & r2,
|
|
@@ -3924,12 +3979,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 3924 |
device float * dst,
|
| 3925 |
constant int64_t & ne00,
|
| 3926 |
constant int64_t & ne02,
|
| 3927 |
-
constant
|
| 3928 |
-
constant
|
| 3929 |
constant int64_t & ne12,
|
| 3930 |
-
constant
|
| 3931 |
-
constant
|
| 3932 |
-
constant
|
| 3933 |
constant int64_t & ne0,
|
| 3934 |
constant int64_t & ne1,
|
| 3935 |
constant uint & r2,
|
|
@@ -3965,19 +4020,19 @@ kernel void kernel_mul_mm_id(
|
|
| 3965 |
device const uchar * ids,
|
| 3966 |
device const uchar * src1,
|
| 3967 |
device uchar * dst,
|
| 3968 |
-
constant
|
| 3969 |
constant int64_t & ne00,
|
| 3970 |
constant int64_t & ne02,
|
| 3971 |
-
constant
|
| 3972 |
-
constant
|
| 3973 |
constant int64_t & ne12,
|
| 3974 |
constant int64_t & ne13,
|
| 3975 |
-
constant
|
| 3976 |
-
constant
|
| 3977 |
-
constant
|
| 3978 |
constant int64_t & ne0,
|
| 3979 |
constant int64_t & ne1,
|
| 3980 |
-
constant
|
| 3981 |
constant uint & r2,
|
| 3982 |
constant uint & r3,
|
| 3983 |
constant int & idx,
|
|
@@ -4070,12 +4125,12 @@ typedef void (mat_mm_t)(
|
|
| 4070 |
device float * dst,
|
| 4071 |
constant int64_t & ne00,
|
| 4072 |
constant int64_t & ne02,
|
| 4073 |
-
constant
|
| 4074 |
-
constant
|
| 4075 |
constant int64_t & ne12,
|
| 4076 |
-
constant
|
| 4077 |
-
constant
|
| 4078 |
-
constant
|
| 4079 |
constant int64_t & ne0,
|
| 4080 |
constant int64_t & ne1,
|
| 4081 |
constant uint & r2,
|
|
@@ -4104,19 +4159,19 @@ typedef void (mat_mm_id_t)(
|
|
| 4104 |
device const uchar * ids,
|
| 4105 |
device const uchar * src1,
|
| 4106 |
device uchar * dst,
|
| 4107 |
-
constant
|
| 4108 |
constant int64_t & ne00,
|
| 4109 |
constant int64_t & ne02,
|
| 4110 |
-
constant
|
| 4111 |
-
constant
|
| 4112 |
constant int64_t & ne12,
|
| 4113 |
constant int64_t & ne13,
|
| 4114 |
-
constant
|
| 4115 |
-
constant
|
| 4116 |
-
constant
|
| 4117 |
constant int64_t & ne0,
|
| 4118 |
constant int64_t & ne1,
|
| 4119 |
-
constant
|
| 4120 |
constant uint & r2,
|
| 4121 |
constant uint & r3,
|
| 4122 |
constant int & idx,
|
|
@@ -4153,7 +4208,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
| 4153 |
device const char * ids,
|
| 4154 |
device const char * src1,
|
| 4155 |
device uchar * dst,
|
| 4156 |
-
constant
|
| 4157 |
constant int64_t & ne00,
|
| 4158 |
constant int64_t & ne01,
|
| 4159 |
constant int64_t & ne02,
|
|
@@ -4169,7 +4224,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
| 4169 |
constant uint64_t & nb12,
|
| 4170 |
constant int64_t & ne0,
|
| 4171 |
constant int64_t & ne1,
|
| 4172 |
-
constant
|
| 4173 |
constant uint & r2,
|
| 4174 |
constant uint & r3,
|
| 4175 |
constant int & idx,
|
|
@@ -4222,7 +4277,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
| 4222 |
device const char * ids,
|
| 4223 |
device const char * src1,
|
| 4224 |
device uchar * dst,
|
| 4225 |
-
constant
|
| 4226 |
constant int64_t & ne00,
|
| 4227 |
constant int64_t & ne01,
|
| 4228 |
constant int64_t & ne02,
|
|
@@ -4238,7 +4293,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
| 4238 |
constant uint64_t & nb12,
|
| 4239 |
constant int64_t & ne0,
|
| 4240 |
constant int64_t & ne1,
|
| 4241 |
-
constant
|
| 4242 |
constant uint & r2,
|
| 4243 |
constant uint & r3,
|
| 4244 |
constant int & idx,
|
|
@@ -4291,7 +4346,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
| 4291 |
device const char * ids,
|
| 4292 |
device const char * src1,
|
| 4293 |
device uchar * dst,
|
| 4294 |
-
constant
|
| 4295 |
constant int64_t & ne00,
|
| 4296 |
constant int64_t & ne01,
|
| 4297 |
constant int64_t & ne02,
|
|
@@ -4307,7 +4362,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
| 4307 |
constant uint64_t & nb12,
|
| 4308 |
constant int64_t & ne0,
|
| 4309 |
constant int64_t & ne1,
|
| 4310 |
-
constant
|
| 4311 |
constant uint & r2,
|
| 4312 |
constant uint & r3,
|
| 4313 |
constant int & idx,
|
|
@@ -4354,7 +4409,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
| 4354 |
device const char * ids,
|
| 4355 |
device const char * src1,
|
| 4356 |
device uchar * dst,
|
| 4357 |
-
constant
|
| 4358 |
constant int64_t & ne00,
|
| 4359 |
constant int64_t & ne01,
|
| 4360 |
constant int64_t & ne02,
|
|
@@ -4370,7 +4425,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
| 4370 |
constant uint64_t & nb12,
|
| 4371 |
constant int64_t & ne0,
|
| 4372 |
constant int64_t & ne1,
|
| 4373 |
-
constant
|
| 4374 |
constant uint & r2,
|
| 4375 |
constant uint & r3,
|
| 4376 |
constant int & idx,
|
|
@@ -4417,7 +4472,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
| 4417 |
device const char * ids,
|
| 4418 |
device const char * src1,
|
| 4419 |
device uchar * dst,
|
| 4420 |
-
constant
|
| 4421 |
constant int64_t & ne00,
|
| 4422 |
constant int64_t & ne01,
|
| 4423 |
constant int64_t & ne02,
|
|
@@ -4433,7 +4488,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
| 4433 |
constant uint64_t & nb12,
|
| 4434 |
constant int64_t & ne0,
|
| 4435 |
constant int64_t & ne1,
|
| 4436 |
-
constant
|
| 4437 |
constant uint & r2,
|
| 4438 |
constant uint & r3,
|
| 4439 |
constant int & idx,
|
|
@@ -4480,7 +4535,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
| 4480 |
device const char * ids,
|
| 4481 |
device const char * src1,
|
| 4482 |
device uchar * dst,
|
| 4483 |
-
constant
|
| 4484 |
constant int64_t & ne00,
|
| 4485 |
constant int64_t & ne01,
|
| 4486 |
constant int64_t & ne02,
|
|
@@ -4496,7 +4551,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
| 4496 |
constant uint64_t & nb12,
|
| 4497 |
constant int64_t & ne0,
|
| 4498 |
constant int64_t & ne1,
|
| 4499 |
-
constant
|
| 4500 |
constant uint & r2,
|
| 4501 |
constant uint & r3,
|
| 4502 |
constant int & idx,
|
|
@@ -4543,7 +4598,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
| 4543 |
device const char * ids,
|
| 4544 |
device const char * src1,
|
| 4545 |
device uchar * dst,
|
| 4546 |
-
constant
|
| 4547 |
constant int64_t & ne00,
|
| 4548 |
constant int64_t & ne01,
|
| 4549 |
constant int64_t & ne02,
|
|
@@ -4559,7 +4614,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
| 4559 |
constant uint64_t & nb12,
|
| 4560 |
constant int64_t & ne0,
|
| 4561 |
constant int64_t & ne1,
|
| 4562 |
-
constant
|
| 4563 |
constant uint & r2,
|
| 4564 |
constant uint & r3,
|
| 4565 |
constant int & idx,
|
|
@@ -4606,7 +4661,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
| 4606 |
device const char * ids,
|
| 4607 |
device const char * src1,
|
| 4608 |
device uchar * dst,
|
| 4609 |
-
constant
|
| 4610 |
constant int64_t & ne00,
|
| 4611 |
constant int64_t & ne01,
|
| 4612 |
constant int64_t & ne02,
|
|
@@ -4622,7 +4677,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
| 4622 |
constant uint64_t & nb12,
|
| 4623 |
constant int64_t & ne0,
|
| 4624 |
constant int64_t & ne1,
|
| 4625 |
-
constant
|
| 4626 |
constant uint & r2,
|
| 4627 |
constant uint & r3,
|
| 4628 |
constant int & idx,
|
|
@@ -4669,7 +4724,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
| 4669 |
device const char * ids,
|
| 4670 |
device const char * src1,
|
| 4671 |
device uchar * dst,
|
| 4672 |
-
constant
|
| 4673 |
constant int64_t & ne00,
|
| 4674 |
constant int64_t & ne01,
|
| 4675 |
constant int64_t & ne02,
|
|
@@ -4685,7 +4740,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
| 4685 |
constant uint64_t & nb12,
|
| 4686 |
constant int64_t & ne0,
|
| 4687 |
constant int64_t & ne1,
|
| 4688 |
-
constant
|
| 4689 |
constant uint & r2,
|
| 4690 |
constant uint & r3,
|
| 4691 |
constant int & idx,
|
|
@@ -4732,7 +4787,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
| 4732 |
device const char * ids,
|
| 4733 |
device const char * src1,
|
| 4734 |
device uchar * dst,
|
| 4735 |
-
constant
|
| 4736 |
constant int64_t & ne00,
|
| 4737 |
constant int64_t & ne01,
|
| 4738 |
constant int64_t & ne02,
|
|
@@ -4748,7 +4803,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
| 4748 |
constant uint64_t & nb12,
|
| 4749 |
constant int64_t & ne0,
|
| 4750 |
constant int64_t & ne1,
|
| 4751 |
-
constant
|
| 4752 |
constant uint & r2,
|
| 4753 |
constant uint & r3,
|
| 4754 |
constant int & idx,
|
|
@@ -4795,7 +4850,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
| 4795 |
device const char * ids,
|
| 4796 |
device const char * src1,
|
| 4797 |
device uchar * dst,
|
| 4798 |
-
constant
|
| 4799 |
constant int64_t & ne00,
|
| 4800 |
constant int64_t & ne01,
|
| 4801 |
constant int64_t & ne02,
|
|
@@ -4811,7 +4866,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
| 4811 |
constant uint64_t & nb12,
|
| 4812 |
constant int64_t & ne0,
|
| 4813 |
constant int64_t & ne1,
|
| 4814 |
-
constant
|
| 4815 |
constant uint & r2,
|
| 4816 |
constant uint & r3,
|
| 4817 |
constant int & idx,
|
|
@@ -4858,7 +4913,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
| 4858 |
device const char * ids,
|
| 4859 |
device const char * src1,
|
| 4860 |
device uchar * dst,
|
| 4861 |
-
constant
|
| 4862 |
constant int64_t & ne00,
|
| 4863 |
constant int64_t & ne01,
|
| 4864 |
constant int64_t & ne02,
|
|
@@ -4874,7 +4929,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
| 4874 |
constant uint64_t & nb12,
|
| 4875 |
constant int64_t & ne0,
|
| 4876 |
constant int64_t & ne1,
|
| 4877 |
-
constant
|
| 4878 |
constant uint & r2,
|
| 4879 |
constant uint & r3,
|
| 4880 |
constant int & idx,
|
|
|
|
| 59 |
constant int64_t & ne01,
|
| 60 |
constant int64_t & ne02,
|
| 61 |
constant int64_t & ne03,
|
| 62 |
+
constant uint64_t & nb00,
|
| 63 |
+
constant uint64_t & nb01,
|
| 64 |
+
constant uint64_t & nb02,
|
| 65 |
+
constant uint64_t & nb03,
|
| 66 |
constant int64_t & ne10,
|
| 67 |
constant int64_t & ne11,
|
| 68 |
constant int64_t & ne12,
|
| 69 |
constant int64_t & ne13,
|
| 70 |
+
constant uint64_t & nb10,
|
| 71 |
+
constant uint64_t & nb11,
|
| 72 |
+
constant uint64_t & nb12,
|
| 73 |
+
constant uint64_t & nb13,
|
| 74 |
constant int64_t & ne0,
|
| 75 |
constant int64_t & ne1,
|
| 76 |
constant int64_t & ne2,
|
| 77 |
constant int64_t & ne3,
|
| 78 |
+
constant uint64_t & nb0,
|
| 79 |
+
constant uint64_t & nb1,
|
| 80 |
+
constant uint64_t & nb2,
|
| 81 |
+
constant uint64_t & nb3,
|
| 82 |
constant int64_t & offs,
|
| 83 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 84 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 109 |
constant int64_t & ne01,
|
| 110 |
constant int64_t & ne02,
|
| 111 |
constant int64_t & ne03,
|
| 112 |
+
constant uint64_t & nb00,
|
| 113 |
+
constant uint64_t & nb01,
|
| 114 |
+
constant uint64_t & nb02,
|
| 115 |
+
constant uint64_t & nb03,
|
| 116 |
constant int64_t & ne10,
|
| 117 |
constant int64_t & ne11,
|
| 118 |
constant int64_t & ne12,
|
| 119 |
constant int64_t & ne13,
|
| 120 |
+
constant uint64_t & nb10,
|
| 121 |
+
constant uint64_t & nb11,
|
| 122 |
+
constant uint64_t & nb12,
|
| 123 |
+
constant uint64_t & nb13,
|
| 124 |
constant int64_t & ne0,
|
| 125 |
constant int64_t & ne1,
|
| 126 |
constant int64_t & ne2,
|
| 127 |
constant int64_t & ne3,
|
| 128 |
+
constant uint64_t & nb0,
|
| 129 |
+
constant uint64_t & nb1,
|
| 130 |
+
constant uint64_t & nb2,
|
| 131 |
+
constant uint64_t & nb3,
|
| 132 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 133 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 134 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
| 158 |
constant int64_t & ne01,
|
| 159 |
constant int64_t & ne02,
|
| 160 |
constant int64_t & ne03,
|
| 161 |
+
constant uint64_t & nb00,
|
| 162 |
+
constant uint64_t & nb01,
|
| 163 |
+
constant uint64_t & nb02,
|
| 164 |
+
constant uint64_t & nb03,
|
| 165 |
constant int64_t & ne10,
|
| 166 |
constant int64_t & ne11,
|
| 167 |
constant int64_t & ne12,
|
| 168 |
constant int64_t & ne13,
|
| 169 |
+
constant uint64_t & nb10,
|
| 170 |
+
constant uint64_t & nb11,
|
| 171 |
+
constant uint64_t & nb12,
|
| 172 |
+
constant uint64_t & nb13,
|
| 173 |
constant int64_t & ne0,
|
| 174 |
constant int64_t & ne1,
|
| 175 |
constant int64_t & ne2,
|
| 176 |
constant int64_t & ne3,
|
| 177 |
+
constant uint64_t & nb0,
|
| 178 |
+
constant uint64_t & nb1,
|
| 179 |
+
constant uint64_t & nb2,
|
| 180 |
+
constant uint64_t & nb3,
|
| 181 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 182 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 183 |
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
| 205 |
device const float4 * src0,
|
| 206 |
device const float4 * src1,
|
| 207 |
device float4 * dst,
|
| 208 |
+
constant uint64_t & nb [[buffer(28)]],
|
| 209 |
uint tpig[[thread_position_in_grid]]) {
|
| 210 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 211 |
}
|
|
|
|
| 214 |
device const float4 * src0,
|
| 215 |
device const float4 * src1,
|
| 216 |
device float4 * dst,
|
| 217 |
+
constant uint64_t & nb [[buffer(28)]],
|
| 218 |
uint tpig[[thread_position_in_grid]]) {
|
| 219 |
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
| 220 |
}
|
|
|
|
| 223 |
device const float4 * src0,
|
| 224 |
device const float4 * src1,
|
| 225 |
device float4 * dst,
|
| 226 |
+
constant uint64_t & nb [[buffer(28)]],
|
| 227 |
uint tpig[[thread_position_in_grid]]) {
|
| 228 |
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
| 229 |
}
|
|
|
|
| 307 |
constant int64_t & ne01,
|
| 308 |
constant int64_t & ne02,
|
| 309 |
constant int64_t & ne03,
|
| 310 |
+
constant uint64_t & nb00,
|
| 311 |
+
constant uint64_t & nb01,
|
| 312 |
+
constant uint64_t & nb02,
|
| 313 |
+
constant uint64_t & nb03,
|
| 314 |
constant int64_t & ne10,
|
| 315 |
constant int64_t & ne11,
|
| 316 |
constant int64_t & ne12,
|
| 317 |
constant int64_t & ne13,
|
| 318 |
+
constant uint64_t & nb10,
|
| 319 |
+
constant uint64_t & nb11,
|
| 320 |
+
constant uint64_t & nb12,
|
| 321 |
+
constant uint64_t & nb13,
|
| 322 |
constant int64_t & ne0,
|
| 323 |
constant int64_t & ne1,
|
| 324 |
constant int64_t & ne2,
|
| 325 |
constant int64_t & ne3,
|
| 326 |
+
constant uint64_t & nb0,
|
| 327 |
+
constant uint64_t & nb1,
|
| 328 |
+
constant uint64_t & nb2,
|
| 329 |
+
constant uint64_t & nb3,
|
| 330 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 331 |
int64_t i3 = tpig.z;
|
| 332 |
int64_t i2 = tpig.y;
|
|
|
|
| 920 |
device const float * src1,
|
| 921 |
device float * dst,
|
| 922 |
constant int64_t & ne00,
|
| 923 |
+
constant int64_t & ne01,
|
| 924 |
+
constant int64_t & ne02,
|
| 925 |
+
constant uint64_t & nb00,
|
| 926 |
+
constant uint64_t & nb01,
|
| 927 |
+
constant uint64_t & nb02,
|
| 928 |
+
constant int64_t & ne10,
|
| 929 |
+
constant int64_t & ne11,
|
| 930 |
+
constant int64_t & ne12,
|
| 931 |
+
constant uint64_t & nb10,
|
| 932 |
+
constant uint64_t & nb11,
|
| 933 |
+
constant uint64_t & nb12,
|
| 934 |
+
constant int64_t & ne0,
|
| 935 |
+
constant int64_t & ne1,
|
| 936 |
+
constant uint & r2,
|
| 937 |
+
constant uint & r3,
|
| 938 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 939 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 940 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 946 |
device const float * src1,
|
| 947 |
device float * dst,
|
| 948 |
constant int64_t & ne00,
|
| 949 |
+
constant int64_t & ne01,
|
| 950 |
+
constant int64_t & ne02,
|
| 951 |
+
constant uint64_t & nb00,
|
| 952 |
+
constant uint64_t & nb01,
|
| 953 |
+
constant uint64_t & nb02,
|
| 954 |
+
constant int64_t & ne10,
|
| 955 |
+
constant int64_t & ne11,
|
| 956 |
+
constant int64_t & ne12,
|
| 957 |
+
constant uint64_t & nb10,
|
| 958 |
+
constant uint64_t & nb11,
|
| 959 |
+
constant uint64_t & nb12,
|
| 960 |
+
constant int64_t & ne0,
|
| 961 |
+
constant int64_t & ne1,
|
| 962 |
+
constant uint & r2,
|
| 963 |
+
constant uint & r3,
|
| 964 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 965 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 966 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 972 |
device const float * src1,
|
| 973 |
device float * dst,
|
| 974 |
constant int64_t & ne00,
|
| 975 |
+
constant int64_t & ne01,
|
| 976 |
+
constant int64_t & ne02,
|
| 977 |
+
constant uint64_t & nb00,
|
| 978 |
+
constant uint64_t & nb01,
|
| 979 |
+
constant uint64_t & nb02,
|
| 980 |
+
constant int64_t & ne10,
|
| 981 |
+
constant int64_t & ne11,
|
| 982 |
+
constant int64_t & ne12,
|
| 983 |
+
constant uint64_t & nb10,
|
| 984 |
+
constant uint64_t & nb11,
|
| 985 |
+
constant uint64_t & nb12,
|
| 986 |
+
constant int64_t & ne0,
|
| 987 |
+
constant int64_t & ne1,
|
| 988 |
+
constant uint & r2,
|
| 989 |
+
constant uint & r3,
|
| 990 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 991 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 992 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 998 |
device const float * src1,
|
| 999 |
device float * dst,
|
| 1000 |
constant int64_t & ne00,
|
| 1001 |
+
constant int64_t & ne01,
|
| 1002 |
+
constant int64_t & ne02,
|
| 1003 |
+
constant uint64_t & nb00,
|
| 1004 |
+
constant uint64_t & nb01,
|
| 1005 |
+
constant uint64_t & nb02,
|
| 1006 |
+
constant int64_t & ne10,
|
| 1007 |
+
constant int64_t & ne11,
|
| 1008 |
+
constant int64_t & ne12,
|
| 1009 |
+
constant uint64_t & nb10,
|
| 1010 |
+
constant uint64_t & nb11,
|
| 1011 |
+
constant uint64_t & nb12,
|
| 1012 |
+
constant int64_t & ne0,
|
| 1013 |
+
constant int64_t & ne1,
|
| 1014 |
+
constant uint & r2,
|
| 1015 |
+
constant uint & r3,
|
| 1016 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1017 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1018 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 1099 |
constant int64_t & ne00,
|
| 1100 |
constant int64_t & ne01,
|
| 1101 |
constant int64_t & ne02,
|
| 1102 |
+
constant uint64_t & nb00,
|
| 1103 |
+
constant uint64_t & nb01,
|
| 1104 |
+
constant uint64_t & nb02,
|
| 1105 |
constant int64_t & ne10,
|
| 1106 |
+
constant int64_t & ne11,
|
| 1107 |
constant int64_t & ne12,
|
| 1108 |
+
constant uint64_t & nb10,
|
| 1109 |
+
constant uint64_t & nb11,
|
| 1110 |
+
constant uint64_t & nb12,
|
| 1111 |
constant int64_t & ne0,
|
| 1112 |
constant int64_t & ne1,
|
| 1113 |
+
constant uint & r2,
|
| 1114 |
+
constant uint & r3,
|
| 1115 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1116 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1117 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 1217 |
constant uint64_t & nb12,
|
| 1218 |
constant int64_t & ne0,
|
| 1219 |
constant int64_t & ne1,
|
| 1220 |
+
constant uint & r2,
|
| 1221 |
+
constant uint & r3,
|
| 1222 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1223 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1224 |
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
|
| 1244 |
constant uint64_t & nb12,
|
| 1245 |
constant int64_t & ne0,
|
| 1246 |
constant int64_t & ne1,
|
| 1247 |
+
constant uint & r2,
|
| 1248 |
+
constant uint & r3,
|
| 1249 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1250 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1251 |
|
|
|
|
| 1381 |
constant uint64_t & nb12,
|
| 1382 |
constant int64_t & ne0,
|
| 1383 |
constant int64_t & ne1,
|
| 1384 |
+
constant uint & r2,
|
| 1385 |
+
constant uint & r3,
|
| 1386 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1387 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1388 |
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
|
| 1487 |
constant uint64_t & nb12,
|
| 1488 |
constant int64_t & ne0,
|
| 1489 |
constant int64_t & ne1,
|
| 1490 |
+
constant uint & r2,
|
| 1491 |
+
constant uint & r3,
|
| 1492 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1493 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1494 |
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
|
| 1513 |
constant uint64_t & nb12,
|
| 1514 |
constant int64_t & ne0,
|
| 1515 |
constant int64_t & ne1,
|
| 1516 |
+
constant uint & r2,
|
| 1517 |
+
constant uint & r3,
|
| 1518 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1519 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1520 |
|
|
|
|
| 1578 |
const int64_t i3 = n / (ne2*ne1*ne0);
|
| 1579 |
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
| 1580 |
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 1581 |
+
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
| 1582 |
+
|
| 1583 |
const int64_t k = i3*ne3 + i2;
|
| 1584 |
|
| 1585 |
float m_k;
|
|
|
|
| 2446 |
} block_q6_K;
|
| 2447 |
// 210 bytes / block
|
| 2448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2449 |
//====================================== dot products =========================
|
| 2450 |
|
| 2451 |
void kernel_mul_mv_q2_K_f32_impl(
|
|
|
|
| 2604 |
device const float * src1,
|
| 2605 |
device float * dst,
|
| 2606 |
constant int64_t & ne00,
|
| 2607 |
+
constant int64_t & ne01,
|
| 2608 |
+
constant int64_t & ne02,
|
| 2609 |
+
constant uint64_t & nb00,
|
| 2610 |
+
constant uint64_t & nb01,
|
| 2611 |
+
constant uint64_t & nb02,
|
| 2612 |
+
constant int64_t & ne10,
|
| 2613 |
+
constant int64_t & ne11,
|
| 2614 |
+
constant int64_t & ne12,
|
| 2615 |
+
constant uint64_t & nb10,
|
| 2616 |
+
constant uint64_t & nb11,
|
| 2617 |
+
constant uint64_t & nb12,
|
| 2618 |
+
constant int64_t & ne0,
|
| 2619 |
+
constant int64_t & ne1,
|
| 2620 |
+
constant uint & r2,
|
| 2621 |
+
constant uint & r3,
|
| 2622 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2623 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2624 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 2868 |
device const float * src1,
|
| 2869 |
device float * dst,
|
| 2870 |
constant int64_t & ne00,
|
| 2871 |
+
constant int64_t & ne01,
|
| 2872 |
+
constant int64_t & ne02,
|
| 2873 |
+
constant uint64_t & nb00,
|
| 2874 |
+
constant uint64_t & nb01,
|
| 2875 |
+
constant uint64_t & nb02,
|
| 2876 |
+
constant int64_t & ne10,
|
| 2877 |
+
constant int64_t & ne11,
|
| 2878 |
+
constant int64_t & ne12,
|
| 2879 |
+
constant uint64_t & nb10,
|
| 2880 |
+
constant uint64_t & nb11,
|
| 2881 |
+
constant uint64_t & nb12,
|
| 2882 |
+
constant int64_t & ne0,
|
| 2883 |
+
constant int64_t & ne1,
|
| 2884 |
+
constant uint & r2,
|
| 2885 |
+
constant uint & r3,
|
| 2886 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2887 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 2888 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3018 |
constant uint & r2,
|
| 3019 |
constant uint & r3,
|
| 3020 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3021 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 3022 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3023 |
|
| 3024 |
const int ix = tiisg/4; // 0...7
|
| 3025 |
const int it = tiisg%4; // 0...3
|
|
|
|
| 3028 |
const int r0 = tgpig.x;
|
| 3029 |
const int r1 = tgpig.y;
|
| 3030 |
const int im = tgpig.z;
|
| 3031 |
+
const int first_row = r0 * N_DST;
|
| 3032 |
const int ib_row = first_row * nb;
|
| 3033 |
|
| 3034 |
const uint i12 = im%ne12;
|
|
|
|
| 3094 |
for (int row = 0; row < N_DST; ++row) {
|
| 3095 |
all_sum = simd_sum(sumf[row]);
|
| 3096 |
if (tiisg == 0) {
|
| 3097 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
| 3098 |
}
|
| 3099 |
}
|
| 3100 |
}
|
|
|
|
| 3106 |
device const float * src1,
|
| 3107 |
device float * dst,
|
| 3108 |
constant int64_t & ne00,
|
| 3109 |
+
constant int64_t & ne01,
|
| 3110 |
+
constant int64_t & ne02,
|
| 3111 |
+
constant uint64_t & nb00,
|
| 3112 |
+
constant uint64_t & nb01,
|
| 3113 |
+
constant uint64_t & nb02,
|
| 3114 |
+
constant int64_t & ne10,
|
| 3115 |
+
constant int64_t & ne11,
|
| 3116 |
+
constant int64_t & ne12,
|
| 3117 |
+
constant uint64_t & nb10,
|
| 3118 |
+
constant uint64_t & nb11,
|
| 3119 |
+
constant uint64_t & nb12,
|
| 3120 |
+
constant int64_t & ne0,
|
| 3121 |
+
constant int64_t & ne1,
|
| 3122 |
+
constant uint & r2,
|
| 3123 |
+
constant uint & r3,
|
| 3124 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3125 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3126 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3312 |
device const float * src1,
|
| 3313 |
device float * dst,
|
| 3314 |
constant int64_t & ne00,
|
| 3315 |
+
constant int64_t & ne01,
|
| 3316 |
+
constant int64_t & ne02,
|
| 3317 |
+
constant uint64_t & nb00,
|
| 3318 |
+
constant uint64_t & nb01,
|
| 3319 |
+
constant uint64_t & nb02,
|
| 3320 |
+
constant int64_t & ne10,
|
| 3321 |
+
constant int64_t & ne11,
|
| 3322 |
+
constant int64_t & ne12,
|
| 3323 |
+
constant uint64_t & nb10,
|
| 3324 |
+
constant uint64_t & nb11,
|
| 3325 |
+
constant uint64_t & nb12,
|
| 3326 |
+
constant int64_t & ne0,
|
| 3327 |
+
constant int64_t & ne1,
|
| 3328 |
+
constant uint & r2,
|
| 3329 |
+
constant uint & r3,
|
| 3330 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3331 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3332 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3446 |
device const float * src1,
|
| 3447 |
device float * dst,
|
| 3448 |
constant int64_t & ne00,
|
| 3449 |
+
constant int64_t & ne01,
|
| 3450 |
+
constant int64_t & ne02,
|
| 3451 |
+
constant uint64_t & nb00,
|
| 3452 |
+
constant uint64_t & nb01,
|
| 3453 |
+
constant uint64_t & nb02,
|
| 3454 |
+
constant int64_t & ne10,
|
| 3455 |
+
constant int64_t & ne11,
|
| 3456 |
+
constant int64_t & ne12,
|
| 3457 |
+
constant uint64_t & nb10,
|
| 3458 |
+
constant uint64_t & nb11,
|
| 3459 |
+
constant uint64_t & nb12,
|
| 3460 |
+
constant int64_t & ne0,
|
| 3461 |
+
constant int64_t & ne1,
|
| 3462 |
+
constant uint & r2,
|
| 3463 |
+
constant uint & r3,
|
| 3464 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3465 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3466 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 3578 |
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 3579 |
const half d = xb->d;
|
| 3580 |
|
| 3581 |
+
for (int i = 0; i < 16; i++) {
|
| 3582 |
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
| 3583 |
}
|
| 3584 |
}
|
|
|
|
| 3847 |
device float * dst,
|
| 3848 |
constant int64_t & ne00,
|
| 3849 |
constant int64_t & ne02,
|
| 3850 |
+
constant uint64_t & nb01,
|
| 3851 |
+
constant uint64_t & nb02,
|
| 3852 |
constant int64_t & ne12,
|
| 3853 |
+
constant uint64_t & nb10,
|
| 3854 |
+
constant uint64_t & nb11,
|
| 3855 |
+
constant uint64_t & nb12,
|
| 3856 |
constant int64_t & ne0,
|
| 3857 |
constant int64_t & ne1,
|
| 3858 |
constant uint & r2,
|
|
|
|
| 3979 |
device float * dst,
|
| 3980 |
constant int64_t & ne00,
|
| 3981 |
constant int64_t & ne02,
|
| 3982 |
+
constant uint64_t & nb01,
|
| 3983 |
+
constant uint64_t & nb02,
|
| 3984 |
constant int64_t & ne12,
|
| 3985 |
+
constant uint64_t & nb10,
|
| 3986 |
+
constant uint64_t & nb11,
|
| 3987 |
+
constant uint64_t & nb12,
|
| 3988 |
constant int64_t & ne0,
|
| 3989 |
constant int64_t & ne1,
|
| 3990 |
constant uint & r2,
|
|
|
|
| 4020 |
device const uchar * ids,
|
| 4021 |
device const uchar * src1,
|
| 4022 |
device uchar * dst,
|
| 4023 |
+
constant uint64_t & nbi1,
|
| 4024 |
constant int64_t & ne00,
|
| 4025 |
constant int64_t & ne02,
|
| 4026 |
+
constant uint64_t & nb01,
|
| 4027 |
+
constant uint64_t & nb02,
|
| 4028 |
constant int64_t & ne12,
|
| 4029 |
constant int64_t & ne13,
|
| 4030 |
+
constant uint64_t & nb10,
|
| 4031 |
+
constant uint64_t & nb11,
|
| 4032 |
+
constant uint64_t & nb12,
|
| 4033 |
constant int64_t & ne0,
|
| 4034 |
constant int64_t & ne1,
|
| 4035 |
+
constant uint64_t & nb1,
|
| 4036 |
constant uint & r2,
|
| 4037 |
constant uint & r3,
|
| 4038 |
constant int & idx,
|
|
|
|
| 4125 |
device float * dst,
|
| 4126 |
constant int64_t & ne00,
|
| 4127 |
constant int64_t & ne02,
|
| 4128 |
+
constant uint64_t & nb01,
|
| 4129 |
+
constant uint64_t & nb02,
|
| 4130 |
constant int64_t & ne12,
|
| 4131 |
+
constant uint64_t & nb10,
|
| 4132 |
+
constant uint64_t & nb11,
|
| 4133 |
+
constant uint64_t & nb12,
|
| 4134 |
constant int64_t & ne0,
|
| 4135 |
constant int64_t & ne1,
|
| 4136 |
constant uint & r2,
|
|
|
|
| 4159 |
device const uchar * ids,
|
| 4160 |
device const uchar * src1,
|
| 4161 |
device uchar * dst,
|
| 4162 |
+
constant uint64_t & nbi1,
|
| 4163 |
constant int64_t & ne00,
|
| 4164 |
constant int64_t & ne02,
|
| 4165 |
+
constant uint64_t & nb01,
|
| 4166 |
+
constant uint64_t & nb02,
|
| 4167 |
constant int64_t & ne12,
|
| 4168 |
constant int64_t & ne13,
|
| 4169 |
+
constant uint64_t & nb10,
|
| 4170 |
+
constant uint64_t & nb11,
|
| 4171 |
+
constant uint64_t & nb12,
|
| 4172 |
constant int64_t & ne0,
|
| 4173 |
constant int64_t & ne1,
|
| 4174 |
+
constant uint64_t & nb1,
|
| 4175 |
constant uint & r2,
|
| 4176 |
constant uint & r3,
|
| 4177 |
constant int & idx,
|
|
|
|
| 4208 |
device const char * ids,
|
| 4209 |
device const char * src1,
|
| 4210 |
device uchar * dst,
|
| 4211 |
+
constant uint64_t & nbi1,
|
| 4212 |
constant int64_t & ne00,
|
| 4213 |
constant int64_t & ne01,
|
| 4214 |
constant int64_t & ne02,
|
|
|
|
| 4224 |
constant uint64_t & nb12,
|
| 4225 |
constant int64_t & ne0,
|
| 4226 |
constant int64_t & ne1,
|
| 4227 |
+
constant uint64_t & nb1,
|
| 4228 |
constant uint & r2,
|
| 4229 |
constant uint & r3,
|
| 4230 |
constant int & idx,
|
|
|
|
| 4277 |
device const char * ids,
|
| 4278 |
device const char * src1,
|
| 4279 |
device uchar * dst,
|
| 4280 |
+
constant uint64_t & nbi1,
|
| 4281 |
constant int64_t & ne00,
|
| 4282 |
constant int64_t & ne01,
|
| 4283 |
constant int64_t & ne02,
|
|
|
|
| 4293 |
constant uint64_t & nb12,
|
| 4294 |
constant int64_t & ne0,
|
| 4295 |
constant int64_t & ne1,
|
| 4296 |
+
constant uint64_t & nb1,
|
| 4297 |
constant uint & r2,
|
| 4298 |
constant uint & r3,
|
| 4299 |
constant int & idx,
|
|
|
|
| 4346 |
device const char * ids,
|
| 4347 |
device const char * src1,
|
| 4348 |
device uchar * dst,
|
| 4349 |
+
constant uint64_t & nbi1,
|
| 4350 |
constant int64_t & ne00,
|
| 4351 |
constant int64_t & ne01,
|
| 4352 |
constant int64_t & ne02,
|
|
|
|
| 4362 |
constant uint64_t & nb12,
|
| 4363 |
constant int64_t & ne0,
|
| 4364 |
constant int64_t & ne1,
|
| 4365 |
+
constant uint64_t & nb1,
|
| 4366 |
constant uint & r2,
|
| 4367 |
constant uint & r3,
|
| 4368 |
constant int & idx,
|
|
|
|
| 4409 |
device const char * ids,
|
| 4410 |
device const char * src1,
|
| 4411 |
device uchar * dst,
|
| 4412 |
+
constant uint64_t & nbi1,
|
| 4413 |
constant int64_t & ne00,
|
| 4414 |
constant int64_t & ne01,
|
| 4415 |
constant int64_t & ne02,
|
|
|
|
| 4425 |
constant uint64_t & nb12,
|
| 4426 |
constant int64_t & ne0,
|
| 4427 |
constant int64_t & ne1,
|
| 4428 |
+
constant uint64_t & nb1,
|
| 4429 |
constant uint & r2,
|
| 4430 |
constant uint & r3,
|
| 4431 |
constant int & idx,
|
|
|
|
| 4472 |
device const char * ids,
|
| 4473 |
device const char * src1,
|
| 4474 |
device uchar * dst,
|
| 4475 |
+
constant uint64_t & nbi1,
|
| 4476 |
constant int64_t & ne00,
|
| 4477 |
constant int64_t & ne01,
|
| 4478 |
constant int64_t & ne02,
|
|
|
|
| 4488 |
constant uint64_t & nb12,
|
| 4489 |
constant int64_t & ne0,
|
| 4490 |
constant int64_t & ne1,
|
| 4491 |
+
constant uint64_t & nb1,
|
| 4492 |
constant uint & r2,
|
| 4493 |
constant uint & r3,
|
| 4494 |
constant int & idx,
|
|
|
|
| 4535 |
device const char * ids,
|
| 4536 |
device const char * src1,
|
| 4537 |
device uchar * dst,
|
| 4538 |
+
constant uint64_t & nbi1,
|
| 4539 |
constant int64_t & ne00,
|
| 4540 |
constant int64_t & ne01,
|
| 4541 |
constant int64_t & ne02,
|
|
|
|
| 4551 |
constant uint64_t & nb12,
|
| 4552 |
constant int64_t & ne0,
|
| 4553 |
constant int64_t & ne1,
|
| 4554 |
+
constant uint64_t & nb1,
|
| 4555 |
constant uint & r2,
|
| 4556 |
constant uint & r3,
|
| 4557 |
constant int & idx,
|
|
|
|
| 4598 |
device const char * ids,
|
| 4599 |
device const char * src1,
|
| 4600 |
device uchar * dst,
|
| 4601 |
+
constant uint64_t & nbi1,
|
| 4602 |
constant int64_t & ne00,
|
| 4603 |
constant int64_t & ne01,
|
| 4604 |
constant int64_t & ne02,
|
|
|
|
| 4614 |
constant uint64_t & nb12,
|
| 4615 |
constant int64_t & ne0,
|
| 4616 |
constant int64_t & ne1,
|
| 4617 |
+
constant uint64_t & nb1,
|
| 4618 |
constant uint & r2,
|
| 4619 |
constant uint & r3,
|
| 4620 |
constant int & idx,
|
|
|
|
| 4661 |
device const char * ids,
|
| 4662 |
device const char * src1,
|
| 4663 |
device uchar * dst,
|
| 4664 |
+
constant uint64_t & nbi1,
|
| 4665 |
constant int64_t & ne00,
|
| 4666 |
constant int64_t & ne01,
|
| 4667 |
constant int64_t & ne02,
|
|
|
|
| 4677 |
constant uint64_t & nb12,
|
| 4678 |
constant int64_t & ne0,
|
| 4679 |
constant int64_t & ne1,
|
| 4680 |
+
constant uint64_t & nb1,
|
| 4681 |
constant uint & r2,
|
| 4682 |
constant uint & r3,
|
| 4683 |
constant int & idx,
|
|
|
|
| 4724 |
device const char * ids,
|
| 4725 |
device const char * src1,
|
| 4726 |
device uchar * dst,
|
| 4727 |
+
constant uint64_t & nbi1,
|
| 4728 |
constant int64_t & ne00,
|
| 4729 |
constant int64_t & ne01,
|
| 4730 |
constant int64_t & ne02,
|
|
|
|
| 4740 |
constant uint64_t & nb12,
|
| 4741 |
constant int64_t & ne0,
|
| 4742 |
constant int64_t & ne1,
|
| 4743 |
+
constant uint64_t & nb1,
|
| 4744 |
constant uint & r2,
|
| 4745 |
constant uint & r3,
|
| 4746 |
constant int & idx,
|
|
|
|
| 4787 |
device const char * ids,
|
| 4788 |
device const char * src1,
|
| 4789 |
device uchar * dst,
|
| 4790 |
+
constant uint64_t & nbi1,
|
| 4791 |
constant int64_t & ne00,
|
| 4792 |
constant int64_t & ne01,
|
| 4793 |
constant int64_t & ne02,
|
|
|
|
| 4803 |
constant uint64_t & nb12,
|
| 4804 |
constant int64_t & ne0,
|
| 4805 |
constant int64_t & ne1,
|
| 4806 |
+
constant uint64_t & nb1,
|
| 4807 |
constant uint & r2,
|
| 4808 |
constant uint & r3,
|
| 4809 |
constant int & idx,
|
|
|
|
| 4850 |
device const char * ids,
|
| 4851 |
device const char * src1,
|
| 4852 |
device uchar * dst,
|
| 4853 |
+
constant uint64_t & nbi1,
|
| 4854 |
constant int64_t & ne00,
|
| 4855 |
constant int64_t & ne01,
|
| 4856 |
constant int64_t & ne02,
|
|
|
|
| 4866 |
constant uint64_t & nb12,
|
| 4867 |
constant int64_t & ne0,
|
| 4868 |
constant int64_t & ne1,
|
| 4869 |
+
constant uint64_t & nb1,
|
| 4870 |
constant uint & r2,
|
| 4871 |
constant uint & r3,
|
| 4872 |
constant int & idx,
|
|
|
|
| 4913 |
device const char * ids,
|
| 4914 |
device const char * src1,
|
| 4915 |
device uchar * dst,
|
| 4916 |
+
constant uint64_t & nbi1,
|
| 4917 |
constant int64_t & ne00,
|
| 4918 |
constant int64_t & ne01,
|
| 4919 |
constant int64_t & ne02,
|
|
|
|
| 4929 |
constant uint64_t & nb12,
|
| 4930 |
constant int64_t & ne0,
|
| 4931 |
constant int64_t & ne1,
|
| 4932 |
+
constant uint64_t & nb1,
|
| 4933 |
constant uint & r2,
|
| 4934 |
constant uint & r3,
|
| 4935 |
constant int & idx,
|