Skip to content

Commit

Permalink
Temporarily disable channel-wise pruning in praxis to align its imple…
Browse files Browse the repository at this point in the history
…mentation with the sparsity/jax channelwise pruning implementation.

PiperOrigin-RevId: 629473113
  • Loading branch information
The praxis Authors committed Apr 30, 2024
1 parent da54dee commit 963189f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions praxis/layers/quantization/sparsity/sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ def sparsifiy(
):
return weight

assert (
self.sparsity.sparsity_type != SparsityType.CHANNELWISE_PRUNING
), 'Channel-wise pruning is temporarily disabled'

step = self.get_var('step')

# Return without updating mask if in we want to do mixed sparsity for
Expand Down
4 changes: 4 additions & 0 deletions praxis/layers/quantization/sparsity/sparsifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""Tests for sparse_base_layer."""

import copy
import unittest

from absl.testing import absltest
from absl.testing import parameterized
import jax
Expand All @@ -30,6 +32,7 @@
from praxis.layers.quantization.sparsity import sparsity_hparams
from praxis.layers.quantization.sparsity import sparsity_modes


instantiate = base_layer.instantiate
NON_TRAINABLE = base_layer.NON_TRAINABLE
PARAMS = base_layer.PARAMS
Expand Down Expand Up @@ -1222,6 +1225,7 @@ def update(test_layer, params, inputs, updated_weights):
),
)

@unittest.skip('Channel-wise pruning is temporarily disabled')
@parameterized.named_parameters(
('column_wise', -1),
('row_wise', -2),
Expand Down

0 comments on commit 963189f

Please sign in to comment.