Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Ztorch Architecture

This document describes the internal architecture of Ztorch.

Overview

Ztorch is a compiler-based ML library. Models are defined at compile time (or runtime), converted to an internal representation (IR), optimized, and then compiled to backend-specific code.

Model Definition → IR → Optimization → Autograd → Backend Codegen → Execution

Components

1. IR (Internal Representation)

The IR is Ztorch's internal graph representation. It's the single source of truth for all transformations.

pub const Graph = struct {
    nodes: []Node,
    edges: []Edge,
    allocator: Allocator,
};

pub const Node = union(enum) {
    matmul: MatMulOp,
    relu: ActivationOp,
    softmax: SoftmaxOp,
    layernorm: LayerNormOp,
    // ... more ops
};

pub const Edge = struct {
    from: NodeId,
    to: NodeId,
    tensor_shape: Shape,
};

Design principles:

  • Immutable after construction (transformations create new graphs)
  • Validates shape compatibility at creation
  • Lightweight - can be copied/cloned cheaply

2. Frontends

Frontends convert external model formats to Ztorch IR.

Native Zig API (v0.1):

const Model = ztorch.Sequential(.{
    ztorch.Linear(784, 128),
    ztorch.ReLU(),
});

This comptime struct is converted to IR during compilation.

ONNX Import (v0.2+):

const graph = try ztorch.frontends.onnx.load("model.onnx");

3. Optimization Passes

Optimization passes transform the IR to improve performance.

v0.1 Optimizations:

  • Operator fusion (e.g., MatMul + ReLU → FusedMatMulReLU)
  • Constant folding
  • Dead code elimination
  • Memory layout optimization

Example:

Before: MatMul → ReLU → Softmax (3 kernel launches)
After:  FusedMatMulReLUSoftmax (1 kernel launch)

4. Autograd

The autograd system generates backward pass operations from forward pass IR.

Each operation has a gradient function:

pub const MatMulOp = struct {
    pub fn forward(a: Tensor, b: Tensor) Tensor { ... }

    pub fn backward(
        d_output: Tensor,
        a: Tensor,
        b: Tensor,
    ) struct { d_a: Tensor, d_b: Tensor } {
        // d_a = d_output @ b.T
        // d_b = a.T @ d_output
        ...
    }
};

The autograd pass walks the forward graph and generates the backward graph.

5. Backend Codegen

Backend codegen converts IR operations to executable code.

CPU Scalar (reference):

  • Direct Zig implementation
  • Simple, obviously correct
  • Used for verification

CPU SIMD:

  • Intrinsics for AVX2/AVX512 (x86)
  • Intrinsics for NEON (ARM)
  • Falls back to scalar if unsupported

CUDA:

  • Generates PTX assembly
  • Comptime specialization for shapes
  • Tensor core utilization

ROCm:

  • Generates LLVM IR
  • Similar to CUDA approach

Vulkan:

  • Generates SPIR-V
  • Portable across vendors

6. Runtime

The runtime manages:

  • Memory allocation (device buffers)
  • Kernel launching
  • Synchronization
  • Error handling

Memory management:

  • Static allocation during model compilation
  • No dynamic allocation during forward/backward
  • Explicit buffer reuse

Compilation Flow

Comptime Model Definition

const Model = ztorch.Sequential(.{
    ztorch.Linear(784, 128),
    ztorch.ReLU(),
    ztorch.Linear(128, 10),
});

// At comptime:
// 1. Type-check layer compatibility (128 matches between layers)
// 2. Build IR graph
// 3. Apply optimization passes
// 4. Generate backward pass

Compilation

var model = try Model.compile(.cuda, allocator);

// During compile():
// 1. Finalize IR (if not comptime)
// 2. Allocate device memory
// 3. Generate backend code (PTX)
// 4. Load kernels
// 5. Create execution plan

Execution

const output = try model.forward(input);

// During forward():
// 1. Copy input to device (if needed)
// 2. Launch fused kernels in sequence
// 3. Return output tensor

Data Structures

Tensor

pub const Tensor = struct {
    data: DevicePtr,
    shape: Shape,
    stride: Stride,
    dtype: DType,
    device: Device,
    requires_grad: bool,

    pub fn item(self: Tensor) f32 { ... }
    pub fn reshape(self: Tensor, new_shape: Shape) Tensor { ... }
    // ...
};

Shape

pub const Shape = struct {
    dims: [MAX_DIMS]usize,
    ndim: u8,

    pub fn numel(self: Shape) usize {
        var n: usize = 1;
        for (self.dims[0..self.ndim]) |d| n *= d;
        return n;
    }
};

Backend Interface

All backends implement the same interface:

pub const Backend = struct {
    vtable: *const VTable,
    context: *anyopaque,

    pub const VTable = struct {
        matmul: *const fn (*anyopaque, Tensor, Tensor) Tensor,
        relu: *const fn (*anyopaque, Tensor) Tensor,
        softmax: *const fn (*anyopaque, Tensor, usize) Tensor,
        // ... all ops
    };
};

This allows runtime backend selection and testing backend parity.

Performance Model

Napkin Math

Before implementing any operation, estimate its cost:

MatMul (M, K) @ (K, N):
- FLOPs: 2 * M * K * N
- Memory: (M*K + K*N + M*N) * sizeof(f32) bytes
- Arithmetic intensity: 2*M*K*N / (M*K + K*N + M*N)

Example: (1024, 1024) @ (1024, 1024)
- FLOPs: 2.15B
- Memory: 12 MB
- On RTX 4090 (82 TFLOPS, 1 TB/s):
  - Compute bound if > 82 FLOPs/byte ❌
  - Memory bound: 12MB / 1TB/s = 12µs
  - Actual should be ~12µs

Benchmarking

Every implementation is benchmarked:

=== MatMul 1024x1024 ===
CPU Scalar:       450ms (4.8 GFLOPS)
CPU AVX2:         112ms (19.2 GFLOPS) - 4.0x speedup
CUDA (RTX 4090):  0.5ms (4300 GFLOPS) - 900x speedup

Testing Strategy

See testing.md for full details.

Levels:

  1. Unit tests (each op)
  2. Backend parity (GPU matches CPU)
  3. Gradient checks (numerical vs autograd)
  4. Integration (full model training)

Future Architecture

Dynamic Shapes (v0.2)

Support runtime shape variation within bounds:

const Model = ztorch.Sequential(.{
    ztorch.Linear(784, 128),
    // ... batch size determined at runtime
});

Distributed (zbmd integration)

Ztorch provides the compute engine, zbmd provides fault-tolerant distribution.

Quantization (v0.3)

Support int8, fp16, bfloat16 for inference acceleration.