Functions
Details
row_max
fn row_max(x: i64, base: i64, cols: i64) -> i64
softmax_row
fn softmax_row(x: i64, out: i64, base: i64, cols: i64) -> i64
softmax_rows
fn softmax_rows(x: i64, out: i64, rows: i64, cols: i64) -> i64
compute_mean
fn compute_mean(x: i64) -> i64
compute_var
fn compute_var(x: i64, mean: i64) -> i64
norm_apply
fn norm_apply(x: i64, out: i64, mean: i64, std: i64, gamma: i64, beta: i64) -> i64
layer_norm_fwd
fn layer_norm_fwd(x: i64, out: i64, gamma: i64, beta: i64) -> i64
attn_scores
fn attn_scores(q: i64, k: i64, scores: i64, seq: i64, dk: i64) -> i64
attn_apply
fn attn_apply(attn: i64, v: i64, out: i64, seq: i64, dk: i64) -> i64
attention_fwd
fn attention_fwd(q: i64, k: i64, v: i64, out: i64, seq: i64, dk: i64) -> i64
gelu_fwd
fn gelu_fwd(x: i64, out: i64) -> i64
embed_fwd
fn embed_fwd(tokens: i64, table: i64, out: i64, seq: i64, d_model: i64) -> i64
pos_encode
fn pos_encode(x: i64, seq: i64, d_model: i64) -> i64
linear_fwd
fn linear_fwd(x: i64, w: i64, bias: i64, out: i64) -> i64
residual_add
fn residual_add(a: i64, b: i64, out: i64) -> i64
relu_fwd
fn relu_fwd(x: i64, out: i64) -> i64