utils.quantization

utils.quantization

Utilities for quantization including QAT and PTQ using torchao.

Functions

Name Description
convert_qat_model This function converts a QAT model which has fake quantized layers back to the original model.
get_quantization_config This function is used to build a post-training quantization config.
patch_transformers_skip_quantized_init Stop from_pretrained from re-initializing torchao-quantized weights.
prepare_model_for_qat This function is used to prepare a model for QAT by swapping the model’s linear
quantize_model This function is used to quantize a model.
save_quantized_model Save a quantized model, handling MXTensor serialization.

convert_qat_model

utils.quantization.convert_qat_model(model, quantize_embedding=False)

This function converts a QAT model which has fake quantized layers back to the original model.

get_quantization_config

utils.quantization.get_quantization_config(
    weight_dtype,
    activation_dtype=None,
    group_size=None,
)

This function is used to build a post-training quantization config.

Parameters

Name Type Description Default
weight_dtype TorchAOQuantDType The dtype to use for weight quantization. required
activation_dtype TorchAOQuantDType | None The dtype to use for activation quantization. None
group_size int | None The group size to use for weight quantization. None

Returns

Name Type Description
AOBaseConfig The post-training quantization config.

Raises

Name Type Description
ValueError If the activation dtype is not specified and the weight dtype is not int8 or int4, or if the group size is not specified for int8 or int4 weight only quantization.

patch_transformers_skip_quantized_init

utils.quantization.patch_transformers_skip_quantized_init()

Stop from_pretrained from re-initializing torchao-quantized weights.

transformers re-runs _init_weights on every module during loading; the generic implementation does init.normal_(module.weight.float(), ...). .float() on a torchao tensor subclass (e.g. MXTensor) returns a new tensor that both drops the _is_hf_initialized skip flag and does not implement normal_, so loading an MX checkpoint raises NotImplementedError. Re-initializing an already-loaded quantized weight is never correct, so we skip those modules entirely.

prepare_model_for_qat

utils.quantization.prepare_model_for_qat(
    model,
    weight_dtype,
    group_size=None,
    activation_dtype=None,
    quantize_embedding=False,
)

This function is used to prepare a model for QAT by swapping the model’s linear layers with fake quantized linear layers, and optionally the embedding weights with fake quantized embedding weights.

Parameters

Name Type Description Default
model The model to quantize. required
weight_dtype TorchAOQuantDType The dtype to use for weight quantization. required
group_size int | None The group size to use for weight quantization. None
activation_dtype TorchAOQuantDType | None The dtype to use for activation quantization. None
quantize_embedding bool Whether to quantize the model’s embedding weights. False

Raises

Name Type Description
ValueError If the activation/weight dtype combination is invalid.

quantize_model

utils.quantization.quantize_model(
    model,
    weight_dtype,
    group_size=None,
    activation_dtype=None,
    quantize_embedding=None,
)

This function is used to quantize a model.

Parameters

Name Type Description Default
model The model to quantize. required
weight_dtype TorchAOQuantDType The dtype to use for weight quantization. required
group_size int | None The group size to use for weight quantization. None
activation_dtype TorchAOQuantDType | None The dtype to use for activation quantization. None
quantize_embedding bool | None Whether to quantize the model’s embedding weights. None

save_quantized_model

utils.quantization.save_quantized_model(model, save_dir, **kwargs)

Save a quantized model, handling MXTensor serialization.

MXTensor does not have a valid storage pointer, which causes save_pretrained to crash (both in remove_tied_weights_from_state_dict via id_tensor_storage, and in safetensors serialization). Transformers >=5.5 removed the safe_serialization parameter entirely.

For MX-quantized models we save the config/generation_config via save_pretrained machinery and the weights via torch.save.