Spaces:
Running
Running
metal : move mm_id indices to shared mem (llama/5982)
Browse files- ggml-metal.m +3 -3
- ggml-metal.metal +3 -3
ggml-metal.m
CHANGED
|
@@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1642 |
// TODO: make this more general
|
| 1643 |
GGML_ASSERT(n_as <= 8);
|
| 1644 |
|
| 1645 |
-
// max size of the src1ids array in the kernel
|
| 1646 |
-
GGML_ASSERT(ne11 <=
|
| 1647 |
|
| 1648 |
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
| 1649 |
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
@@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1741 |
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
| 1742 |
}
|
| 1743 |
|
| 1744 |
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 1745 |
|
| 1746 |
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 1747 |
} else {
|
|
|
|
| 1642 |
// TODO: make this more general
|
| 1643 |
GGML_ASSERT(n_as <= 8);
|
| 1644 |
|
| 1645 |
+
// max size of the src1ids array in the kernel shared buffer
|
| 1646 |
+
GGML_ASSERT(ne11 <= 4096);
|
| 1647 |
|
| 1648 |
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
| 1649 |
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
|
|
| 1741 |
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
| 1742 |
}
|
| 1743 |
|
| 1744 |
+
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
|
| 1745 |
|
| 1746 |
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 1747 |
} else {
|
ggml-metal.metal
CHANGED
|
@@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|
| 5386 |
void kernel_mul_mm_id_impl(
|
| 5387 |
device const uchar * src0,
|
| 5388 |
device const uchar * src1,
|
| 5389 |
-
|
| 5390 |
device float * dst,
|
| 5391 |
constant int64_t & ne00,
|
| 5392 |
constant int64_t & ne02,
|
|
@@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
|
|
| 5589 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 5590 |
|
| 5591 |
// row indices of src1 for expert id
|
| 5592 |
-
|
| 5593 |
-
short src1ids[512];
|
| 5594 |
|
|
|
|
| 5595 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 5596 |
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
| 5597 |
src1ids[_ne1++] = i1;
|
|
|
|
| 5386 |
void kernel_mul_mm_id_impl(
|
| 5387 |
device const uchar * src0,
|
| 5388 |
device const uchar * src1,
|
| 5389 |
+
threadgroup short * src1ids,
|
| 5390 |
device float * dst,
|
| 5391 |
constant int64_t & ne00,
|
| 5392 |
constant int64_t & ne02,
|
|
|
|
| 5589 |
tgpig.z = tgpig.z%(ne12*ne13);
|
| 5590 |
|
| 5591 |
// row indices of src1 for expert id
|
| 5592 |
+
threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
|
|
|
|
| 5593 |
|
| 5594 |
+
int64_t _ne1 = 0;
|
| 5595 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 5596 |
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
| 5597 |
src1ids[_ne1++] = i1;
|