Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding configs related to DCLM #663

Open
wants to merge 101 commits into
base: fineweb_data
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
4aa2f2c
Adding configs related to DCLM
abhinavg4 Jul 18, 2024
dde9ed0
Adding configs related to DCLM
abhinavg4 Jul 19, 2024
b991e29
Adding Z loss
abhinavg4 Jul 19, 2024
bb674bb
pre commit changes
abhinavg4 Jul 19, 2024
6c99dfb
Adding z_loss as part of train_lm.py
abhinavg4 Jul 19, 2024
24469e7
Reverting changes to llama.py for z_loss
abhinavg4 Jul 19, 2024
76092c4
Address capacity_type and env variables (#665)
Ivan-Zhou Jul 21, 2024
2e55856
fix best effort test (#662)
dlwh Jul 24, 2024
2e64f14
Enable multislice in launch script (#666)
blahBlahhhJ Jul 26, 2024
4950a8e
Fineweb Text + Partial revert of kiloshard (#669)
dlwh Jul 26, 2024
c17f653
log run_progress for a special x axis. Fixes #671 (#674)
dlwh Jul 26, 2024
ac0882d
refactor trainer to always need a loss function, add z_loss (#672)
dlwh Jul 28, 2024
cb3638e
Specify node_count as int in launch.py (#682)
Ivan-Zhou Aug 4, 2024
8111f29
Bump ray[default] from 2.32.0 to 2.34.0 (#683)
dependabot[bot] Aug 11, 2024
04b0904
wandb seems to be broken in latest release (#688)
dlwh Aug 12, 2024
8c10a7a
switch to setup tools and forget the config thing (#691)
dlwh Aug 14, 2024
e8b6003
set logging level to INFO
dlwh Aug 14, 2024
441af5c
update docker image, build it in ci, make the args point to the new v…
dlwh Aug 14, 2024
ef6349c
RE-Allow adding extrenal directory to docker image (#695)
blahBlahhhJ Aug 15, 2024
e12c1b6
Merge remote-tracking branch 'origin/dclm' into dclm
dlwh Aug 20, 2024
c9ebc88
match specs in dclm
dlwh Aug 20, 2024
7727696
publish dev build
dlwh Aug 21, 2024
55e4d98
wip
dlwh Aug 21, 2024
de51236
fix imports and such
dlwh Aug 22, 2024
7863989
get default zone from gcloud config
dlwh Aug 22, 2024
a550bb5
factor out docker command, build
dlwh Aug 22, 2024
6341252
Merge remote-tracking branch 'origin/main' into dclm
dlwh Aug 22, 2024
715a04a
Update beta2=0.95 (#701)
dlwh Aug 22, 2024
c0ae0f9
publish full tpu image (#703)
dlwh Aug 23, 2024
ca7c9a6
fix incremental build on CI (#704)
dlwh Aug 23, 2024
d16482b
sigh
dlwh Aug 23, 2024
c823c75
grr (#705)
dlwh Aug 23, 2024
7ec7bb5
Adding multiple configs (#685)
abhinavg4 Aug 26, 2024
20faff3
Expose infra as a package, publish dev builds (#696)
dlwh Aug 26, 2024
5c53a19
Llama mixture (#706)
abhinavg4 Aug 26, 2024
277e728
Fix base again (#707)
dlwh Aug 26, 2024
0c628d5
Fix tpu vm autoshutdown (#708)
dlwh Aug 27, 2024
97358f9
suppress stderr in describe_tpu since it usually logs a dumb error (#…
dlwh Aug 28, 2024
e9ca517
Merge remote-tracking branch 'origin/main' into dclm
dlwh Aug 28, 2024
d674dd9
wip
dlwh Aug 29, 2024
4913df2
fix pyprojec.toml and pre-commit wandb issues (#712)
dlwh Aug 29, 2024
06dc304
wip
dlwh Aug 29, 2024
ffa8e28
fix device kind for mfu v5e (#713)
dlwh Aug 29, 2024
fd7888d
add haps configuration (cycle lr schedule) (#709)
blahBlahhhJ Sep 2, 2024
8dd32c6
Bump ray[default] from 2.34.0 to 2.35.0 (#714)
dependabot[bot] Sep 4, 2024
ea4ea25
use hf config from checkpoint by default (#715)
dlwh Sep 4, 2024
fbe27bc
Completely rework dataset/cache system: instant resume, perfect shuff…
dlwh Sep 5, 2024
944a19f
unpin ray (#718)
dlwh Sep 5, 2024
f13cfde
bump equinox
dlwh Sep 5, 2024
8d3dfe0
wip
dlwh Sep 6, 2024
8ecb7ea
768
dlwh Sep 6, 2024
9ba6b20
Update gcsfs requirement from <2024.7,>=2024.2 to >=2024.2,<2024.10 (…
dependabot[bot] Sep 10, 2024
78da902
Update fsspec[http] requirement (#722)
dependabot[bot] Sep 10, 2024
5b685c3
Bump equinox from 0.11.4 to 0.11.5 (#721)
dependabot[bot] Sep 10, 2024
a91ef81
fix extra context docker build bug (#724)
blahBlahhhJ Sep 11, 2024
5c18557
Fix eqx (#725)
dlwh Sep 12, 2024
b6f334e
get rid of eraconfig b/c draccus can't handle it
dlwh Sep 13, 2024
e33a905
ugh
dlwh Sep 13, 2024
2645efb
missed some prints
dlwh Sep 13, 2024
5fc4084
attempt at launching small fast in CI, add tqdm_loggable (#719)
dlwh Sep 13, 2024
d05036c
Update datasets requirement from ~=2.18 to >=2.18,<4.0 (#732)
dependabot[bot] Sep 16, 2024
ca16aa0
Bump tensorstore from 0.1.64 to 0.1.65 (#731)
dependabot[bot] Sep 16, 2024
79fa64c
Bump equinox from 0.11.3 to 0.11.6 (#730)
dependabot[bot] Sep 16, 2024
07b3f16
add bits-per-byte calculation to levanter (#729)
dlwh Sep 18, 2024
fe3e2f3
fix sequence parallel attention in splash attention (#738)
dlwh Sep 22, 2024
9fa3aaa
fix llama 3 rotary embeddings (#740)
dlwh Sep 24, 2024
2b42bfb
Support for running in a Ray cluster (#737)
dlwh Sep 24, 2024
8ad3074
see if it's this file in particular (#742)
dlwh Sep 24, 2024
541ff12
Update README.md (#656)
devactivity-team Sep 24, 2024
91be677
bump levanter version (#743)
dlwh Sep 25, 2024
cd82fb3
Make new tokenization ~67% faster (#744)
dlwh Sep 25, 2024
43268e0
Adding supervised data config
TheQuantumFractal Sep 25, 2024
d6ad71f
Fixing linter error
TheQuantumFractal Sep 25, 2024
71bd696
Tweaks to Ray TPU stuff (#747)
dlwh Sep 26, 2024
f5b32cd
Fixing supervised training
TheQuantumFractal Sep 27, 2024
6483b42
Making linter happy
TheQuantumFractal Sep 27, 2024
45d41d8
Making linter happy
TheQuantumFractal Sep 27, 2024
b41838f
Simplify tokenization pipeline, make it work with large numbers of sh…
dlwh Oct 4, 2024
3bae9d3
allow mixture components to override cache_dir (#754)
dlwh Oct 5, 2024
9847728
a few final tweaks for marin runs (#755)
dlwh Oct 5, 2024
36b29fd
Update Audio Data Loader to Support Mixture Dataset (#758)
Helw150 Oct 9, 2024
5370c72
Update src/levanter/data/text.py
ahmeda14960 Oct 9, 2024
1063fd8
Merge remote-tracking branch 'origin/main' into ksalahi/supervised-data
ahmeda14960 Oct 9, 2024
2f625d3
address david's comments
ahmeda14960 Oct 9, 2024
cf2c9e5
lint and minor
ahmeda14960 Oct 9, 2024
8bed0aa
Adding supervised data config (#746)
ahmeda14960 Oct 9, 2024
adf4b6d
Add an actor pool for batch processing, switch to a thread for writin…
dlwh Oct 10, 2024
36459da
pre-commit
dlwh Oct 10, 2024
6499656
flaky hf
dlwh Oct 10, 2024
074477f
Fix actor pool in python 3.11, add better scaling down logic (#760)
dlwh Oct 10, 2024
1c0e10e
Fix ray docs (#761)
blahBlahhhJ Oct 11, 2024
51f9bf1
ensure everything always uses at least some CPU to avoid flooding ray…
dlwh Oct 11, 2024
c3b3dd8
cap the size of the core writer task rather than the number of batche…
dlwh Oct 11, 2024
52bff4f
add parquet support
nikil-ravi Oct 13, 2024
af78281
lint, shard name fix
nikil-ravi Oct 13, 2024
8d09cfd
pre-commit
nikil-ravi Oct 13, 2024
50715e9
read as binary file
nikil-ravi Oct 13, 2024
3fe8995
simplify test
nikil-ravi Oct 14, 2024
fc26c74
Support Parquet files in ShardedDataSource (#764)
nikil-ravi Oct 14, 2024
02f34ac
fix crash in data loader caused by using stale array (#765)
dlwh Oct 14, 2024
0ea3eb4
Merge remote-tracking branch 'origin/main' into dclm
dlwh Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions config/data/dclm_gpt_neo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
cache_dir: "gs://marin-data/tokenized/dclm/gpt_neo_tokenizer"
tokenizer: "EleutherAI/gpt-neox-20b"
stop_strategy: restart
configs:
"dclm":
train_urls:
- gs://marin-data/datacomp/dclm-baseline-dedup-07-09/*/*/*.jsonl.zstd
# these are just for eval
"paloma/4chan":
validation_urls:
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz
"paloma/c4_100_domains":
validation_urls:
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz
"paloma/c4_en":
validation_urls:
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz
"paloma/dolma-v1_5":
validation_urls:
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz
"paloma/dolma_100_programing_languages":
validation_urls:
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz
"paloma/dolma_100_subreddits":
validation_urls:
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz
"paloma/falcon-refinedweb":
validation_urls:
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz
"paloma/gab":
validation_urls:
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz
"paloma/m2d2_s2orc_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz
"paloma/m2d2_wikipedia_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
"paloma/manosphere_meta_sep":
validation_urls:
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz
"paloma/mc4":
validation_urls:
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz
"paloma/ptb":
validation_urls:
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz
"paloma/redpajama":
validation_urls:
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz
"paloma/twitterAAE_HELM_fixed":
validation_urls:
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz
"paloma/wikitext_103":
validation_urls:
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz
train_weights:
dclm: 1.0
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
30 changes: 30 additions & 0 deletions config/llama_1b_dclm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
data: !include data/dclm_gpt_neo.yaml
model: # 1B class model
type: llama
seq_len: 2048
hidden_dim: 2048
intermediate_dim: 8192
num_layers: 24
num_heads: 16
num_kv_heads: 16
use_flash_attention: True
flash_attention_block_size: 1024
trainer:
tracker:
type: wandb
project: "marin"
tags: ["llama", "fineweb", "markdown"]

mp: p=f32,c=bfloat16
train_batch_size: 256 # 2048 * 2048 = 4,194,304
num_train_steps: 71526 # 300,000,000,000 / 4,194,304 = 71,526
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 3E-3
weight_decay: 0.033
min_lr_ratio: 0.1
warmup: 5000
cooldown: 3E-5
29 changes: 29 additions & 0 deletions config/llama_7b_with_dclm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
data: !include data/dclm_gpt_neo.yaml
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
flash_attention_block_size: 1024
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dclm", "7B", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4E-4
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 0.01
29 changes: 27 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.random as jrandom

import haliax as hax
from haliax import Axis
from haliax import Axis, Scalar
from haliax.partitioning import named_jit, round_axis_for_partitioning

import levanter
Expand All @@ -19,12 +19,29 @@
from levanter.models.lm_model import LmConfig
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.types import ComputeLossFunction, M, X
from levanter.utils.jax_utils import parameter_count


logger = logging.getLogger(__name__)


class ModuleComputeZLoss(ComputeLossFunction[M, X]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i still don't like this but I think I can't really articulate what I want. i'm gonna push a change to my fork

"""
Loss that just delegates to the model's compute_z_loss method.
"""

def __call__(
self,
model,
*inputs: X,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
**kwargs,
) -> Scalar | hax.NamedArray:
return model.compute_z_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)


@dataclass
class TrainLmConfig:
data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig)
Expand All @@ -48,6 +65,7 @@ class TrainLmConfig:

update_hessian_steps: int = 10
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
z_loss_weight: float = 0.0


def main(config: TrainLmConfig):
Expand Down Expand Up @@ -82,11 +100,18 @@ def main(config: TrainLmConfig):
levanter.initialize(config)
optimizer = config.optimizer.build(config.trainer.num_train_steps)

loss_fn: Optional[ComputeLossFunction] = None

if config.z_loss_weight > 0:
loss_fn = ModuleComputeZLoss()
else:
loss_fn = None # It will be automatically set to the default loss function in the model

# Using the trainer as a context manager does 3 things:
# 1. Sets the device mesh
# 2. Sets the axis mapping (for fsdp)
# 3. Sets the global metrics tracker
with Trainer(config.trainer, optimizer) as trainer:
with Trainer(config.trainer, optimizer, loss_fn) as trainer:
# randomness in jax is tightly controlled by "keys" which are the states of the random number generators
# this makes deterministic training pretty easy
seed = config.trainer.seed
Expand Down
28 changes: 28 additions & 0 deletions src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import haliax as hax
from haliax import Axis, NamedArray
from haliax.nn import cross_entropy_loss
from haliax.nn.loss import maybe_reduce_loss

from levanter.models.attention import AttentionMask
from levanter.models.loss import cross_entropy_and_logsumexp_penalty


LmConfigT = TypeVar("LmConfigT", bound="LmConfig")
Expand Down Expand Up @@ -137,6 +139,32 @@ def compute_loss(

return loss

def compute_z_loss(
self,
example: LmExample,
z_loss_weight,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
) -> jnp.ndarray | NamedArray:
"""
Computes the cross-entropy loss for a language modeling example with z_loss.
If reduction is not None, the loss is reduced
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
reduced, and the result is a named array with axes (*batch axes, sequence_length).
"""
logits = self(example.tokens, example.attn_mask, key=key)
# TODO: would be nice if we made the dtype configurable
logits = logits.astype(jnp.float32)
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
loss = cross_entropy_and_logsumexp_penalty(
logits, self.Vocab, target_y, logsumexp_weight=self.config.z_loss_weight
)
loss = maybe_reduce_loss(loss, reduction=reduction, reduction_axis=reduction_axis, where=example.loss_mask)
return loss

@property
def vocab_size(self) -> int:
return self.Vocab.size
Loading