Transforms
Transforms allow you to post-process model outputs after ONNX inference and before returning results. They run inside the model binary, operating directly on tensors for high performance.
Transforms run on Lua 5.4 in a sandboxed environment. As of right now, the transforms feature does not support LuaJIT.
Why Use Transforms?
Common use cases: - Normalize embeddings for cosine similarity - Apply softmax to convert logits to probabilities - Pool embeddings to create sentence representations - Scale outputs for specific downstream tasks
Getting Started
A transform is a Lua script that defines a Postprocess function:
---@param arr Tensor
---@return Tensor
function Postprocess(arr, ...)
-- your postprocessing logic
return tensor
end
With a handful of exceptions, the Postprocess function must return a Tensor with the exact same shape as the input Tensor provided for that model type. The exceptions are as follows:
- Embedding and sentence embedding models can modify the length of
hidden(useful for matryoshka embeddings) - Sentence embeddings are given a
Tensorof shape[batch_size, seq_len, hidden]and attention mask of[batch_size, seq_len], and must return aTensorof shape[batch_size, hidden]. In other words, it expects a pooling operation along dimensionseq_len.
Note on indexing
Lua is 1-indexed, meaning that it starts counting at 1 instead of 0. The Tensor API reflects this, meaning that you must count your axes and indices starting at 1 instead of 0.
We provide a built-in API for standard tensor operations. To learn more, check out our Tensor API reference page. You can find the stub file here.
If you don't see an op that you need, please don't hesitate to create an issue on Github.
Input Signatures
The input signature of Postprocess depends on the type of model being used.
Embedding
--- input: 3d tensor of shape [batch_size, seq_len, hidden]
---@param arr Tensor
---output: 3d tensor of shape [batch_size, seq_len, hidden]
---@return Tensor
function Postprocess(arr)
-- your postprocessing logic
return tensor
end
Sequence Classification
--- input: 2d tensor of shape [batch_size, n_labels]
---@param arr Tensor
---output: 2d tensor of shape [batch_size, n_labels]
---@return Tensor
function Postprocess(arr)
-- your postprocessing logic
return tensor
end
Token Classification
--- input: 3d tensor of shape [batch_size, seq_len, n_labels]
---@param arr Tensor
---output: 3d tensor of shape [batch_size, seq_len, n_labels]
---@return Tensor
function Postprocess(arr)
-- your postprocessing logic
return tensor
end
Sentence Embedding
Mean Pooling
To mean-pool embeddings, you can use the Tensor:mean_pool function like this: tensor:mean_pool(mask).
--- input: 3d tensor of shape [batch_size, seq_len, hidden]
---@param arr Tensor
-- input: 2d tensor of shape [batch_size, seq_len]
-- This is automatically provided to the function and is equivalent to 🤗 transformer's attention_mask.
---@param mask Tensor
---output: 2d tensor of shape [batch_size, hidden]
---@return Tensor
function Postprocess(arr, mask)
-- your postprocessing logic
return tensor
end
Typical Transform Patterns
Most transforms fall into one of 3 patterns:
1. Elementwise Transforms
Safe: they preserve shape automatically.
Examples:
- scaling (
tensor * 1.5) - activation functions (
tensor:exp())
2. Normalization Across Axis
These also preserve shape.
Examples:
- Lp normalization: (
tensor:lp_normalize(p, axis)) - subtracting mean per batch or per token
- applying softmax across a specific dimension (
tensor:softmax(2))
3. Mask-aware adjustments
When working with sentence embedding models:
function Postprocess(arr, mask)
-- embeddings: [batch, seq, hidden]
-- mask: [batch, seq]
-- operations here must output [batch, hidden]
return ...
end
Best Practices
Performance Implications
Transforms run synchronously during inference, so expensive Lua-side loops will increase latency. If you don't see an op that you need, please don't hesitate to create an issue on Github.
A typical transform follows this structure:
function Postprocess(arr, ...)
-- Step 1: apply elementwise or axis-based operations
local modified = arr:exp() -- example
-- Step 2: ensure the output shape matches the input shape
-- (all built-in ops described in the Tensor API preserve shape)
return modified
end
Debugging Transforms
You can inspect shape and values using:
Errors typically fall into:
-
axis out of range → axis must be 1-indexed and ≤ tensor rank
-
broadcasting errors → the two shapes are incompatible
-
returned value is not a tensor → must return a Tensor userdata object
-
shape mismatch → you modified rank or dimensions
Configuration
Transforms are embedded at build time. You can specify them in your config.yml either as a file path or inline.
Or, they can be passed inline: