Skip to content

Commit

Permalink
Add sparsity core library which will eventually replace sparsity_core…
Browse files Browse the repository at this point in the history
…_depr.

PiperOrigin-RevId: 625750989
  • Loading branch information
laurentes authored and pax authors committed Apr 17, 2024
1 parent 08fed24 commit 4c3e809
Show file tree
Hide file tree
Showing 11 changed files with 1,488 additions and 24 deletions.
4 changes: 2 additions & 2 deletions praxis/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,6 @@ pytype_strict_test(
":repeats",
# Implicit absl.testing.absltest dependency.
# Implicit absl.testing.parameterized dependency.
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_hparams",
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_modes",
# Implicit upb python proto dependency.
# Implicit jax dependency.
# Implicit numpy dependency.
Expand All @@ -569,6 +567,8 @@ pytype_strict_test(
"//praxis:py_utils",
"//praxis:test_utils",
"//praxis/layers/quantization:linears",
"//praxis/layers/quantization/sparsity:sparsity_hparams",
"//praxis/layers/quantization/sparsity:sparsity_modes",
],
)

Expand Down
56 changes: 47 additions & 9 deletions praxis/layers/quantization/sparsity/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,53 @@ pytype_strict_library(
srcs = ["__init__.py"],
)

pytype_strict_library(
name = "sparsity_hparams",
srcs = ["sparsity_hparams.py"],
deps = [":sparsity_modes"],
)

pytype_strict_library(
name = "sparsity",
srcs = ["sparsity.py"],
deps = [
":sparsity_hparams",
# Implicit jax dependency.
],
)

pytype_strict_test(
name = "sparsity_test",
srcs = ["sparsity_test.py"],
deps = [
":sparsity",
":sparsity_hparams",
# Implicit absl.testing.absltest dependency.
# Implicit absl.testing.parameterized dependency.
# Implicit upb python proto dependency.
# Implicit jax dependency.
# Implicit numpy dependency.
],
)

pytype_strict_library(
name = "sparsity_modes",
srcs = ["sparsity_modes.py"],
deps = [
# Implicit jax dependency.
"//praxis:pytypes",
],
)

pytype_strict_test(
name = "linears_test",
srcs = ["linears_test.py"],
deps = [
":sparsity_hparams",
":sparsity_modes",
# Implicit absl.logging dependency.
# Implicit absl.testing.absltest dependency.
# Implicit absl.testing.parameterized dependency.
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_hparams",
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_modes",
# Implicit upb python proto dependency.
# Implicit jax dependency.
# Implicit numpy dependency.
Expand All @@ -50,10 +88,10 @@ pytype_strict_test(
name = "attentions_test",
srcs = ["attentions_test.py"],
deps = [
":sparsity_hparams",
":sparsity_modes",
# Implicit absl.testing.absltest dependency.
# Implicit absl.testing.parameterized dependency.
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_hparams",
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_modes",
# Implicit upb python proto dependency.
# Implicit jax dependency.
# Implicit numpy dependency.
Expand All @@ -70,9 +108,9 @@ pytype_strict_library(
name = "sparsifier",
srcs = ["sparsifier.py"],
deps = [
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_core",
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_hparams",
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_modes",
":sparsity",
":sparsity_hparams",
":sparsity_modes",
# Implicit jax dependency.
"//praxis:base_layer",
"//praxis:pytypes",
Expand All @@ -84,10 +122,10 @@ pytype_strict_test(
srcs = ["sparsifier_test.py"],
deps = [
":sparsifier",
":sparsity_hparams",
":sparsity_modes",
# Implicit absl.testing.absltest dependency.
# Implicit absl.testing.parameterized dependency.
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_hparams",
"//third_party/py/aqt/jax_legacy/jax/sparsity:sparsity_modes",
# Implicit upb python proto dependency.
# Implicit jax dependency.
# Implicit numpy dependency.
Expand Down
5 changes: 2 additions & 3 deletions praxis/layers/quantization/sparsity/attentions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax_legacy.jax.sparsity import sparsity_hparams
from aqt.jax_legacy.jax.sparsity import sparsity_modes
import jax
from jax import numpy as jnp
import numpy as np
Expand All @@ -31,7 +29,8 @@
from praxis import test_utils
from praxis.layers import attentions
from praxis.layers.quantization import attentions as sattentions

from praxis.layers.quantization.sparsity import sparsity_hparams
from praxis.layers.quantization.sparsity import sparsity_modes

NON_TRAINABLE = base_layer.NON_TRAINABLE
SPARSITY_NAME_POSTFIX = base_layer.SPARSITY_NAME_POSTFIX
Expand Down
5 changes: 2 additions & 3 deletions praxis/layers/quantization/sparsity/linears_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax_legacy.jax.sparsity import sparsity_hparams
from aqt.jax_legacy.jax.sparsity import sparsity_modes
import jax
from jax import numpy as jnp
import numpy as np
Expand All @@ -31,7 +29,8 @@
from praxis import test_utils
from praxis.layers import linears
from praxis.layers.quantization import linears as slinears

from praxis.layers.quantization.sparsity import sparsity_hparams
from praxis.layers.quantization.sparsity import sparsity_modes

instantiate = base_layer.instantiate
NON_TRAINABLE = base_layer.NON_TRAINABLE
Expand Down
6 changes: 3 additions & 3 deletions praxis/layers/quantization/sparsity/sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import functools
from typing import Optional, Sequence

from aqt.jax_legacy.jax.sparsity import sparsity_core as sparsity
from aqt.jax_legacy.jax.sparsity import sparsity_hparams
from aqt.jax_legacy.jax.sparsity import sparsity_modes
import jax
from jax import numpy as jnp
from praxis import base_layer
from praxis import pytypes
from praxis.layers.quantization.sparsity import sparsity
from praxis.layers.quantization.sparsity import sparsity_hparams
from praxis.layers.quantization.sparsity import sparsity_modes


SparsityType = sparsity_hparams.SparsityType
Expand Down
4 changes: 2 additions & 2 deletions praxis/layers/quantization/sparsity/sparsifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import copy
from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax_legacy.jax.sparsity import sparsity_hparams
from aqt.jax_legacy.jax.sparsity import sparsity_modes
import jax
from jax import numpy as jnp
import numpy as np
Expand All @@ -29,6 +27,8 @@
from praxis.layers import linears
from praxis.layers import quantization
from praxis.layers.quantization.sparsity import sparsifier
from praxis.layers.quantization.sparsity import sparsity_hparams
from praxis.layers.quantization.sparsity import sparsity_modes

instantiate = base_layer.instantiate
NON_TRAINABLE = base_layer.NON_TRAINABLE
Expand Down
Loading

0 comments on commit 4c3e809

Please sign in to comment.