ggerganov HF Staff commited on
Commit
1350705
·
unverified ·
1 Parent(s): cb8bbaa

metal : move mm_id indices to shared mem (llama/5982)

Browse files
Files changed (2) hide show
  1. ggml-metal.m +3 -3
  2. ggml-metal.metal +3 -3
ggml-metal.m CHANGED
@@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
1642
  // TODO: make this more general
1643
  GGML_ASSERT(n_as <= 8);
1644
 
1645
- // max size of the src1ids array in the kernel stack
1646
- GGML_ASSERT(ne11 <= 512);
1647
 
1648
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1649
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
@@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
1741
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1742
  }
1743
 
1744
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1745
 
1746
  [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1747
  } else {
 
1642
  // TODO: make this more general
1643
  GGML_ASSERT(n_as <= 8);
1644
 
1645
+ // max size of the src1ids array in the kernel shared buffer
1646
+ GGML_ASSERT(ne11 <= 4096);
1647
 
1648
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1649
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
 
1741
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1742
  }
1743
 
1744
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1745
 
1746
  [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1747
  } else {
ggml-metal.metal CHANGED
@@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
5386
  void kernel_mul_mm_id_impl(
5387
  device const uchar * src0,
5388
  device const uchar * src1,
5389
- thread short * src1ids,
5390
  device float * dst,
5391
  constant int64_t & ne00,
5392
  constant int64_t & ne02,
@@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
5589
  tgpig.z = tgpig.z%(ne12*ne13);
5590
 
5591
  // row indices of src1 for expert id
5592
- int64_t _ne1 = 0;
5593
- short src1ids[512];
5594
 
 
5595
  for (int64_t i1 = 0; i1 < ne1; i1++) {
5596
  if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
5597
  src1ids[_ne1++] = i1;
 
5386
  void kernel_mul_mm_id_impl(
5387
  device const uchar * src0,
5388
  device const uchar * src1,
5389
+ threadgroup short * src1ids,
5390
  device float * dst,
5391
  constant int64_t & ne00,
5392
  constant int64_t & ne02,
 
5589
  tgpig.z = tgpig.z%(ne12*ne13);
5590
 
5591
  // row indices of src1 for expert id
5592
+ threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
 
5593
 
5594
+ int64_t _ne1 = 0;
5595
  for (int64_t i1 = 0; i1 < ne1; i1++) {
5596
  if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
5597
  src1ids[_ne1++] = i1;