nanochat-d24-CPU / nanochat.cpp
Nekochu's picture
pre-built llama-server + Q4_K_M GGUF
71137d0
#include "models.h"
// Nanochat d24: ReLU^2, QK-norm after RoPE, logit softcap 15,
// per-layer residual scalars (x = rl*x + xl*x0), value embeddings
// on alternating layers, backout (subtract mid-layer residual).
// All norms are unweighted RMSNorm (pass NULL weight).
// Scalar params read as float from model struct (not ggml tensors,
// because ggml_mul with {1} tensors causes precision issues on CPU).
llm_build_nanochat::llm_build_nanochat(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_tokens = res->t_inp_tokens;
// Embedding norm (unweighted)
inpL = build_norm(inpL, NULL, NULL, LLM_NORM_RMS, -1);
cb(inpL, "inp_norm", -1);
// x0 for residual scaling (explicit copy for lifetime tracking)
ggml_tensor * x0 = ggml_cont(ctx0, inpL);
ggml_set_name(x0, "x0");
ggml_build_forward_expand(gf, x0);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
const int backout_layer = n_layer / 2;
ggml_tensor * x_backout = nullptr;
for (int il = 0; il < n_layer; ++il) {
auto & layer = model.layers[il];
// Per-layer residual scaling: x = resid_lambda * x + x0_lambda * x0
{
float rl = model.nanochat_resid_lambda[il];
float xl = model.nanochat_x0_lambda[il];
inpL = ggml_add(ctx0, ggml_scale(ctx0, inpL, rl), ggml_scale(ctx0, x0, xl));
}
ggml_tensor * inpSA = inpL;
// Pre-attention norm (unweighted)
cur = build_norm(inpL, NULL, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// Q, K, V
ggml_tensor * Qcur = build_lora_mm(layer.wq, cur);
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur);
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// Value embeddings on alternating layers
if (layer.value_embd && layer.wqkv_gate) {
ggml_tensor * ve = ggml_get_rows(ctx0, layer.value_embd, inp_tokens);
ve = ggml_reshape_3d(ctx0, ve, n_embd_head, n_head_kv, n_tokens);
ggml_tensor * gate_in = ggml_view_2d(ctx0, cur, 12, n_tokens, cur->nb[1], 0);
ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, gate_in);
gate = ggml_sigmoid(ctx0, ggml_scale(ctx0, gate, 3.0f));
gate = ggml_reshape_3d(ctx0, gate, 1, n_head_kv, n_tokens);
Vcur = ggml_add(ctx0, Vcur, ggml_mul(ctx0, ve, gate));
}
// RoPE (before QK-norm, nanochat order)
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
// QK-norm (after RoPE) + 1.15 sharpening
Qcur = ggml_scale(ctx0, build_norm(Qcur, NULL, NULL, LLM_NORM_RMS, il), 1.15f);
Kcur = ggml_scale(ctx0, build_norm(Kcur, NULL, NULL, LLM_NORM_RMS, il), 1.15f);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// Attention + output proj
cur = build_attn(inp_attn,
layer.wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
// Attention residual
cur = ggml_add(ctx0, cur, inpSA);
if (il == backout_layer) {
x_backout = cur;
}
ggml_tensor * ffn_inp = cur;
// Pre-FFN norm (unweighted)
cur = build_norm(cur, NULL, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// FFN: ReLU^2
cur = build_ffn(cur,
layer.ffn_up, NULL, NULL,
NULL, NULL, NULL,
layer.ffn_down, NULL, NULL,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
// FFN residual
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
inpL = cur;
}
cur = inpL;
// Backout: subtract mid-layer residual
if (x_backout && model.nanochat_backout != 0.0f) {
cur = ggml_sub(ctx0, cur, ggml_scale(ctx0, x_backout, model.nanochat_backout));
}
// Final norm (unweighted)
cur = build_norm(cur, NULL, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
// Logit softcap
if (hparams.f_final_logit_softcapping) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}