Transform-Aware LLM Training Guidelines#
This guide explains how to implement and integrate a new compute paradigm simulation into NewComputeBench. Both continual pretraining (starting from a pretrained checkpoint) and pretraining from scratch are supported.
Note
If the transform is highly lossy, pretraining from scratch may be ineffective — the model may fail to learn at all. Continual pretraining is recommended in those cases.
Continual Pretraining#
Continual pretraining adapts a pretrained HuggingFace checkpoint by injecting the hardware simulation transform before training begins.
The Optical Transformer is used as a worked example below. The same pattern applies to any new compute paradigm.
Step 1 — Implement the Simulated Layers#
Place your implementation in src/aixsim_models/<paradigm>/.
For the Optical Transformer, the key components are in
src/aixsim_models/optical_compute/optical_transformer:
Simulated linear layer and matmul:
OpticalTransformerLinear(from mase-triton) — replaces allnn.Linearlayers exceptlm_head.HFOpticalTransformerLlamaAttention(in layers.py) — wraps the attention module to intercept Q-K and Attention-V matmuls.
Note
When to use Triton vs. PyTorch built-ins:
The Optical Transformer uses Triton kernels (ot_qlinear_fn, ot_qmatmul_fn)
primarily to save GPU memory. If your simulation can be memory-efficiently implemented
using PyTorch built-in operations, there is no need to write a Triton kernel.
Also note: HuggingFace Trainer does not work with autotuned Triton kernels. Autotuning is therefore disabled in mase-triton by default.
Step 2 — Implement the Transform Pass#
Implement a transform_hf_model function that iterates over the model’s modules and
replaces them with your simulated equivalents.
See transform.py
for the Optical Transformer example.
The typical pattern uses two loops:
Replace attention modules with your custom attention class (to intercept matmuls).
Replace remaining
nn.Linearmodules with your simulated linear layer.
A TransformConfigManager reads the YAML config and maps layer names to their
transform configurations.
Step 3 — Write a Transform Config#
Define a YAML file specifying the transform parameters layer-by-layer.
See experiments/llm-optical-transformer/continual_pretraining/configs/default.yaml
for an example.
Step 4 — Adapt the Training Script#
Copy HuggingFace’s run_clm.py and insert the following snippet just before training starts:
if model_args.transform_config is not None:
with open(model_args.transform_config, "r") as f:
transform_args = yaml.safe_load(f)
config_manager = TransformConfigManager(**transform_args)
transformed_layers = transform_hf_model(model, config_manager)
transform_histogram = make_transform_histogram(transformed_layers=transformed_layers)
logger.info(f"Transformed layers:\n{transform_histogram}")
else:
logger.info("No transform config provided. Using the original model.")
Replace transform_hf_model with your own transform pass.
The adapted script for the Optical Transformer is at
experiments/llm-optical-transformer/continual_pretraining/run_clm.py.
Step 5 — Launch Training#
Example using the Optical Transformer:
python run_clm.py \
--model_name_or_path AICrossSim/clm-60m \
--dataset_name HuggingFaceFW/fineweb-edu \
--dataset_config_name "sample-10BT" \
--per_device_train_batch_size 12 \
--per_device_eval_batch_size 12 \
--gradient_accumulation_steps 50 \
--do_train \
--report_to "wandb" \
--learning_rate 5e-5 \
--max_steps 100 \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 2 \
--bf16 \
--dataloader_num_workers 16 \
--preprocessing_num_workers 32 \
--tokenizer_name HuggingFaceTB/cosmo2-tokenizer \
--output_dir ./output/test-clm-trainer \
--transform_config ./configs/default.yaml \
--seed 42
Pretraining from Scratch#
We use torchtitan as the distributed training backend for transform-aware pretraining from scratch.
See experiments/llm-optical-transformer/pretrain/run.py for the Optical Transformer example.
The interface is identical to the standard CLM pretraining script
(see LLM Pretraining & Evaluation), with an additional
transform config argument passed to generate-cfg.