Spaces:
Build error
Build error
| // kernel argument structs | |
| // | |
| // - element counters (e.g. ne00) typically use int32_t to reduce register usage | |
| // however, be careful from int overflows when using those in the kernel implementation | |
| // | |
| // - strides (e.g. nb00) use uint64_t | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t dim; | |
| } ggml_metal_kargs_concat; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| uint64_t offs; | |
| } ggml_metal_kargs_bin; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_repeat; | |
| typedef struct { | |
| int64_t ne00; | |
| int64_t ne01; | |
| int64_t ne02; | |
| int64_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int64_t ne0; | |
| int64_t ne1; | |
| int64_t ne2; | |
| int64_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| } ggml_metal_kargs_cpy; | |
| typedef struct { | |
| int64_t ne10; | |
| int64_t ne11; | |
| int64_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| uint64_t offs; | |
| bool inplace; | |
| } ggml_metal_kargs_set; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int32_t ne2; | |
| int32_t ne3; | |
| uint64_t nb0; | |
| uint64_t nb1; | |
| uint64_t nb2; | |
| uint64_t nb3; | |
| int32_t n_past; | |
| int32_t n_dims; | |
| int32_t n_ctx_orig; | |
| float freq_base; | |
| float freq_scale; | |
| float ext_factor; | |
| float attn_factor; | |
| float beta_fast; | |
| float beta_slow; | |
| } ggml_metal_kargs_rope; | |
| typedef struct { | |
| int32_t ne01; | |
| int32_t ne02; | |
| int32_t ne03; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne11; | |
| int32_t ne_12_2; // assume K and V are same shape | |
| int32_t ne_12_3; | |
| uint64_t nb_12_1; | |
| uint64_t nb_12_2; | |
| uint64_t nb_12_3; | |
| uint64_t nb31; | |
| int32_t ne1; | |
| int32_t ne2; | |
| float scale; | |
| float max_bias; | |
| float m0; | |
| float m1; | |
| uint16_t n_head_log2; | |
| float logit_softcap; | |
| } ggml_metal_kargs_flash_attn_ext; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne02; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| } ggml_metal_kargs_mul_mm; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| } ggml_metal_kargs_mul_mv; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| uint64_t nb03; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| uint64_t nb13; | |
| int32_t ne0; | |
| int32_t ne1; | |
| int16_t r2; | |
| int16_t r3; | |
| int16_t nsg; | |
| int16_t nxpsg; | |
| int16_t r1ptg; | |
| } ggml_metal_kargs_mul_mv_ext; | |
| typedef struct { | |
| int32_t nei0; | |
| int32_t nei1; | |
| uint64_t nbi1; | |
| int32_t ne00; | |
| int32_t ne02; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| int32_t ne0; | |
| int32_t ne1; | |
| } ggml_metal_kargs_mul_mm_id; | |
| typedef struct { | |
| int32_t nei0; | |
| int32_t nei1; | |
| uint64_t nbi1; | |
| int32_t ne00; | |
| int32_t ne01; | |
| int32_t ne02; | |
| uint64_t nb00; | |
| uint64_t nb01; | |
| uint64_t nb02; | |
| int32_t ne10; | |
| int32_t ne11; | |
| int32_t ne12; | |
| int32_t ne13; | |
| uint64_t nb10; | |
| uint64_t nb11; | |
| uint64_t nb12; | |
| int32_t ne0; | |
| int32_t ne1; | |
| uint64_t nb1; | |
| } ggml_metal_kargs_mul_mv_id; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne00_4; | |
| uint64_t nb01; | |
| float eps; | |
| } ggml_metal_kargs_norm; | |
| typedef struct { | |
| int32_t ne00; | |
| int32_t ne00_4; | |
| uint64_t nb01; | |
| float eps; | |
| } ggml_metal_kargs_rms_norm; | |