PABannier commited on
Commit
dd775d5
·
1 Parent(s): 154bbc0

ggml: add `GGML_SET` Metal kernel + i32 CPU kernel (ggml/1037)

Browse files

* implemented cpu kernel

* add i32 test cases in test-backend-ops

* typedef `ggml_metal_kargs_set`

* implemented `kernel_set`

* memcpy

ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -1374,7 +1374,10 @@ struct ggml_compute_state {
1374
 
1375
  inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1376
  inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1377
- inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
 
 
1378
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1379
  inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1380
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
@@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
8248
  }
8249
  }
8250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8251
  static void ggml_compute_forward_set(
8252
  const struct ggml_compute_params * params,
8253
  struct ggml_tensor * dst) {
@@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
8259
  {
8260
  ggml_compute_forward_set_f32(params, dst);
8261
  } break;
 
 
 
 
8262
  case GGML_TYPE_F16:
8263
  case GGML_TYPE_BF16:
8264
  case GGML_TYPE_Q4_0:
 
1374
 
1375
  inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1376
  inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1377
+
1378
+ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1379
+ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1380
+
1381
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1382
  inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1383
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
 
8251
  }
8252
  }
8253
 
8254
+ static void ggml_compute_forward_set_i32(
8255
+ const struct ggml_compute_params * params,
8256
+ struct ggml_tensor * dst) {
8257
+
8258
+ const struct ggml_tensor * src0 = dst->src[0];
8259
+ const struct ggml_tensor * src1 = dst->src[1];
8260
+
8261
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8262
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
8263
+
8264
+ // view src0 and dst with these strides and data offset inbytes during set
8265
+ // nb0 is implicitly element_size because src0 and dst are contiguous
8266
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
8267
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
8268
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
8269
+ size_t offset = ((int32_t *) dst->op_params)[3];
8270
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
8271
+
8272
+ if (!inplace) {
8273
+ if (params->ith == 0) {
8274
+ // memcpy needs to be synchronized across threads to avoid race conditions.
8275
+ // => do it in INIT phase
8276
+ memcpy(
8277
+ ((char *) dst->data),
8278
+ ((char *) src0->data),
8279
+ ggml_nbytes(dst));
8280
+ }
8281
+ ggml_barrier(params->threadpool);
8282
+ }
8283
+
8284
+ const int ith = params->ith;
8285
+ const int nth = params->nth;
8286
+
8287
+ const int nr = ggml_nrows(src1);
8288
+ const int nc = src1->ne[0];
8289
+
8290
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
8291
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
8292
+
8293
+ // src0 and dst as viewed during set
8294
+ const size_t nb0 = ggml_element_size(src0);
8295
+
8296
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
8297
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
8298
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
8299
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
8300
+
8301
+ GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
8302
+
8303
+ GGML_ASSERT(nb10 == sizeof(int32_t));
8304
+
8305
+ // rows per thread
8306
+ const int dr = (nr + nth - 1)/nth;
8307
+
8308
+ // row range for this thread
8309
+ const int ir0 = dr*ith;
8310
+ const int ir1 = MIN(ir0 + dr, nr);
8311
+
8312
+ for (int ir = ir0; ir < ir1; ++ir) {
8313
+ // src0 and dst are viewed with shape of src1 and offset
8314
+ // => same indices
8315
+ const int i3 = ir/(ne12*ne11);
8316
+ const int i2 = (ir - i3*ne12*ne11)/ne11;
8317
+ const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
8318
+
8319
+ ggml_vec_cpy_i32(nc,
8320
+ (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
8321
+ (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8322
+ }
8323
+ }
8324
+
8325
  static void ggml_compute_forward_set(
8326
  const struct ggml_compute_params * params,
8327
  struct ggml_tensor * dst) {
 
8333
  {
8334
  ggml_compute_forward_set_f32(params, dst);
8335
  } break;
8336
+ case GGML_TYPE_I32:
8337
+ {
8338
+ ggml_compute_forward_set_i32(params, dst);
8339
+ } break;
8340
  case GGML_TYPE_F16:
8341
  case GGML_TYPE_BF16:
8342
  case GGML_TYPE_Q4_0:
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -102,6 +102,21 @@ typedef struct {
102
  uint64_t nb3;
103
  } ggml_metal_kargs_cpy;
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  typedef struct {
106
  int32_t ne00;
107
  int32_t ne01;
 
102
  uint64_t nb3;
103
  } ggml_metal_kargs_cpy;
104
 
105
+ typedef struct {
106
+ int64_t ne10;
107
+ int64_t ne11;
108
+ int64_t ne12;
109
+ uint64_t nb10;
110
+ uint64_t nb11;
111
+ uint64_t nb12;
112
+ uint64_t nb13;
113
+ uint64_t nb1;
114
+ uint64_t nb2;
115
+ uint64_t nb3;
116
+ uint64_t offs;
117
+ bool inplace;
118
+ } ggml_metal_kargs_set;
119
+
120
  typedef struct {
121
  int32_t ne00;
122
  int32_t ne01;
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -372,6 +372,8 @@ enum ggml_metal_kernel_type {
372
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
373
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
374
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
 
 
375
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
376
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
377
  GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
@@ -940,6 +942,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
940
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
941
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
942
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
 
 
943
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
944
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
945
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
@@ -1159,6 +1163,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1159
  return false;
1160
  };
1161
  }
 
 
 
 
 
 
 
 
 
 
1162
  case GGML_OP_DIAG_MASK_INF:
1163
  case GGML_OP_GET_ROWS:
1164
  {
@@ -3824,6 +3838,68 @@ static void ggml_metal_encode_node(
3824
 
3825
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3826
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3827
  case GGML_OP_POOL_2D:
3828
  {
3829
  GGML_ASSERT(ggml_is_contiguous(src0));
 
372
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
373
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
374
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
375
+ GGML_METAL_KERNEL_TYPE_SET_I32,
376
+ GGML_METAL_KERNEL_TYPE_SET_F32,
377
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
378
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
379
  GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
 
942
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
943
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
944
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
945
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
946
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
947
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
948
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
949
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
 
1163
  return false;
1164
  };
1165
  }
1166
+ case GGML_OP_SET:
1167
+ {
1168
+ switch (op->src[0]->type) {
1169
+ case GGML_TYPE_F32:
1170
+ case GGML_TYPE_I32:
1171
+ return true;
1172
+ default:
1173
+ return false;
1174
+ };
1175
+ }
1176
  case GGML_OP_DIAG_MASK_INF:
1177
  case GGML_OP_GET_ROWS:
1178
  {
 
3838
 
3839
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3840
  } break;
3841
+ case GGML_OP_SET:
3842
+ {
3843
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
3844
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
3845
+
3846
+ // src0 and dst as viewed during set
3847
+ const size_t dst_nb0 = ggml_element_size(src0);
3848
+
3849
+ const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
3850
+ const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
3851
+ const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
3852
+ const size_t offset = ((int32_t *) dst->op_params)[3];
3853
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
3854
+
3855
+ if (!inplace) {
3856
+ memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst));
3857
+ }
3858
+
3859
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
3860
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
3861
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
3862
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
3863
+
3864
+ GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst));
3865
+
3866
+ id<MTLComputePipelineState> pipeline = nil;
3867
+
3868
+ switch (src0t) {
3869
+ case GGML_TYPE_F32:
3870
+ GGML_ASSERT(nb10 == sizeof(float));
3871
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
3872
+ case GGML_TYPE_I32:
3873
+ GGML_ASSERT(nb10 == sizeof(int32_t));
3874
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
3875
+ default: GGML_ABORT("fatal error");
3876
+ }
3877
+
3878
+ ggml_metal_kargs_set args = {
3879
+ /*.ne10 =*/ ne10,
3880
+ /*.ne11 =*/ ne11,
3881
+ /*.ne12 =*/ ne12,
3882
+ /*.nb10 =*/ nb10,
3883
+ /*.nb11 =*/ nb11,
3884
+ /*.nb12 =*/ nb12,
3885
+ /*.nb13 =*/ nb13,
3886
+ /*.nb1 =*/ dst_nb1,
3887
+ /*.nb2 =*/ dst_nb2,
3888
+ /*.nb3 =*/ dst_nb3,
3889
+ /*.offs =*/ offset,
3890
+ /*.inplace =*/ inplace,
3891
+ };
3892
+
3893
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
3894
+
3895
+ [encoder setComputePipelineState:pipeline];
3896
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3897
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3898
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3899
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3900
+
3901
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3902
+ } break;
3903
  case GGML_OP_POOL_2D:
3904
  {
3905
  GGML_ASSERT(ggml_is_contiguous(src0));
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -3927,6 +3927,38 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_
3927
 
3928
  #undef FA_TYPES
3929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3930
  template<typename T0, typename T1>
3931
  kernel void kernel_cpy(
3932
  constant ggml_metal_kargs_cpy & args,
 
3927
 
3928
  #undef FA_TYPES
3929
 
3930
+ template<typename T>
3931
+ kernel void kernel_set(
3932
+ constant ggml_metal_kargs_set & args,
3933
+ device const char * src0,
3934
+ device const char * src1,
3935
+ device char * dst,
3936
+ uint3 tgpig[[threadgroup_position_in_grid]],
3937
+ ushort3 tpitg[[thread_position_in_threadgroup]],
3938
+ ushort3 ntg[[threads_per_threadgroup]]) {
3939
+ const int i13 = tgpig[2];
3940
+ const int i12 = tgpig[1];
3941
+ const int i11 = tgpig[0];
3942
+
3943
+ const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
3944
+
3945
+ const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
3946
+ const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
3947
+ const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
3948
+
3949
+ device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
3950
+
3951
+ for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
3952
+ device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
3953
+ dst_data[i10] = (T) src[0];
3954
+ }
3955
+ }
3956
+
3957
+ typedef decltype(kernel_set<float>) kernel_set_t;
3958
+
3959
+ template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
3960
+ template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
3961
+
3962
  template<typename T0, typename T1>
3963
  kernel void kernel_cpy(
3964
  constant ggml_metal_kargs_cpy & args,