transformer

Functions

FunctionDescription
row_max
softmax_row
softmax_rows
compute_mean
compute_var
norm_apply
layer_norm_fwd
attn_scores
attn_apply
attention_fwd
gelu_fwd
embed_fwd
pos_encode
linear_fwd
residual_add
relu_fwd

Details

row_max

fn row_max(x: i64, base: i64, cols: i64) -> i64

softmax_row

fn softmax_row(x: i64, out: i64, base: i64, cols: i64) -> i64

softmax_rows

fn softmax_rows(x: i64, out: i64, rows: i64, cols: i64) -> i64

compute_mean

fn compute_mean(x: i64) -> i64

compute_var

fn compute_var(x: i64, mean: i64) -> i64

norm_apply

fn norm_apply(x: i64, out: i64, mean: i64, std: i64, gamma: i64, beta: i64) -> i64

layer_norm_fwd

fn layer_norm_fwd(x: i64, out: i64, gamma: i64, beta: i64) -> i64

attn_scores

fn attn_scores(q: i64, k: i64, scores: i64, seq: i64, dk: i64) -> i64

attn_apply

fn attn_apply(attn: i64, v: i64, out: i64, seq: i64, dk: i64) -> i64

attention_fwd

fn attention_fwd(q: i64, k: i64, v: i64, out: i64, seq: i64, dk: i64) -> i64

gelu_fwd

fn gelu_fwd(x: i64, out: i64) -> i64

embed_fwd

fn embed_fwd(tokens: i64, table: i64, out: i64, seq: i64, d_model: i64) -> i64

pos_encode

fn pos_encode(x: i64, seq: i64, d_model: i64) -> i64

linear_fwd

fn linear_fwd(x: i64, w: i64, bias: i64, out: i64) -> i64

residual_add

fn residual_add(a: i64, b: i64, out: i64) -> i64

relu_fwd

fn relu_fwd(x: i64, out: i64) -> i64