Combining Compression Types
===========================

In previous sections on [palettization](opt-palettization), 
[quantization](opt-quantization), and [sparsity](opt-pruning), 
we considered how to apply the various compression 
techniques to the weights and activations of the model independently. 
In this section, we describe how these techniques can be combined, 
which may be beneficial to get even more disk savings and latency improvements. 

We first start by looking at how to take an uncompressed `mlpackage` 
and get a joint compressed model by using the `ct.optimize.coreml.*` APIs. 
As discussed in previous sections, this approach may or 
may not yield a highly accurate model. In some cases, however, 
this is the best way to get the model in the desired format to test 
out the expected disk size savings and performance (latency, runtime memory etc). 
Once a model has the desired performance characteristics, 
a better accuracy model can be generated by applying the various data 
based optimization methods available in `ct.optimize.torch.*`. 
This last topic is discussed via a few API code snippets in the section below. 

## Combining compression types on an mlpackage 

### Joint palettization and quantization 

This means using a lookup table (LUT) whose values are of the 
dtype INT8/UINT8 instead of Float16 which is the default. 
This can help speed up inference when combined with INT8 activations. 
For instance, you could take a A16W16 model, quantize the activations to get A8W16 model, and then quantize the 
weights to a 4-bit LUT with INT8 dtype to yield an A8W4 model, 
where “W4” refers to a palettized weights with a LUT that has 2^4 entries, and 
each entry has a dtype of INT8. 
When such a model is run on the Neural Engine (on newer SoCs >= A17pro, M4), 
it will utilize the faster int8-int8 compute path. 

```python
from coremltools.optimize.coreml import (
   OptimizationConfig,
   OpPalettizerConfig,
   OpLinearQuantizerConfig,
   palettize_weights,
   linear_quantize_weights,
 ) 
     
# mlmodel: an uncompressed mlpackage, loaded into memory 
                                                                          
# first palettize the model
# this will produce an LUT with Float values
op_config = OpPalettizerConfig(nbits=4)
config = OptimizationConfig(global_config=op_config)
mlmodel_palettized = palettize_weights(mlmodel, config)

# now apply weight quantization on the model, 
# with "joint_compression" set to True. 
# this will result in quantizing the LUT to 8 bits. 
# (granularity must be set to "per-tensor" for this scenario) 
op_config = OpLinearQuantizerConfig(mode="linear_symmetric",  
                                    granularity="per_tensor")
linear_weight_quantize_config = OptimizationConfig(global_config=op_config)

mlmodel_palettized_with_8bit_lut = linear_quantize_weights(mlmodel_palettized, 
                                                           linear_weight_quantize_config, 
                                                           joint_compression=True)
```

### Joint sparsity and quantization

This means quantizing the non-zero values in the sparse weight tensor to INT8/UINT8 values. 
This could improve inference speed and disk savings.

```python
from coremltools.optimize.coreml import (
   OptimizationConfig,
   OpMagnitudePrunerConfig,
   OpLinearQuantizerConfig,
   prune_weights,
   linear_quantize_weights,
 )
 
# first prune the model
op_config = OpMagnitudePrunerConfig(target_sparsity=0.80)
config = OptimizationConfig(global_config=op_config)
mlmodel_pruned = prune_weights(mlmodel, config=config)

# now apply weight quantization on the model,
# with "joint_compression" set to True. 
# this will result in quantizing the non-zero values to 8 bits. 
linear_weight_quantize_config = OptimizationConfig(
    global_config=OpLinearQuantizerConfig(mode="linear_symmetric")
)
mlmodel_pruned_quantized = linear_quantize_weights(mlmodel_pruned, 
                                                   linear_weight_quantize_config, 
                                                   joint_compression=True)
```

### Joint sparsity and palettization 

This means representing the non-zero values in a sparse weight tensor with 
discrete values pointing to a lookup table (i.e. palettized). 

```python
from coremltools.optimize.coreml import (
   OptimizationConfig,
   OpMagnitudePrunerConfig,
   OpPalettizerConfig,
   prune_weights,
   palettize_weights,
 )
 
# first prune the model
op_config = OpMagnitudePrunerConfig(target_sparsity=0.80)
pruning_config = OptimizationConfig(global_config=op_config)
mlmodel_pruned = prune_weights(mlmodel, config=pruning_config)

# now apply weight palettization on the model, 
# with "joint_compression" set to True. 
# this will result in palettizing the non-zero values. 
palettization_config = OptimizationConfig(global_config=OpPalettizerConfig(nbits=4))
mlmodel_pruned_palettized = palettize_weights(mlmodel_pruned, 
                                              palettization_config, 
                                              joint_compression=True)

```


## Combining compression types on a Torch model


### Joint palettization and quantization 

This means using a lookup table (LUT) whose values 
are of the dtype INT8/UINT8 instead of the default Float16. 

```python
import torchvision 
import torch
import coremltools as ct
from coremltools.optimize.torch.palettization import PostTrainingPalettizerConfig,\
                                                     PostTrainingPalettizer

# load a torch model
# e.g. resnet50 
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")
model.eval()

# specify "lut_dtype" as torch.int8
# when not specified, it defaults to None and FP16 LUT is constructed 
config_dict = {"global_config": {"n_bits": 4, "lut_dtype" : torch.int8}}
palettizer_config = PostTrainingPalettizerConfig.from_dict(config_dict)

compressor = PostTrainingPalettizer(model, palettizer_config)
compressed_model = compressor.compress()   

# convert the compressed model
traced_model = torch.jit.trace(compressed_model, torch.rand(1, 3, 256, 256))

mlmodel = ct.convert(traced_model,
                     inputs=[ct.TensorType(shape=(1, 3, 256, 256))],
                     minimum_deployment_target=ct.target.macOS15,
                     )
mlmodel.save("model_4bit_palettized_with_8bit_quantized_lut.mlpackage")
```


