Spaces:
Running
Running
| // kernel argument structs | |
| // | |
| // - element counters (e.g. ne00) typically use int32_t to reduce register usage | |
| // however, be careful from int overflows when using those in the kernel implementation | |
| // | |
| // - strides (e.g. nb00) use uint64_t | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t dim; | |
| } ggml_metal_kargs_concat; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| uint64_t offs; | |
| } ggml_metal_kargs_bin; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_repeat; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_cpy; | |
| typedef struct { | |
| int64_t ne10; | |
| int64_t ne11; | |
| int64_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| uint64_t offs; | |
| bool inplace; | |
| } ggml_metal_kargs_set; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t n_past; | |
| int32_t n_dims; | |
| int32_t n_ctx_orig; | |
| float freq_base; | |
| float freq_scale; | |
| float ext_factor; | |
| float attn_factor; | |
| float beta_fast; | |
| float beta_slow; | |
| } ggml_metal_kargs_rope; | |
| typedef struct { | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne11; | |
| int32_t ne_12_2; // assume K and V are same shape | |
| int32_t ne_12_3; | |
| uint64_t nb_12_1; | |
| uint64_t nb_12_2; | |
| uint64_t nb_12_3; | |
| uint64_t nb31; | |
| int32_t ne1; | |
| int32_t ne2; | |
| float scale; | |
| float max_bias; | |
| float m0; | |
| float m1; | |
| uint16_t n_head_log2; | |
| float logit_softcap; | |
| } ggml_metal_kargs_flash_attn_ext; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne02; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| } ggml_metal_kargs_mul_mm; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| } ggml_metal_kargs_mul_mv; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| int16_t nsg; | |
| int16_t nxpsg; | |
| int16_t r1ptg; | |
| } ggml_metal_kargs_mul_mv_ext; | |
| typedef struct { | |
| int32_t nei0; | |
| int32_t nei1; | |
| uint64_t nbi1; | |
| int32_t ne00; | |
| int32_t ne02; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| int32_t ne0; | |
| int32_t ne1; | |
| } ggml_metal_kargs_mul_mm_id; | |
| typedef struct { | |
| int32_t nei0; | |
| int32_t nei1; | |
| uint64_t nbi1; | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| int32_t ne0; | |
| int32_t ne1; | |
| uint64_t nb1; | |
| } ggml_metal_kargs_mul_mv_id; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne00_4; | |
| uint64_t nb01; | |
| float eps; | |
| } ggml_metal_kargs_norm; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne00_4; | |
| uint64_t nb01; | |
| float eps; | |
| } ggml_metal_kargs_rms_norm; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int32_t n_groups; | |
| float eps; | |
| } ggml_metal_kargs_group_norm; | |
| typedef struct { | |
| int32_t IC; | |
| int32_t IL; | |
| int32_t K; | |
| int32_t s0; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| } ggml_metal_kargs_conv_transpose_1d; | |
| typedef struct { | |
| uint64_t ofs0; | |
| uint64_t ofs1; | |
| int32_t IW; | |
| int32_t IH; | |
| int32_t CHW; | |
| int32_t s0; | |
| int32_t s1; | |
| int32_t p0; | |
| int32_t p1; | |
| int32_t d0; | |
| int32_t d1; | |
| int32_t N; | |
| int32_t KH; | |
| int32_t KW; | |
| int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources | |
| } ggml_metal_kargs_im2col; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne10; | |
| int64_t ne11; | |
| int64_t ne12; | |
| int64_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_sum_rows; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| float scale; | |
| float max_bias; | |
| float m0; | |
| float m1; | |
| uint32_t n_head_log2; | |
| } ggml_metal_kargs_soft_max; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int n_past; | |
| } ggml_metal_kargs_diag_mask_inf; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int64_t ne10; | |
| int64_t ne11; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| } ggml_metal_kargs_ssm_conv; | |
| typedef struct { | |
| int64_t d_state; | |
| int64_t d_inner; | |
| int64_t n_seq_tokens; | |
| int64_t n_seqs; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| uint64_t nb20; | |
| uint64_t nb21; | |
| uint64_t nb22; | |
| uint64_t nb30; | |
| uint64_t nb31; | |
| uint64_t nb40; | |
| uint64_t nb41; | |
| uint64_t nb42; | |
| uint64_t nb50; | |
| uint64_t nb51; | |
| uint64_t nb52; | |
| } ggml_metal_kargs_ssm_scan; | |
| typedef struct { | |
| int64_t ne00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int64_t ne10; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| } ggml_metal_kargs_get_rows; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| float sf0; | |
| float sf1; | |
| float sf2; | |
| float sf3; | |
| } ggml_metal_kargs_upscale; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_pad; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t p0; | |
| int32_t p1; | |
| } ggml_metal_kargs_pad_reflect_1d; | |
| typedef struct { | |
| uint64_t nb1; | |
| int dim; | |
| int max_period; | |
| } ggml_metal_kargs_timestep_embedding; | |
| typedef struct { | |
| float slope; | |
| } ggml_metal_kargs_leaky_relu; | |
| typedef struct { | |
| int64_t ncols; | |
| int64_t ncols_pad; | |
| } ggml_metal_kargs_argsort; | |
| typedef struct { | |
| int64_t ne0; | |
| float start; | |
| float step; | |
| } ggml_metal_kargs_arange; | |
| typedef struct { | |
| int32_t k0; | |
| int32_t k1; | |
| int32_t s0; | |
| int32_t s1; | |
| int32_t p0; | |
| int32_t p1; | |
| int64_t IH; | |
| int64_t IW; | |
| int64_t OH; | |
| int64_t OW; | |
| int64_t parallel_elements; | |
| } ggml_metal_kargs_pool_2d; | |