Spaces:
Running
Running
ggml: add `GGML_SET` Metal kernel + i32 CPU kernel (ggml/1037)
Browse files* implemented cpu kernel
* add i32 test cases in test-backend-ops
* typedef `ggml_metal_kargs_set`
* implemented `kernel_set`
* memcpy
ggml/src/ggml-cpu/ggml-cpu.c
CHANGED
|
@@ -1374,7 +1374,10 @@ struct ggml_compute_state {
|
|
| 1374 |
|
| 1375 |
inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1376 |
inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1377 |
-
|
|
|
|
|
|
|
|
|
|
| 1378 |
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1379 |
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1380 |
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
|
@@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
|
|
| 8248 |
}
|
| 8249 |
}
|
| 8250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8251 |
static void ggml_compute_forward_set(
|
| 8252 |
const struct ggml_compute_params * params,
|
| 8253 |
struct ggml_tensor * dst) {
|
|
@@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
|
|
| 8259 |
{
|
| 8260 |
ggml_compute_forward_set_f32(params, dst);
|
| 8261 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8262 |
case GGML_TYPE_F16:
|
| 8263 |
case GGML_TYPE_BF16:
|
| 8264 |
case GGML_TYPE_Q4_0:
|
|
|
|
| 1374 |
|
| 1375 |
inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1376 |
inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1377 |
+
|
| 1378 |
+
inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1379 |
+
inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
|
| 1380 |
+
|
| 1381 |
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1382 |
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
| 1383 |
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
|
|
|
| 8251 |
}
|
| 8252 |
}
|
| 8253 |
|
| 8254 |
+
static void ggml_compute_forward_set_i32(
|
| 8255 |
+
const struct ggml_compute_params * params,
|
| 8256 |
+
struct ggml_tensor * dst) {
|
| 8257 |
+
|
| 8258 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 8259 |
+
const struct ggml_tensor * src1 = dst->src[1];
|
| 8260 |
+
|
| 8261 |
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
| 8262 |
+
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
| 8263 |
+
|
| 8264 |
+
// view src0 and dst with these strides and data offset inbytes during set
|
| 8265 |
+
// nb0 is implicitly element_size because src0 and dst are contiguous
|
| 8266 |
+
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
| 8267 |
+
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
| 8268 |
+
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
| 8269 |
+
size_t offset = ((int32_t *) dst->op_params)[3];
|
| 8270 |
+
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
| 8271 |
+
|
| 8272 |
+
if (!inplace) {
|
| 8273 |
+
if (params->ith == 0) {
|
| 8274 |
+
// memcpy needs to be synchronized across threads to avoid race conditions.
|
| 8275 |
+
// => do it in INIT phase
|
| 8276 |
+
memcpy(
|
| 8277 |
+
((char *) dst->data),
|
| 8278 |
+
((char *) src0->data),
|
| 8279 |
+
ggml_nbytes(dst));
|
| 8280 |
+
}
|
| 8281 |
+
ggml_barrier(params->threadpool);
|
| 8282 |
+
}
|
| 8283 |
+
|
| 8284 |
+
const int ith = params->ith;
|
| 8285 |
+
const int nth = params->nth;
|
| 8286 |
+
|
| 8287 |
+
const int nr = ggml_nrows(src1);
|
| 8288 |
+
const int nc = src1->ne[0];
|
| 8289 |
+
|
| 8290 |
+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
| 8291 |
+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
| 8292 |
+
|
| 8293 |
+
// src0 and dst as viewed during set
|
| 8294 |
+
const size_t nb0 = ggml_element_size(src0);
|
| 8295 |
+
|
| 8296 |
+
const int im0 = (ne10 == 0 ? 0 : ne10-1);
|
| 8297 |
+
const int im1 = (ne11 == 0 ? 0 : ne11-1);
|
| 8298 |
+
const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
| 8299 |
+
const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
| 8300 |
+
|
| 8301 |
+
GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
|
| 8302 |
+
|
| 8303 |
+
GGML_ASSERT(nb10 == sizeof(int32_t));
|
| 8304 |
+
|
| 8305 |
+
// rows per thread
|
| 8306 |
+
const int dr = (nr + nth - 1)/nth;
|
| 8307 |
+
|
| 8308 |
+
// row range for this thread
|
| 8309 |
+
const int ir0 = dr*ith;
|
| 8310 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 8311 |
+
|
| 8312 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 8313 |
+
// src0 and dst are viewed with shape of src1 and offset
|
| 8314 |
+
// => same indices
|
| 8315 |
+
const int i3 = ir/(ne12*ne11);
|
| 8316 |
+
const int i2 = (ir - i3*ne12*ne11)/ne11;
|
| 8317 |
+
const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
|
| 8318 |
+
|
| 8319 |
+
ggml_vec_cpy_i32(nc,
|
| 8320 |
+
(int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
|
| 8321 |
+
(int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
| 8322 |
+
}
|
| 8323 |
+
}
|
| 8324 |
+
|
| 8325 |
static void ggml_compute_forward_set(
|
| 8326 |
const struct ggml_compute_params * params,
|
| 8327 |
struct ggml_tensor * dst) {
|
|
|
|
| 8333 |
{
|
| 8334 |
ggml_compute_forward_set_f32(params, dst);
|
| 8335 |
} break;
|
| 8336 |
+
case GGML_TYPE_I32:
|
| 8337 |
+
{
|
| 8338 |
+
ggml_compute_forward_set_i32(params, dst);
|
| 8339 |
+
} break;
|
| 8340 |
case GGML_TYPE_F16:
|
| 8341 |
case GGML_TYPE_BF16:
|
| 8342 |
case GGML_TYPE_Q4_0:
|
ggml/src/ggml-metal/ggml-metal-impl.h
CHANGED
|
@@ -102,6 +102,21 @@ typedef struct {
|
|
| 102 |
uint64_t nb3;
|
| 103 |
} ggml_metal_kargs_cpy;
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
typedef struct {
|
| 106 |
int32_t ne00;
|
| 107 |
int32_t ne01;
|
|
|
|
| 102 |
uint64_t nb3;
|
| 103 |
} ggml_metal_kargs_cpy;
|
| 104 |
|
| 105 |
+
typedef struct {
|
| 106 |
+
int64_t ne10;
|
| 107 |
+
int64_t ne11;
|
| 108 |
+
int64_t ne12;
|
| 109 |
+
uint64_t nb10;
|
| 110 |
+
uint64_t nb11;
|
| 111 |
+
uint64_t nb12;
|
| 112 |
+
uint64_t nb13;
|
| 113 |
+
uint64_t nb1;
|
| 114 |
+
uint64_t nb2;
|
| 115 |
+
uint64_t nb3;
|
| 116 |
+
uint64_t offs;
|
| 117 |
+
bool inplace;
|
| 118 |
+
} ggml_metal_kargs_set;
|
| 119 |
+
|
| 120 |
typedef struct {
|
| 121 |
int32_t ne00;
|
| 122 |
int32_t ne01;
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -372,6 +372,8 @@ enum ggml_metal_kernel_type {
|
|
| 372 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
| 373 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
| 374 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
|
|
|
|
|
|
| 375 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 376 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 377 |
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
|
@@ -940,6 +942,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 940 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
| 941 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
| 942 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
|
|
|
|
|
|
| 943 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 944 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 945 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
|
@@ -1159,6 +1163,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1159 |
return false;
|
| 1160 |
};
|
| 1161 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1162 |
case GGML_OP_DIAG_MASK_INF:
|
| 1163 |
case GGML_OP_GET_ROWS:
|
| 1164 |
{
|
|
@@ -3824,6 +3838,68 @@ static void ggml_metal_encode_node(
|
|
| 3824 |
|
| 3825 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3826 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3827 |
case GGML_OP_POOL_2D:
|
| 3828 |
{
|
| 3829 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
| 372 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
| 373 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
| 374 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
| 375 |
+
GGML_METAL_KERNEL_TYPE_SET_I32,
|
| 376 |
+
GGML_METAL_KERNEL_TYPE_SET_F32,
|
| 377 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 378 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 379 |
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
|
|
|
| 942 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
| 943 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
| 944 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
| 945 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
| 946 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
| 947 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 948 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 949 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
|
|
|
| 1163 |
return false;
|
| 1164 |
};
|
| 1165 |
}
|
| 1166 |
+
case GGML_OP_SET:
|
| 1167 |
+
{
|
| 1168 |
+
switch (op->src[0]->type) {
|
| 1169 |
+
case GGML_TYPE_F32:
|
| 1170 |
+
case GGML_TYPE_I32:
|
| 1171 |
+
return true;
|
| 1172 |
+
default:
|
| 1173 |
+
return false;
|
| 1174 |
+
};
|
| 1175 |
+
}
|
| 1176 |
case GGML_OP_DIAG_MASK_INF:
|
| 1177 |
case GGML_OP_GET_ROWS:
|
| 1178 |
{
|
|
|
|
| 3838 |
|
| 3839 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3840 |
} break;
|
| 3841 |
+
case GGML_OP_SET:
|
| 3842 |
+
{
|
| 3843 |
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
| 3844 |
+
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
| 3845 |
+
|
| 3846 |
+
// src0 and dst as viewed during set
|
| 3847 |
+
const size_t dst_nb0 = ggml_element_size(src0);
|
| 3848 |
+
|
| 3849 |
+
const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
|
| 3850 |
+
const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
|
| 3851 |
+
const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
|
| 3852 |
+
const size_t offset = ((int32_t *) dst->op_params)[3];
|
| 3853 |
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
| 3854 |
+
|
| 3855 |
+
if (!inplace) {
|
| 3856 |
+
memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst));
|
| 3857 |
+
}
|
| 3858 |
+
|
| 3859 |
+
const int im0 = (ne10 == 0 ? 0 : ne10-1);
|
| 3860 |
+
const int im1 = (ne11 == 0 ? 0 : ne11-1);
|
| 3861 |
+
const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
| 3862 |
+
const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
| 3863 |
+
|
| 3864 |
+
GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst));
|
| 3865 |
+
|
| 3866 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 3867 |
+
|
| 3868 |
+
switch (src0t) {
|
| 3869 |
+
case GGML_TYPE_F32:
|
| 3870 |
+
GGML_ASSERT(nb10 == sizeof(float));
|
| 3871 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
|
| 3872 |
+
case GGML_TYPE_I32:
|
| 3873 |
+
GGML_ASSERT(nb10 == sizeof(int32_t));
|
| 3874 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
|
| 3875 |
+
default: GGML_ABORT("fatal error");
|
| 3876 |
+
}
|
| 3877 |
+
|
| 3878 |
+
ggml_metal_kargs_set args = {
|
| 3879 |
+
/*.ne10 =*/ ne10,
|
| 3880 |
+
/*.ne11 =*/ ne11,
|
| 3881 |
+
/*.ne12 =*/ ne12,
|
| 3882 |
+
/*.nb10 =*/ nb10,
|
| 3883 |
+
/*.nb11 =*/ nb11,
|
| 3884 |
+
/*.nb12 =*/ nb12,
|
| 3885 |
+
/*.nb13 =*/ nb13,
|
| 3886 |
+
/*.nb1 =*/ dst_nb1,
|
| 3887 |
+
/*.nb2 =*/ dst_nb2,
|
| 3888 |
+
/*.nb3 =*/ dst_nb3,
|
| 3889 |
+
/*.offs =*/ offset,
|
| 3890 |
+
/*.inplace =*/ inplace,
|
| 3891 |
+
};
|
| 3892 |
+
|
| 3893 |
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
|
| 3894 |
+
|
| 3895 |
+
[encoder setComputePipelineState:pipeline];
|
| 3896 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 3897 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 3898 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 3899 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 3900 |
+
|
| 3901 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3902 |
+
} break;
|
| 3903 |
case GGML_OP_POOL_2D:
|
| 3904 |
{
|
| 3905 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -3927,6 +3927,38 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_
|
|
| 3927 |
|
| 3928 |
#undef FA_TYPES
|
| 3929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3930 |
template<typename T0, typename T1>
|
| 3931 |
kernel void kernel_cpy(
|
| 3932 |
constant ggml_metal_kargs_cpy & args,
|
|
|
|
| 3927 |
|
| 3928 |
#undef FA_TYPES
|
| 3929 |
|
| 3930 |
+
template<typename T>
|
| 3931 |
+
kernel void kernel_set(
|
| 3932 |
+
constant ggml_metal_kargs_set & args,
|
| 3933 |
+
device const char * src0,
|
| 3934 |
+
device const char * src1,
|
| 3935 |
+
device char * dst,
|
| 3936 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3937 |
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
| 3938 |
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
| 3939 |
+
const int i13 = tgpig[2];
|
| 3940 |
+
const int i12 = tgpig[1];
|
| 3941 |
+
const int i11 = tgpig[0];
|
| 3942 |
+
|
| 3943 |
+
const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
|
| 3944 |
+
|
| 3945 |
+
const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
|
| 3946 |
+
const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
|
| 3947 |
+
const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
|
| 3948 |
+
|
| 3949 |
+
device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
|
| 3950 |
+
|
| 3951 |
+
for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
|
| 3952 |
+
device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
|
| 3953 |
+
dst_data[i10] = (T) src[0];
|
| 3954 |
+
}
|
| 3955 |
+
}
|
| 3956 |
+
|
| 3957 |
+
typedef decltype(kernel_set<float>) kernel_set_t;
|
| 3958 |
+
|
| 3959 |
+
template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
|
| 3960 |
+
template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
|
| 3961 |
+
|
| 3962 |
template<typename T0, typename T1>
|
| 3963 |
kernel void kernel_cpy(
|
| 3964 |
constant ggml_metal_kargs_cpy & args,
|