### Joint sparsity and quantization

One way to combine sparsity and quantization is to 
first prune a torch model using the `MagnitudePruner` class, 
export it as an `mlpackage` and then apply weight quantization (A16W8) on 
the `mlpackage`, as shown in the section above. 

However, if we want to apply pruning and weight-only quantization (A16W8) 
on the torch model at training time, it can be done in the way explained below.
 
Note that if `"activation_dtype"` argument in 
`ModuleLinearQuantizerConfig` is set to its default value of `torch.qint8`, 
then activations will also be quantized to get an A8W8 model.

```python
import torchvision 
import torch
import coremltools as ct
from coremltools.optimize.torch.quantization import ModuleLinearQuantizerConfig, \
                                                    LinearQuantizerConfig, \
                                                    LinearQuantizer

from coremltools.optimize.torch.pruning import ModuleMagnitudePrunerConfig, \
                                               MagnitudePrunerConfig, \
                                               MagnitudePruner 
                                               

# Initialize model and optimizer
# e.g. Resnet50
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Prepare model for joint quantization and pruning
quant_config = LinearQuantizerConfig(
    global_config=ModuleLinearQuantizerConfig(
        quantization_scheme="symmetric",
        activation_dtype=torch.float,
    )
)
prune_config = MagnitudePrunerConfig(
    global_config=ModuleMagnitudePrunerConfig(
         target_sparsity=0.8,       
    )
)


# The quantizer config needs to be applied before the pruner config

quantizer = LinearQuantizer(model, quant_config)
quant_model = quantizer.prepare(example_inputs=[1, 3, 256, 256])

pruner = MagnitudePruner(quant_model, prune_config)
#  in-place is required to ensure quantizer and pruner are 
# operating on the same model 
pruned_quant_model = pruner.prepare(inplace=True)


n_classes = 1000
batch_size = 5
# run a couple of training iterations with random data
for i in range(2):
    # Dummy data
    inputs = torch.randn(batch_size, 3, 256, 256)  # Batch of samples
    targets = torch.randint(0, n_classes, (batch_size,))  # Target labels 
    
    # Forward pass
    logits = pruned_quant_model(inputs)
    out = torch.nn.LogSoftmax(dim=1)(logits)
    loss = torch.nn.functional.nll_loss(out, targets)
    print(loss)
    
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward() 
    optimizer.step()
    quantizer.step()
    pruner.step()         
    

# finalize the model for export
# we first finalize the quantizer followed by the pruner 
quant_finalized_model = quantizer.finalize(inplace=True)
finalized_model = pruner.finalize(quant_finalized_model, inplace=True)
finalized_model.eval()                                                                                                                                                                                          


# trace and export to mlpackage
traced_model = torch.jit.trace(finalized_model, torch.rand(1, 3, 256, 256))
mlmodel = ct.convert(traced_model,
                     inputs=[ct.TensorType(shape=(1, 3, 256, 256))],
                     minimum_deployment_target=ct.target.macOS15,
                     )
mlmodel.save("model_torch_pruned_and_quantized.mlpackage")  
```


### Joint sparsity and palettization 

Here we apply magnitude pruning to the torch model, followed by data-free palettization.

Note: to apply training time pruning and 
palettization (e.g. [DKM](opt-palettization-algos.md#differentiable-k-means)), follow
the same pattern as the section above, replacing the `LinearQuantizer` with `DKMPalettizer`.

```python
import torchvision 
import torch
import coremltools as ct


from coremltools.optimize.torch.pruning import ModuleMagnitudePrunerConfig, \
                                               MagnitudePrunerConfig, \
                                               MagnitudePruner 
                                               
from coremltools.optimize.torch.palettization import PostTrainingPalettizer, \
                                                 PostTrainingPalettizerConfig, \
                                                 ModulePostTrainingPalettizerConfig                                               


# Apply pruning 

# Initialize model and optimizer
# e.g. Resnet50
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Prepare model for pruning
prune_config = MagnitudePrunerConfig(
    global_config=ModuleMagnitudePrunerConfig(
         target_sparsity=0.8,       
    )
)

pruner = MagnitudePruner(model, prune_config)
pruned_model = pruner.prepare()

# run a couple of training iterations with random data
n_classes = 1000
batch_size = 5
for i in range(2):
    inputs = torch.randn(batch_size, 3, 256, 256)  # Batch of samples
    targets = torch.randint(0, n_classes, (batch_size,))  # Target labels 
    logits = pruned_model(inputs)
    out = torch.nn.LogSoftmax(dim=1)(logits)
    loss = torch.nn.functional.nll_loss(out, targets)
    optimizer.zero_grad()
    loss.backward() 
    optimizer.step()
    pruner.step()

# finalize model 
pruned_model = pruner.finalize(pruned_model, inplace=True)



# Apply palettization 
palettization_config = PostTrainingPalettizerConfig(
    global_config=ModulePostTrainingPalettizerConfig(
         n_bits=4,       
    )
)

palettizer = PostTrainingPalettizer(pruned_model, palettization_config)
joint_compressed_model = palettizer.compress()

# convert the compressed model
joint_compressed_model.eval()
traced_model = torch.jit.trace(joint_compressed_model, torch.rand(1, 3, 256, 256))

mlmodel = ct.convert(traced_model,
                     inputs=[ct.TensorType(shape=(1, 3, 256, 256))],
                     minimum_deployment_target=ct.target.macOS15,
                     )
mlmodel.save("model_torch_pruned_and_palettized.mlpackage")
```