MASE-Triton
Mase-triton is a PyTorch extension library that provides efficient implementations of various operations used in simulating new compute paradigms and our PLENA project, including random bitflip, optical transformer, MXFP (Microscaling Formats), and minifloat. It leverages the Triton language to enable faster simulations on CUDA-enabled GPUs.
Functionality
-
Random Bitflip: Simulate random bit flips in neural network computations
functional APIs: Random bitflip functions with backward support.random_bitflip_fn: Perform random bit flipping on tensors with configurable exponent and fraction bit flip probabilitiescalculate_flip_probability: Calculate flip probability from number of halvesfind_nearest_prob_n_halves: Find nearest probability in terms of halves
layers: PyTorch modules for random bitflip operations.RandomBitFlipDropout: Random bit flip layer with dropout functionalityRandomBitFlipLinear: Linear layer with random bit flipping
-
Optical Transformer: Simulate optical computing for transformers
functional APIs: Optical transformer functions with backward support.ot_quantize_fn: Quantize tensors for optical transformer operationsot_qlinear_fn: Quantized linear transformation for optical computingot_qmatmul_fn: Quantized matrix multiplication for optical computing
layers: PyTorch modules for optical computing.OpticalTransformerLinear: Linear layer with optical transformer quantization
-
MXFP: Simulate MXFP (Microscaling Formats) on CPU & GPU using PyTorch & Triton
functional: MXFP format conversion and operations.extract_mxfp_components: Extract MXFP components (shared exponent and elements) from tensorscompose_mxfp_tensor: Compose MXFP components back to standard floating-point tensorsquantize_dequantize: Quantize and dequantize tensors using MXFP formatflatten_for_quantize: Flatten tensors for quantization operationspermute_for_dequantize: Permute tensors for dequantization operationsmxfp_linear: Linear operation with MXFP supportmxfp_matmul: Matrix multiplication with MXFP supportparse_mxfp_linear_type: Parse MXFP linear layer types
layers: PyTorch modules with MXFP support.MXFPLinearPTQ: Linear layer with MXFP support for post-training quantization (no backpropagation support)
-
Minifloat: Simulate minifloat formats on CPU & GPU using PyTorch & Triton
functional: Minifloat format operations.extract_minifloat_component: Extract minifloat components from tensorscompose_minifloat_component: Compose minifloat components back to tensorsquantize_dequantize: Quantize and dequantize tensors using minifloat formatminifloat_linear: Linear operation with minifloat supportminifloat_matmul: Matrix multiplication with minifloat support
layers: PyTorch modules with minifloat support.MinifloatLinearPTQ: Linear layer with minifloat support for post-training quantization (no backpropagation support)
-
Utilities & Management
manager.py: Kernel management and autotune control.KernelManager: Enable/disable autotune for Triton kernels
utils/: Various utility functions for PyTorch modules, debugging, and training.