Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional gate activation histogram logging during eval #641

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Aphoh
Copy link
Contributor

@Aphoh Aphoh commented Jun 21, 2024

The main goal of this PR is to add the ability to log activation statistics of the MLPs of models.

In it's current state, this involves one big, slightly inconvenient change: every model's compute_loss function now returns a tuple of (loss: Array, extras: dict) where extras can contain any auxiliary data to log. Thus all the upstream code had to be modified to accomidate this change.

Currently there's code to measure the activation statistics of llama models during eval only, as computing the histograms is incredibly inefficient on TPUs. For LLaMa 7b, computing the histograms takes roughly 4x as long as the rest of the forward pass. AFAIK there's no faster way to do this, but it's just during eval so 🤷.

The code's a little messy, so some review would be appreciated.

Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so I kinda want to not make a whole bunch of changes to the model API just yet, and would rather have a guide on how to hack this in, since these things tend to be special snowflakes.

I also wonder if we just should consider using a debug callback (see e.g. jit_log_metrics) which is a bit gross from a functional purity perspective, but for logging I think it's fine?

activation_function = hf_config.hidden_act
# This is the implementation in huggingface
# https://github.com/huggingface/transformers/blob/12b1620e615592fbf099d4ec44af7b9f2d1b48aa/src/transformers/models/gemma/modeling_gemma.py#L200
activation_function = "gelu_pytorch_tanh"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i swore we already did this

@@ -123,6 +125,17 @@ def eval_callback(step: StepInfo):
_join_prefix(prefix, "loading_time"): result.total_eval_loading_time,
_join_prefix(prefix, "total_time"): time_fn(),
}
if (gate_hist := result.extras.get("gate_hist", None)) is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i think i'm gonna have a strong preference for

  1. extracting this block (and the part in the loop) into a class (sort of like runningmean)
  2. not actually checking the usage of it in taggedevaluator (or in the models) into main, but instead
  3. making a little guide on how to add it in, since it's something that people want to play with sometimes but kinda adds a bunch of noise

@@ -193,16 +203,20 @@ def init(
if isinstance(activation_fn, str):
activation_fn = ACT2FN[activation_fn]
act = activation_fn # type: ignore
return LlamaMlp(gate_proj, up_proj, down_proj, act)
get_bins() # initialize bins
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm?

if extras:
for key in extras:
curr = total_extras.get(key, jnp.zeros_like(extras[key]))
total_extras[key] = extras[key] + curr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is summing always going to be the right reduction here?

NBINS = 2 * NSIDE + 3


@jax.jit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally speaking it's not worth putting jit around helpers, though sometimes it is

return _BINS


BIN_AX = Axis("bins", NBINS - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm?



@jax.jit
def histogram(a: Array, bins: Array) -> Array:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a reference to that git issue about why we need this?



@jax.jit
def sharded_histogram(a: Array, bins: Array) -> Array:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's maybe just call this histogram and the other thing _histogram?

src/levanter/tracker/histograms.py Show resolved Hide resolved


@jax.jit
def get_bins() -> Array:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this take a number of bins (with the current default?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants