From fff4dfb954c9495e85bd9bc6863d353b90732fb5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 25 Oct 2023 13:07:08 -0700 Subject: [PATCH 001/205] wip --- src/levanter/logging.py | 91 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 4906d0484..465bebeb4 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -1,3 +1,4 @@ +import abc import contextlib import dataclasses import logging as pylogging @@ -7,7 +8,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import draccus import jax @@ -22,6 +23,92 @@ logger = pylogging.getLogger(__name__) +class LoggerSink(abc.ABC): + + @abc.abstractmethod + def init(self, run_id: Optional[str]): + pass + + @abc.abstractmethod + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + + @abc.abstractmethod + def log(self, metrics: dict[str, Any], *, step): + """ + Log metrics to the logger. Step is always required. + """ + pass + + @abc.abstractmethod + def log_summary(self, metrics: dict[str, Any]): + pass + + @abc.abstractmethod + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pass + +class WandbLoggerSink(LoggerSink): + def __init__(self, config: 'WandbConfig'): + self.config = config + self._run = None + + def init(self, run_id: Optional[str]): + self._run = self.config.init(run_id) + + def log_hyperparameters(self, hparams: dict[str, Any]): + if self._run is None: + raise RuntimeError("Must call init before logging hyperparameters") + self._run.config.update(hparams) + + def log(self, metrics: dict[str, Any], *, step): + if self._run is None: + raise RuntimeError("Must call init before logging metrics") + self._run.log(metrics, step=step) + + def log_summary(self, metrics: dict[str, Any]): + if self._run is None: + raise RuntimeError("Must call init before logging summary") + self._run.summary.update(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + if self._run is None: + raise RuntimeError("Must call init before logging artifacts") + self._run.log_artifact(artifact, name=name, type=type) + + +class TensorboardLoggerSink(LoggerSink): + + def __init__(self, logdir: Union[str, Path]): + self.logdir = logdir + self.writer = None + + def init(self, run_id: Optional[str]): + from tensorboardX import SummaryWriter + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + self.writer = SummaryWriter(dir_to_write) + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.writer.add_hparams(hparams, {"dummy": 0}) + + def log(self, metrics: dict[str, Any], *, step): + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_summary(self, metrics: dict[str, Any]): + for k, v in metrics.items(): + self.writer.add_scalar(k, v, 0) + + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + + + def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): if isinstance(opt_state, MultiStepsState): @@ -260,6 +347,8 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams): wandb.summary["num_hosts"] = jax.process_count() wandb.summary["backend"] = jax.default_backend() + return r + @staticmethod def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: # sniff out the main directory (since we typically don't run from the root of the repo) From 740ad6898d157815a9fea60c81b5c0677dc146a8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 7 Nov 2023 08:13:20 -0800 Subject: [PATCH 002/205] wip --- src/levanter/logging.py | 154 +++++++++++++++++++++++------ src/levanter/main/cache_dataset.py | 4 +- src/levanter/trainer.py | 2 +- src/levanter/utils/jax_utils.py | 5 + 4 files changed, 134 insertions(+), 31 deletions(-) diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 465bebeb4..bb28be6d1 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -5,6 +5,7 @@ import os import tempfile import time +import typing import warnings from dataclasses import dataclass from pathlib import Path @@ -12,7 +13,6 @@ import draccus import jax -import wandb from draccus import field from git import InvalidGitRepositoryError, NoSuchPathError, Repo from optax import MultiStepsState @@ -21,10 +21,73 @@ from levanter.utils.jax_utils import jnp_to_python -logger = pylogging.getLogger(__name__) +pylogger = pylogging.getLogger(__name__) -class LoggerSink(abc.ABC): +_global_logger: Optional["MetricsLogger"] = None + +def log_metrics(metrics: dict[str, Any], *, step): + """ + Log metrics to the global logger. + + :param metrics: Metrics to log + :param step: Step to log metrics at + """ + global _global_logger + if _global_logger is None: + raise RuntimeError("No global logger set") + + _global_logger.log(metrics, step=step) + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global logger. + + :param metrics: Metrics to log + """ + global _global_logger + if _global_logger is None: + raise RuntimeError("No global logger set") + _global_logger.log_summary(metrics) + +@typing.overload +def global_logger() -> "MetricsLogger": + ... + + +@typing.overload +def global_logger(logger: "MetricsLogger") -> contextlib.AbstractContextManager: + """Context manager for setting the global logger""" + ... + + +def global_logger(logger: Optional["MetricsLogger"] = None) -> Union["MetricsLogger", contextlib.AbstractContextManager]: + """ + Get or set the global logger. + + :param logger: If provided, sets the global logger to this value. + :return: The global logger, or a context manager for setting the global logger. + """ + global _global_logger + if logger is None: + if _global_logger is None: + raise RuntimeError("No global logger set") + return _global_logger + else: + return _GlobalLoggerContextManager(logger) + + +class MetricsLogger(abc.ABC): + """ + A logger for logging metrics to some backend(s). + Meant to be used with the [global_logger][] context manager, but can also be used directly. + """ @abc.abstractmethod def init(self, run_id: Optional[str]): pass @@ -33,7 +96,6 @@ def init(self, run_id: Optional[str]): def log_hyperparameters(self, hparams: dict[str, Any]): pass - @abc.abstractmethod def log(self, metrics: dict[str, Any], *, step): """ @@ -49,7 +111,47 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): pass -class WandbLoggerSink(LoggerSink): + +class CompositeLogger(MetricsLogger): + def __init__(self, loggers: List[MetricsLogger]): + self.loggers = loggers + + def init(self, run_id: Optional[str]): + for logger in self.loggers: + logger.init(run_id) + + def log_hyperparameters(self, hparams: dict[str, Any]): + for logger in self.loggers: + logger.log_hyperparameters(hparams) + + def log(self, metrics: dict[str, Any], *, step): + for logger in self.loggers: + logger.log(metrics, step=step) + + def log_summary(self, metrics: dict[str, Any]): + for logger in self.loggers: + logger.log_summary(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + for logger in self.loggers: + logger.log_artifact(artifact, name=name, type=type) + + +class _GlobalLoggerContextManager(contextlib.AbstractContextManager): + def __init__(self, logger: "MetricsLogger"): + self.logger = logger + + def __enter__(self): + global _global_logger + self.old_logger = _global_logger + _global_logger = self.logger + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_logger + _global_logger = self.old_logger + + +class WandbLogger(MetricsLogger): def __init__(self, config: 'WandbConfig'): self.config = config self._run = None @@ -78,7 +180,7 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s self._run.log_artifact(artifact, name=name, type=type) -class TensorboardLoggerSink(LoggerSink): +class TensorboardLogger(MetricsLogger): def __init__(self, logdir: Union[str, Path]): self.logdir = logdir @@ -102,14 +204,11 @@ def log_summary(self, metrics: dict[str, Any]): for k, v in metrics.items(): self.writer.add_scalar(k, v, 0) - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pylogger.warning("TensorboardLoggerSink does not support logging artifacts yet") pass - - - def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): if isinstance(opt_state, MultiStepsState): opt_state = opt_state.inner_opt_state @@ -121,10 +220,10 @@ def wrap_key(key): if hasattr(opt_state, "hyperparams"): params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - wandb.log(params, step=step) + log_metrics(params, step=step) -def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: +def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: """ Initialize logging.Logger with the appropriate name, console, and file handlers. @@ -147,6 +246,11 @@ def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os + if not is_wandb_available(): + pylogger.warning("Wandb is not available, so we can't save XLA dumps") + return + + import wandb # attempt to parse xla_flags to see if we're dumping assembly files flags = os.getenv("XLA_FLAGS", None) @@ -154,7 +258,7 @@ def save_xla_dumps_to_wandb(initial_time: float): # parse the path # this isn't robust to quotes path = flags.split("xla_dump_to=")[1].split(" ")[0] - logger.info(f"Found xla_dump_to={path}, logging to wandb") + pylogger.info(f"Found xla_dump_to={path}, logging to wandb") if wandb.run: # only want to save the files that were generated during this run # XLA_FLAGS has to be set before the first jax call, so we can't just set it in the middle of the run @@ -166,7 +270,7 @@ def include_file(path: str): wandb.run.log_code(root=path, name="xla_dumps", include_fn=include_file) else: - logger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") + pylogger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") @contextlib.contextmanager @@ -184,20 +288,14 @@ def fn(): end = time.time() -@contextlib.contextmanager -def log_time_to_wandb(name: str, *, step=None): - with capture_time() as fn: - yield fn - wandb.log({name: fn()}, step=step) - -def jittable_wandb_log(data, *, step=None): - """uses jax effect callback to log to wandb from the host""" - if is_wandb_available(): - jax.debug.callback(wandb.log, data, step=step) def is_wandb_available(): + try: + import wandb + except ImportError: + return False return wandb is not None and wandb.run is not None @@ -278,7 +376,7 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams): other_settings = dict() if code_dir is not None: - logger.info(f"Setting wandb code_dir to {code_dir}") + pylogger.info(f"Setting wandb code_dir to {code_dir}") other_settings["code_dir"] = code_dir other_settings["git_root"] = code_dir # for some reason, wandb isn't populating the git commit, so we do it here @@ -287,7 +385,7 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams): other_settings["git_commit"] = repo.head.commit.hexsha hparams_to_save["git_commit"] = repo.head.commit.hexsha except (NoSuchPathError, InvalidGitRepositoryError): - logger.warning(f"Could not find git repo at {code_dir}") + pylogger.warning(f"Could not find git repo at {code_dir}") pass r = wandb.init( @@ -324,7 +422,7 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams): for k, v in metadata_to_share.items(): setattr(r, k, v) - logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") + pylogger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") if dataclasses.is_dataclass(hparams): with tempfile.TemporaryDirectory() as tmpdir: @@ -370,7 +468,7 @@ def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: top_git_root = repo.working_dir break except (NoSuchPathError, InvalidGitRepositoryError): - logger.debug(f"Skipping {dirname} since it's not a git root") + pylogger.debug(f"Skipping {dirname} since it's not a git root") pass return top_git_root diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 84fad654c..ae30b19df 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -8,7 +8,7 @@ from levanter.data.shard_cache import RichMetricsMonitor, WandbMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig -from levanter.logging import init_logger +from levanter.logging import init_logging logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): @levanter.config.main() def main(args: RayCachedLMDatasetConfig): """Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset""" - init_logger("cache_dataset.log") + init_logging("cache_dataset.log") args.initialize() tokenizer = args.the_tokenizer diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index aadeb97a8..26935aaca 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -608,7 +608,7 @@ def _initialize_jax_config(self): def _initialize_logging(self): self.log_dir.mkdir(parents=True, exist_ok=True) - levanter.logging.init_logger(self.log_dir / f"{self.id}.log") + levanter.logging.init_logging(self.log_dir / f"{self.id}.log") def _maybe_set_id(self): # always do this so we don't get weird hangs if the id isn't set right diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index cb9cd915d..038c5e9b5 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -41,6 +41,11 @@ def use_cpu_device(): yield +def is_inside_jit(): + """Returns True if we're currently inside a jit""" + return isinstance(jnp.zeros(()), jax.core.Tracer) + + def flops_estimate(fn, *args): """Estimates the flop count of a function using XLA/HLO fanciness. See https://github.com/google/flax/discussions/1854""" From 4ad74a69abf2e27e789101d24af366e67b731583 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 7 Nov 2023 15:53:29 -0800 Subject: [PATCH 003/205] almost got new logger working --- src/levanter/callbacks.py | 26 +++--- src/levanter/data/shard_cache.py | 26 +++--- src/levanter/data/text.py | 4 +- src/levanter/logging.py | 60 ++++++++++---- src/levanter/main/cache_dataset.py | 4 +- src/levanter/main/train_lm.py | 6 +- src/levanter/trainer.py | 129 ++++++++++++++++++----------- 7 files changed, 160 insertions(+), 95 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 2292c714a..339f4c206 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -1,5 +1,5 @@ import copy -import logging +import logging as pylogging import os import re import subprocess @@ -14,13 +14,14 @@ import wandb from tqdm import tqdm +from levanter import logging from levanter.logging import WandbConfig, log_optimizer_hyperparams, save_xla_dumps_to_wandb from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): @@ -61,7 +62,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - wandb.log({f"{prefix}/loss": loss}, step=info.step) + logging.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -73,8 +74,8 @@ def compute_loss(info: StepInfo): return compute_loss -def log_to_wandb(step: StepInfo): - wandb.log({"train/loss": step.loss, "global_step": step.step}, step=step.step) +def log_step_info(step: StepInfo): + logging.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") @@ -108,14 +109,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + logging.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + logging.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - wandb.log( + logging.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -125,7 +126,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - wandb.log( + logging.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -218,7 +219,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - wandb.log({"memory/total": memory_usage / 1e6}, step=step.step) + logging.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -229,14 +230,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - wandb.log({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + logging.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - wandb.log({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + logging.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage @@ -262,6 +263,7 @@ def compute_and_viz_log_probs(step: StepInfo): path = os.path.join(html_dir, f"step_{step}.html") viz_probs(path, model, tokenizer, log_prob_fn, test_data, max_docs=max_docs) + # TODO: convert to generic logging wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f2cc7d1ca..e71fdf630 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1,7 +1,7 @@ # Dataset for preprocessing data, tokenizing, and caching to disk. import asyncio import dataclasses -import logging +import logging as pylogging import os import sys import threading @@ -30,7 +30,6 @@ import pyarrow.parquet as pq import ray import tblib -import wandb from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from ray.actor import ActorHandle @@ -45,6 +44,7 @@ TimeRemainingColumn, ) +from .. import logging from . import ShardableDataset from ._preprocessor import BatchProcessor, as_record_batch, dict_from_record_batch from .sharded_dataset import ShardedDataset @@ -54,7 +54,7 @@ T_co = TypeVar("T_co", covariant=True) _ExcInfo = Tuple[Optional[BaseException], tblib.Traceback] -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) DEFAULT_ROWS_PER_CHUNK = 1024 * 32 LEDGER_FILE_NAME = "cache_ledger.json" @@ -205,7 +205,7 @@ def _produce_cache_for_shard( """Produces chunks of preprocessed data from a single shard and writes them to disk. Chunks are written to sink, which is an actor of ChunkCacheBuilder.""" # TODO: thread logging level through calls - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) # load or create shard metadata (for recovery) try: shard_name = source.shard_names[shard_idx] @@ -415,7 +415,7 @@ def _init_progress(self, metrics): self.progress.start() -class WandbMetricsMonitor(MetricsMonitor): +class LoggingMetricsMonitor(MetricsMonitor): last_metrics: Optional[InProgressCacheMetrics] last_time: Optional[float] @@ -457,16 +457,16 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - wandb.log(to_log, commit=self.commit) + logging.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): # TODO: I'd like to get the trainer pbar migrated to rich and just use rich everywhere, but until then, # we have separate logging - def __init__(self, logger: Optional[Union[logging.Logger, str]] = None, level=logging.INFO): + def __init__(self, logger: Optional[Union[pylogging.Logger, str]] = None, level=pylogging.INFO): if isinstance(logger, str): - logger = logging.getLogger(logger) - self.logger = logger or logging.getLogger(__name__) + logger = pylogging.getLogger(logger) + self.logger = logger or pylogging.getLogger(__name__) self.level = level def __call__(self, metrics: InProgressCacheMetrics): @@ -510,7 +510,7 @@ def is_producing(self): def _mk_process_task(processor: BatchProcessor[T]): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(batch: List[T]) -> pa.RecordBatch: - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) return processor(batch) return process_task @@ -519,7 +519,7 @@ def process_task(batch: List[T]) -> pa.RecordBatch: def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(batch: List[T]) -> pa.RecordBatch: - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) ray.get(queue.task_running.remote()) result = processor(batch) del batch @@ -614,7 +614,7 @@ def __init__( processor: BatchProcessor[T], rows_per_chunk: int, ): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() self._current_round_robin = [] @@ -753,7 +753,7 @@ class ChunkCacheBroker: _finished_promise: asyncio.Future[None] def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) self.chunks = [] self._reader_promises = {} self._is_finished = False diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 5a4890efb..4ad535114 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -45,9 +45,9 @@ from levanter.data.shard_cache import ( # noqa ChunkMetadata, LoggerMetricsMonitor, + LoggingMetricsMonitor, MetricsMonitor, ShardCache, - WandbMetricsMonitor, _serialize_json_and_commit, build_cache, ) @@ -604,7 +604,7 @@ def build_or_load_cache( if monitors is True: monitors = [ - WandbMetricsMonitor(prefix=f"preprocessing/{split}", commit=False), + LoggingMetricsMonitor(prefix=f"preprocessing/{split}", commit=False), LoggerMetricsMonitor(f"preprocessing.{split}"), ] elif monitors is False: diff --git a/src/levanter/logging.py b/src/levanter/logging.py index bb28be6d1..e534bdcc7 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -18,7 +18,11 @@ from optax import MultiStepsState from levanter.utils import jax_utils -from levanter.utils.jax_utils import jnp_to_python +from levanter.utils.jax_utils import is_inside_jit, jnp_to_python + + +if typing.TYPE_CHECKING: + import wandb pylogger = pylogging.getLogger(__name__) @@ -26,7 +30,7 @@ _global_logger: Optional["MetricsLogger"] = None -def log_metrics(metrics: dict[str, Any], *, step): +def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): """ Log metrics to the global logger. @@ -37,7 +41,14 @@ def log_metrics(metrics: dict[str, Any], *, step): if _global_logger is None: raise RuntimeError("No global logger set") - _global_logger.log(metrics, step=step) + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside a jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_logger.log(metrics, step=step) def jit_log_metrics(metrics, *, step=None): @@ -56,6 +67,7 @@ def log_summary(metrics: dict[str, Any]): raise RuntimeError("No global logger set") _global_logger.log_summary(metrics) + @typing.overload def global_logger() -> "MetricsLogger": ... @@ -67,7 +79,9 @@ def global_logger(logger: "MetricsLogger") -> contextlib.AbstractContextManager: ... -def global_logger(logger: Optional["MetricsLogger"] = None) -> Union["MetricsLogger", contextlib.AbstractContextManager]: +def global_logger( + logger: Optional["MetricsLogger"] = None, +) -> Union["MetricsLogger", contextlib.AbstractContextManager]: """ Get or set the global logger. @@ -88,6 +102,7 @@ class MetricsLogger(abc.ABC): A logger for logging metrics to some backend(s). Meant to be used with the [global_logger][] context manager, but can also be used directly. """ + @abc.abstractmethod def init(self, run_id: Optional[str]): pass @@ -97,9 +112,12 @@ def log_hyperparameters(self, hparams: dict[str, Any]): pass @abc.abstractmethod - def log(self, metrics: dict[str, Any], *, step): + def log(self, metrics: dict[str, typing.Any], *, step, commit: Optional[bool] = None): """ Log metrics to the logger. Step is always required. + + Args: + commit: """ pass @@ -124,9 +142,9 @@ def log_hyperparameters(self, hparams: dict[str, Any]): for logger in self.loggers: logger.log_hyperparameters(hparams) - def log(self, metrics: dict[str, Any], *, step): + def log(self, metrics: dict[str, Any], *, step, commit=None): for logger in self.loggers: - logger.log(metrics, step=step) + logger.log(metrics, step=step, commit=commit) def log_summary(self, metrics: dict[str, Any]): for logger in self.loggers: @@ -152,7 +170,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): class WandbLogger(MetricsLogger): - def __init__(self, config: 'WandbConfig'): + _run: Optional["wandb.sdk.wandb_run.Run"] + + def __init__(self, config: "WandbConfig"): self.config = config self._run = None @@ -164,10 +184,10 @@ def log_hyperparameters(self, hparams: dict[str, Any]): raise RuntimeError("Must call init before logging hyperparameters") self._run.config.update(hparams) - def log(self, metrics: dict[str, Any], *, step): + def log(self, metrics: dict[str, Any], *, step, commit=None): if self._run is None: raise RuntimeError("Must call init before logging metrics") - self._run.log(metrics, step=step) + self._run.log(metrics, step=step, commit=commit) def log_summary(self, metrics: dict[str, Any]): if self._run is None: @@ -181,31 +201,41 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s class TensorboardLogger(MetricsLogger): - def __init__(self, logdir: Union[str, Path]): self.logdir = logdir self.writer = None def init(self, run_id: Optional[str]): from tensorboardX import SummaryWriter + dir_to_write = self.logdir if run_id is not None: dir_to_write = os.path.join(dir_to_write, run_id) self.writer = SummaryWriter(dir_to_write) def log_hyperparameters(self, hparams: dict[str, Any]): + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + self.writer.add_hparams(hparams, {"dummy": 0}) - def log(self, metrics: dict[str, Any], *, step): + def log(self, metrics: dict[str, Any], *, step, commit=None): + del commit + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + for k, v in metrics.items(): self.writer.add_scalar(k, v, step) def log_summary(self, metrics: dict[str, Any]): + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + for k, v in metrics.items(): self.writer.add_scalar(k, v, 0) def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - pylogger.warning("TensorboardLoggerSink does not support logging artifacts yet") + pylogger.warning("TensorboardLogger does not support logging artifacts yet") pass @@ -246,6 +276,7 @@ def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os + if not is_wandb_available(): pylogger.warning("Wandb is not available, so we can't save XLA dumps") return @@ -288,9 +319,6 @@ def fn(): end = time.time() - - - def is_wandb_available(): try: import wandb diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index ae30b19df..dfbbf800a 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -5,7 +5,7 @@ import wandb import levanter -from levanter.data.shard_cache import RichMetricsMonitor, WandbMetricsMonitor, build_cache +from levanter.data.shard_cache import LoggingMetricsMonitor, RichMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig from levanter.logging import init_logging @@ -40,7 +40,7 @@ def main(args: RayCachedLMDatasetConfig): logger.warning(f"Skipping {split} because it is empty.") continue - monitors = [RichMetricsMonitor(source.num_shards), WandbMetricsMonitor("preprocess/" + split, commit=True)] + monitors = [RichMetricsMonitor(source.num_shards), LoggingMetricsMonitor("preprocess/" + split, commit=True)] cache = build_cache( cache_dir=split_cache_dir, diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 187a8d92d..85d951a0a 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -99,10 +99,10 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp trainer = Trainer(config.trainer, optimizer, compute_loss) - eval_datasets = config.data.validation_sets(Pos.size) - train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) + with trainer: + eval_datasets = config.data.validation_sets(Pos.size) + train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) - with trainer.device_mesh: # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 26935aaca..729a6ad7d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -30,6 +30,7 @@ from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit import levanter.logging +from levanter import logging from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader @@ -114,8 +115,10 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks + _logger: logging.MetricsLogger is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable + _cmanagers: List[typing.ContextManager] = [] def __init__( self, @@ -141,6 +144,10 @@ def __init__( self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable + self._logger = logging.WandbLogger(self.config.wandb) + # TODO: hacky hack + self._logger._run = wandb.run + self._cmanagers = [] @cached_property def loss_fn(self): @@ -202,6 +209,34 @@ def TrainBatch(self): def EvalBatch(self): return self.config.EvalBatch + def __enter__(self): + if len(self._cmanagers) > 0: + raise RuntimeError("Trainer is already entered") + + self._cmanagers = [ + logging.global_logger(self._logger), + self.device_mesh, + hax.axis_mapping(self.parameter_axis_mapping), + ] + + for cmanager in self._cmanagers: + cmanager.__enter__() + + return self + + def __exit__(self, *args): + problems = [] + for cmanager in reversed(self._cmanagers): + try: + cmanager.__exit__(*args) + except Exception as e: + problems.append(e) + + self._cmanagers = [] + + if len(problems) > 0: + raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] + def initial_state( self, training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None ) -> TrainerState: @@ -211,51 +246,51 @@ def initial_state( Returns: model, opt_state, key, resume_step """ + with logging.global_logger(self._logger): + if model is not None and model_init is not None: + raise ValueError("only one of model and model_init should be specified") + elif model is None and model_init is None: + raise ValueError("one of model and model_init must be specified") - if model is not None and model_init is not None: - raise ValueError("only one of model and model_init should be specified") - elif model is None and model_init is None: - raise ValueError("one of model and model_init must be specified") - - if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees - model_init = jax.tree_util.Partial(lambda m: m, model) + if model is not None: + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials + # We can't use plain partials because they aren't pytrees + model_init = jax.tree_util.Partial(lambda m: m, model) - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) + # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones + trainable_model_shape = self.trainable_params_only(model_shape) - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ) + ckpt = self.maybe_load_checkpoint( + trainable_model_shape, + (opt_state_shape, training_key), + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, + ) - if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt - if model is not None: - model = eqx.combine(trainable_model, model) - elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) + if ckpt is not None: + trainable_model, (opt_state, training_key), completed_step = ckpt + if model is not None: + model = eqx.combine(trainable_model, model) + elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): + # if we're resuming, we need to re-initialize the non-trainable parameters to their original values + non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) + model = eqx.combine(trainable_model, non_trainable) + else: + model = trainable_model + step = completed_step + 1 else: - model = trainable_model - step = completed_step + 1 - else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) - step = 0 + model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) + step = 0 - return TrainerState(step, model, opt_state, training_key) + return TrainerState(step, model, opt_state, training_key) def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ - with capture_time() as step_time: + with capture_time() as step_time, logging.global_logger(self._logger): key, new_key = jax.random.split(state.training_key) loss, new_model, new_optstate = self._train_step_fn( state.model, state.opt_state, *batch, **batch_kwargs, key=key @@ -272,24 +307,24 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) + with logging.global_logger(self._logger): + while state.step < self.config.num_train_steps: + with capture_time() as loading_time: + example = next(iter_data) - while state.step < self.config.num_train_steps: - with capture_time() as loading_time: - example = next(iter_data) - - # TODO: refactor logging - wandb.log({"throughput/loading_time": loading_time()}, step=state.step) + # TODO: refactor logging + logging.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) - info = self.train_step(state, example) - state = info.state + info = self.train_step(state, example) + state = info.state - if run_hooks: - with capture_time() as hook_time: - self.run_hooks(info) + if run_hooks: + with capture_time() as hook_time: + self.run_hooks(info) - wandb.log({"throughput/hook_time": hook_time()}, step=state.step) + logging.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) - yield info + yield info def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ @@ -308,7 +343,7 @@ def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) - self.add_hook(callbacks.log_to_wandb, every=1) + self.add_hook(callbacks.log_step_info, every=1) if eval_dataset is not None: self.add_eval_hook(eval_dataset) self.add_hook(callbacks.wandb_xla_logger(self.config.wandb), every=self.config.steps_per_eval) From ad708e38e25dbbfd03d8ca9be9c361c2ed0d0980 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 7 Nov 2023 21:16:01 -0800 Subject: [PATCH 004/205] move the metrics stuff to its own file --- src/levanter/callbacks.py | 23 +- src/levanter/data/shard_cache.py | 4 +- src/levanter/logging.py | 425 +----------------------------- src/levanter/main/train_lm.py | 8 +- src/levanter/metrics.py | 431 +++++++++++++++++++++++++++++++ src/levanter/trainer.py | 20 +- tests/test_eval_lm.py | 2 +- tests/test_logging.py | 2 +- tests/test_train_lm.py | 2 +- tests/test_viz_lm.py | 2 +- 10 files changed, 469 insertions(+), 450 deletions(-) create mode 100644 src/levanter/metrics.py diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 339f4c206..181fc2317 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -14,8 +14,9 @@ import wandb from tqdm import tqdm -from levanter import logging -from levanter.logging import WandbConfig, log_optimizer_hyperparams, save_xla_dumps_to_wandb +import levanter.metrics +from levanter.logging import save_xla_dumps_to_wandb +from levanter.metrics import WandbConfig, log_optimizer_hyperparams from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -62,7 +63,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - logging.log_metrics({f"{prefix}/loss": loss}, step=info.step) + levanter.metrics.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -75,7 +76,7 @@ def compute_loss(info: StepInfo): def log_step_info(step: StepInfo): - logging.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) + levanter.metrics.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") @@ -109,14 +110,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - logging.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.metrics.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - logging.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.metrics.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - logging.log_metrics( + levanter.metrics.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -126,7 +127,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - logging.log_metrics( + levanter.metrics.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -219,7 +220,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - logging.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.metrics.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -230,14 +231,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - logging.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.metrics.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - logging.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.metrics.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index e71fdf630..aca26bc4c 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -44,6 +44,8 @@ TimeRemainingColumn, ) +import levanter.metrics + from .. import logging from . import ShardableDataset from ._preprocessor import BatchProcessor, as_record_batch, dict_from_record_batch @@ -457,7 +459,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - logging.log_metrics(to_log, step=None, commit=self.commit) + levanter.metrics.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/logging.py b/src/levanter/logging.py index e534bdcc7..c9c7258df 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -1,257 +1,14 @@ -import abc import contextlib -import dataclasses import logging as pylogging -import os -import tempfile import time -import typing -import warnings -from dataclasses import dataclass from pathlib import Path -from typing import Any, List, Optional, Union +from typing import List, Union -import draccus import jax -from draccus import field -from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax import MultiStepsState - -from levanter.utils import jax_utils -from levanter.utils.jax_utils import is_inside_jit, jnp_to_python - - -if typing.TYPE_CHECKING: - import wandb pylogger = pylogging.getLogger(__name__) -_global_logger: Optional["MetricsLogger"] = None - - -def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): - """ - Log metrics to the global logger. - - :param metrics: Metrics to log - :param step: Step to log metrics at - """ - global _global_logger - if _global_logger is None: - raise RuntimeError("No global logger set") - - if is_inside_jit(): - # we're inside a jit, so we need to log from the host - if commit: - raise ValueError("Cannot commit from inside a jit") - jit_log_metrics(metrics, step=step) - else: - # TODO: do we need to coerce to np here? - _global_logger.log(metrics, step=step) - - -def jit_log_metrics(metrics, *, step=None): - """uses jax effect callback to log to wandb from the host""" - jax.debug.callback(log_metrics, metrics, step=step) - - -def log_summary(metrics: dict[str, Any]): - """ - Log summary metrics to the global logger. - - :param metrics: Metrics to log - """ - global _global_logger - if _global_logger is None: - raise RuntimeError("No global logger set") - _global_logger.log_summary(metrics) - - -@typing.overload -def global_logger() -> "MetricsLogger": - ... - - -@typing.overload -def global_logger(logger: "MetricsLogger") -> contextlib.AbstractContextManager: - """Context manager for setting the global logger""" - ... - - -def global_logger( - logger: Optional["MetricsLogger"] = None, -) -> Union["MetricsLogger", contextlib.AbstractContextManager]: - """ - Get or set the global logger. - - :param logger: If provided, sets the global logger to this value. - :return: The global logger, or a context manager for setting the global logger. - """ - global _global_logger - if logger is None: - if _global_logger is None: - raise RuntimeError("No global logger set") - return _global_logger - else: - return _GlobalLoggerContextManager(logger) - - -class MetricsLogger(abc.ABC): - """ - A logger for logging metrics to some backend(s). - Meant to be used with the [global_logger][] context manager, but can also be used directly. - """ - - @abc.abstractmethod - def init(self, run_id: Optional[str]): - pass - - @abc.abstractmethod - def log_hyperparameters(self, hparams: dict[str, Any]): - pass - - @abc.abstractmethod - def log(self, metrics: dict[str, typing.Any], *, step, commit: Optional[bool] = None): - """ - Log metrics to the logger. Step is always required. - - Args: - commit: - """ - pass - - @abc.abstractmethod - def log_summary(self, metrics: dict[str, Any]): - pass - - @abc.abstractmethod - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - pass - - -class CompositeLogger(MetricsLogger): - def __init__(self, loggers: List[MetricsLogger]): - self.loggers = loggers - - def init(self, run_id: Optional[str]): - for logger in self.loggers: - logger.init(run_id) - - def log_hyperparameters(self, hparams: dict[str, Any]): - for logger in self.loggers: - logger.log_hyperparameters(hparams) - - def log(self, metrics: dict[str, Any], *, step, commit=None): - for logger in self.loggers: - logger.log(metrics, step=step, commit=commit) - - def log_summary(self, metrics: dict[str, Any]): - for logger in self.loggers: - logger.log_summary(metrics) - - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - for logger in self.loggers: - logger.log_artifact(artifact, name=name, type=type) - - -class _GlobalLoggerContextManager(contextlib.AbstractContextManager): - def __init__(self, logger: "MetricsLogger"): - self.logger = logger - - def __enter__(self): - global _global_logger - self.old_logger = _global_logger - _global_logger = self.logger - - def __exit__(self, exc_type, exc_val, exc_tb): - global _global_logger - _global_logger = self.old_logger - - -class WandbLogger(MetricsLogger): - _run: Optional["wandb.sdk.wandb_run.Run"] - - def __init__(self, config: "WandbConfig"): - self.config = config - self._run = None - - def init(self, run_id: Optional[str]): - self._run = self.config.init(run_id) - - def log_hyperparameters(self, hparams: dict[str, Any]): - if self._run is None: - raise RuntimeError("Must call init before logging hyperparameters") - self._run.config.update(hparams) - - def log(self, metrics: dict[str, Any], *, step, commit=None): - if self._run is None: - raise RuntimeError("Must call init before logging metrics") - self._run.log(metrics, step=step, commit=commit) - - def log_summary(self, metrics: dict[str, Any]): - if self._run is None: - raise RuntimeError("Must call init before logging summary") - self._run.summary.update(metrics) - - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - if self._run is None: - raise RuntimeError("Must call init before logging artifacts") - self._run.log_artifact(artifact, name=name, type=type) - - -class TensorboardLogger(MetricsLogger): - def __init__(self, logdir: Union[str, Path]): - self.logdir = logdir - self.writer = None - - def init(self, run_id: Optional[str]): - from tensorboardX import SummaryWriter - - dir_to_write = self.logdir - if run_id is not None: - dir_to_write = os.path.join(dir_to_write, run_id) - self.writer = SummaryWriter(dir_to_write) - - def log_hyperparameters(self, hparams: dict[str, Any]): - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - - self.writer.add_hparams(hparams, {"dummy": 0}) - - def log(self, metrics: dict[str, Any], *, step, commit=None): - del commit - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - - for k, v in metrics.items(): - self.writer.add_scalar(k, v, step) - - def log_summary(self, metrics: dict[str, Any]): - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - - for k, v in metrics.items(): - self.writer.add_scalar(k, v, 0) - - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - pylogger.warning("TensorboardLogger does not support logging artifacts yet") - pass - - -def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state - - def wrap_key(key): - if prefix: - return f"{prefix}/{key}" - return key - - if hasattr(opt_state, "hyperparams"): - params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - log_metrics(params, step=step) - def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: """ @@ -277,6 +34,8 @@ def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os + from levanter.metrics import is_wandb_available + if not is_wandb_available(): pylogger.warning("Wandb is not available, so we can't save XLA dumps") return @@ -319,14 +78,6 @@ def fn(): end = time.time() -def is_wandb_available(): - try: - import wandb - except ImportError: - return False - return wandb is not None and wandb.run is not None - - def silence_transformer_nag(): # this is a hack to silence the transformers' "None of PyTorch, TensorFlow 2.0 or Flax have been found..." thing # which is annoying and not useful @@ -336,173 +87,3 @@ def silence_transformer_nag(): # log propagation bites us here when using ray logger.propagate = False - - -@dataclass -class WandbConfig: - """ - Configuration for wandb. - """ - - entity: Optional[str] = None # An entity is a username or team name where you send runs - project: Optional[str] = None # The name of the project where you are sending the enw run. - name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. - tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. - id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project - group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. - mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. - resume: Optional[Union[bool, str]] = None # - """ - Set the resume behavior. Options: "allow", "must", "never", "auto" or None. - By default, if the new run has the same ID as a previous run, this run overwrites that data. - Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) - document for more details. - """ - - save_code: Union[bool, str] = True - """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we - typically don't run from the root of the repo).""" - - save_xla_dumps: bool = False - """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - - def init(self, run_id: Optional[str], hparams=None, **extra_hparams): - import wandb - - if run_id is not None and self.id is not None and run_id != self.id: - warnings.warn( - f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" - " config." - ) - - id = self.id - if id is None: - id = run_id - - if hparams is None: - hparams_to_save = {} - elif dataclasses.is_dataclass(hparams): - hparams_to_save = dataclasses.asdict(hparams) - else: - hparams_to_save = dict(hparams) - - if extra_hparams: - hparams_to_save.update(extra_hparams) - - # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled - # however, we do share information about the run id, so that we can link to it from the other workers - mode = self.mode - if jax.process_index() != 0: - mode = "disabled" - - if isinstance(self.save_code, str): - code_dir = self.save_code - elif self.save_code: - code_dir = WandbConfig._infer_experiment_git_root() or "." # type: ignore - else: - code_dir = None - - other_settings = dict() - if code_dir is not None: - pylogger.info(f"Setting wandb code_dir to {code_dir}") - other_settings["code_dir"] = code_dir - other_settings["git_root"] = code_dir - # for some reason, wandb isn't populating the git commit, so we do it here - try: - repo = Repo(code_dir) - other_settings["git_commit"] = repo.head.commit.hexsha - hparams_to_save["git_commit"] = repo.head.commit.hexsha - except (NoSuchPathError, InvalidGitRepositoryError): - pylogger.warning(f"Could not find git repo at {code_dir}") - pass - - r = wandb.init( - entity=self.entity, - project=self.project, - name=self.name, - tags=self.tags, - id=id, - group=self.group, - resume=self.resume, - mode=mode, - config=hparams_to_save, - settings=other_settings, - ) - - assert r is not None - - if jax.process_count() > 1: - # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things - metadata_to_share = dict( - entity=r.entity, - project=r.project, - name=r.name, - tags=r.tags, - id=r.id, - group=r.group, - ) - metadata_to_share = jax_utils.multihost_broadcast_sync( - metadata_to_share, is_source=jax.process_index() == 0 - ) - - if jax.process_index() != 0: - assert r.mode == "disabled" - for k, v in metadata_to_share.items(): - setattr(r, k, v) - - pylogger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - - if dataclasses.is_dataclass(hparams): - with tempfile.TemporaryDirectory() as tmpdir: - config_path = os.path.join(tmpdir, "config.yaml") - with open(config_path, "w") as f: - draccus.dump(hparams, f, encoding="utf-8") - if wandb.run is not None: - wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") - - # generate a pip freeze - with tempfile.TemporaryDirectory() as tmpdir: - requirements_path = os.path.join(tmpdir, "requirements.txt") - requirements = _generate_pip_freeze() - with open(requirements_path, "w") as f: - f.write(requirements) - if wandb.run is not None: - wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") - - wandb.summary["num_devices"] = jax.device_count() - wandb.summary["num_hosts"] = jax.process_count() - wandb.summary["backend"] = jax.default_backend() - - return r - - @staticmethod - def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: - # sniff out the main directory (since we typically don't run from the root of the repo) - # we'll walk the stack and directories for the files in the stack the until we're at a git root - import os - import traceback - - stack = traceback.extract_stack() - # start from the top of the stack and work our way down since we want to hit the main file first - top_git_root = None - for frame in stack: - dirname = os.path.dirname(frame.filename) - # bit hacky but we want to skip anything that's in the python env - if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): - continue - # see if it's under a git root - try: - repo = Repo(dirname, search_parent_directories=True) - top_git_root = repo.working_dir - break - except (NoSuchPathError, InvalidGitRepositoryError): - pylogger.debug(f"Skipping {dirname} since it's not a git root") - pass - return top_git_root - - -def _generate_pip_freeze(): - from importlib.metadata import distributions - - dists = distributions() - return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 85d951a0a..bb6dc057c 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -97,9 +97,11 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): optimizer = config.optimizer.build(config.trainer.num_train_steps) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss) - - with trainer: + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics logger + with Trainer(config.trainer, optimizer, compute_loss) as trainer: eval_datasets = config.data.validation_sets(Pos.size) train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) diff --git a/src/levanter/metrics.py b/src/levanter/metrics.py new file mode 100644 index 000000000..1f36e95ec --- /dev/null +++ b/src/levanter/metrics.py @@ -0,0 +1,431 @@ +import abc +import contextlib +import dataclasses +import os +import tempfile +import typing +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional, Union + +import draccus +import jax +import wandb +from draccus import field +from git import InvalidGitRepositoryError, NoSuchPathError, Repo +from optax._src.wrappers import MultiStepsState + +from levanter.logging import pylogger +from levanter.utils import jax_utils +from levanter.utils.jax_utils import is_inside_jit, jnp_to_python + + +_global_logger: Optional["MetricsLogger"] = None + + +def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + """ + Log metrics to the global logger. + + :param metrics: Metrics to log + :param step: Step to log metrics at + """ + global _global_logger + if _global_logger is None: + raise RuntimeError("No global logger set") + + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside a jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_logger.log(metrics, step=step) + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global logger. + + :param metrics: Metrics to log + """ + global _global_logger + if _global_logger is None: + raise RuntimeError("No global logger set") + _global_logger.log_summary(metrics) + + +@typing.overload +def global_logger() -> "MetricsLogger": + ... + + +@typing.overload +def global_logger(logger: "MetricsLogger") -> contextlib.AbstractContextManager: + """Context manager for setting the global logger""" + ... + + +def global_logger( + logger: Optional["MetricsLogger"] = None, +) -> Union["MetricsLogger", contextlib.AbstractContextManager]: + """ + Get or set the global logger. + + :param logger: If provided, sets the global logger to this value. + :return: The global logger, or a context manager for setting the global logger. + """ + global _global_logger + if logger is None: + if _global_logger is None: + raise RuntimeError("No global logger set") + return _global_logger + else: + return _GlobalLoggerContextManager(logger) + + +class MetricsLogger(abc.ABC): + """ + A logger for logging metrics to some backend(s). + Meant to be used with the [global_logger][] context manager, but can also be used directly. + + We call it a "metrics" logger because it's meant to be used for logging metrics, but it can also be used for + logging artifacts and such. We're mostly trying to distinguish it from python's built-in logging module. + """ + + @abc.abstractmethod + def init(self, run_id: Optional[str]): + """ + Initialize the logger with a run id. + """ + pass + + @abc.abstractmethod + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + @abc.abstractmethod + def log(self, metrics: dict[str, typing.Any], *, step, commit: Optional[bool] = None): + """ + Log metrics to the logger. Step is always required. + + Args: + commit: + """ + pass + + @abc.abstractmethod + def log_summary(self, metrics: dict[str, Any]): + pass + + @abc.abstractmethod + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + +class CompositeLogger(MetricsLogger): + def __init__(self, loggers: List[MetricsLogger]): + self.loggers = loggers + + def init(self, run_id: Optional[str]): + for logger in self.loggers: + logger.init(run_id) + + def log_hyperparameters(self, hparams: dict[str, Any]): + for logger in self.loggers: + logger.log_hyperparameters(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + for logger in self.loggers: + logger.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + for logger in self.loggers: + logger.log_summary(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + for logger in self.loggers: + logger.log_artifact(artifact, name=name, type=type) + + +class _GlobalLoggerContextManager(contextlib.AbstractContextManager): + def __init__(self, logger: "MetricsLogger"): + self.logger = logger + + def __enter__(self): + global _global_logger + self.old_logger = _global_logger + _global_logger = self.logger + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_logger + _global_logger = self.old_logger + + +class WandbLogger(MetricsLogger): + _run: Optional["wandb.sdk.wandb_run.Run"] + + def __init__(self, config: "WandbConfig"): + self.config = config + self._run = None + + def init(self, run_id: Optional[str]): + self._run = self.config.init(run_id) + + def log_hyperparameters(self, hparams: dict[str, Any]): + if self._run is None: + raise RuntimeError("Must call init before logging hyperparameters") + self._run.config.update(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + if self._run is None: + raise RuntimeError("Must call init before logging metrics") + self._run.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + if self._run is None: + raise RuntimeError("Must call init before logging summary") + self._run.summary.update(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + if self._run is None: + raise RuntimeError("Must call init before logging artifacts") + self._run.log_artifact(artifact, name=name, type=type) + + +class TensorboardLogger(MetricsLogger): + def __init__(self, logdir: Union[str, Path]): + self.logdir = logdir + self.writer = None + + def init(self, run_id: Optional[str]): + from tensorboardX import SummaryWriter + + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + self.writer = SummaryWriter(dir_to_write) + + def log_hyperparameters(self, hparams: dict[str, Any]): + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + + self.writer.add_hparams(hparams, {"dummy": 0}) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + del commit + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_summary(self, metrics: dict[str, Any]): + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + + for k, v in metrics.items(): + self.writer.add_scalar(k, v, 0) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pylogger.warning("TensorboardLogger does not support logging artifacts yet") + pass + + +def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + + def wrap_key(key): + if prefix: + return f"{prefix}/{key}" + return key + + if hasattr(opt_state, "hyperparams"): + params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} + log_metrics(params, step=step) + + +def is_wandb_available(): + try: + import wandb + except ImportError: + return False + return wandb is not None and wandb.run is not None + + +@dataclass +class WandbConfig: + """ + Configuration for wandb. + """ + + entity: Optional[str] = None # An entity is a username or team name where you send runs + project: Optional[str] = None # The name of the project where you are sending the enw run. + name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. + tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. + id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project + group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. + mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. + resume: Optional[Union[bool, str]] = None # + """ + Set the resume behavior. Options: "allow", "must", "never", "auto" or None. + By default, if the new run has the same ID as a previous run, this run overwrites that data. + Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) + document for more details. + """ + + save_code: Union[bool, str] = True + """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we + typically don't run from the root of the repo).""" + + save_xla_dumps: bool = False + """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + + def init(self, run_id: Optional[str], hparams=None, **extra_hparams): + import wandb + + if run_id is not None and self.id is not None and run_id != self.id: + warnings.warn( + f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" + " config." + ) + + id = self.id + if id is None: + id = run_id + + if hparams is None: + hparams_to_save = {} + elif dataclasses.is_dataclass(hparams): + hparams_to_save = dataclasses.asdict(hparams) + else: + hparams_to_save = dict(hparams) + + if extra_hparams: + hparams_to_save.update(extra_hparams) + + # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled + # however, we do share information about the run id, so that we can link to it from the other workers + mode = self.mode + if jax.process_index() != 0: + mode = "disabled" + + if isinstance(self.save_code, str): + code_dir = self.save_code + elif self.save_code: + code_dir = WandbConfig._infer_experiment_git_root() or "." # type: ignore + else: + code_dir = None + + other_settings = dict() + if code_dir is not None: + pylogger.info(f"Setting wandb code_dir to {code_dir}") + other_settings["code_dir"] = code_dir + other_settings["git_root"] = code_dir + # for some reason, wandb isn't populating the git commit, so we do it here + try: + repo = Repo(code_dir) + other_settings["git_commit"] = repo.head.commit.hexsha + hparams_to_save["git_commit"] = repo.head.commit.hexsha + except (NoSuchPathError, InvalidGitRepositoryError): + pylogger.warning(f"Could not find git repo at {code_dir}") + pass + + r = wandb.init( + entity=self.entity, + project=self.project, + name=self.name, + tags=self.tags, + id=id, + group=self.group, + resume=self.resume, + mode=mode, + config=hparams_to_save, + settings=other_settings, + ) + + assert r is not None + + if jax.process_count() > 1: + # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things + metadata_to_share = dict( + entity=r.entity, + project=r.project, + name=r.name, + tags=r.tags, + id=r.id, + group=r.group, + ) + metadata_to_share = jax_utils.multihost_broadcast_sync( + metadata_to_share, is_source=jax.process_index() == 0 + ) + + if jax.process_index() != 0: + assert r.mode == "disabled" + for k, v in metadata_to_share.items(): + setattr(r, k, v) + + pylogger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") + + if dataclasses.is_dataclass(hparams): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + draccus.dump(hparams, f, encoding="utf-8") + if wandb.run is not None: + wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") + + # generate a pip freeze + with tempfile.TemporaryDirectory() as tmpdir: + requirements_path = os.path.join(tmpdir, "requirements.txt") + requirements = _generate_pip_freeze() + with open(requirements_path, "w") as f: + f.write(requirements) + if wandb.run is not None: + wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") + + wandb.summary["num_devices"] = jax.device_count() + wandb.summary["num_hosts"] = jax.process_count() + wandb.summary["backend"] = jax.default_backend() + + return r + + @staticmethod + def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: + # sniff out the main directory (since we typically don't run from the root of the repo) + # we'll walk the stack and directories for the files in the stack the until we're at a git root + import os + import traceback + + stack = traceback.extract_stack() + # start from the top of the stack and work our way down since we want to hit the main file first + top_git_root = None + for frame in stack: + dirname = os.path.dirname(frame.filename) + # bit hacky but we want to skip anything that's in the python env + if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): + continue + # see if it's under a git root + try: + repo = Repo(dirname, search_parent_directories=True) + top_git_root = repo.working_dir + break + except (NoSuchPathError, InvalidGitRepositoryError): + pylogger.debug(f"Skipping {dirname} since it's not a git root") + pass + return top_git_root + + +def _generate_pip_freeze(): + from importlib.metadata import distributions + + dists = distributions() + return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 729a6ad7d..b2669df2e 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -30,13 +30,15 @@ from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit import levanter.logging +import levanter.metrics from levanter import logging from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded -from levanter.logging import WandbConfig, capture_time +from levanter.logging import capture_time +from levanter.metrics import WandbConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -115,7 +117,7 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks - _logger: logging.MetricsLogger + _logger: levanter.metrics.MetricsLogger is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable _cmanagers: List[typing.ContextManager] = [] @@ -144,7 +146,7 @@ def __init__( self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable - self._logger = logging.WandbLogger(self.config.wandb) + self._logger = levanter.metrics.WandbLogger(self.config.wandb) # TODO: hacky hack self._logger._run = wandb.run self._cmanagers = [] @@ -214,7 +216,7 @@ def __enter__(self): raise RuntimeError("Trainer is already entered") self._cmanagers = [ - logging.global_logger(self._logger), + levanter.metrics.global_logger(self._logger), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] @@ -246,7 +248,7 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - with logging.global_logger(self._logger): + with levanter.metrics.global_logger(self._logger): if model is not None and model_init is not None: raise ValueError("only one of model and model_init should be specified") elif model is None and model_init is None: @@ -290,7 +292,7 @@ def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepI """ Performs a single training step. """ - with capture_time() as step_time, logging.global_logger(self._logger): + with capture_time() as step_time, levanter.metrics.global_logger(self._logger): key, new_key = jax.random.split(state.training_key) loss, new_model, new_optstate = self._train_step_fn( state.model, state.opt_state, *batch, **batch_kwargs, key=key @@ -307,13 +309,13 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) - with logging.global_logger(self._logger): + with levanter.metrics.global_logger(self._logger): while state.step < self.config.num_train_steps: with capture_time() as loading_time: example = next(iter_data) # TODO: refactor logging - logging.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) + levanter.metrics.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) info = self.train_step(state, example) state = info.state @@ -322,7 +324,7 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - logging.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + levanter.metrics.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) yield info diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index f1193f4f4..7fb64ed5b 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -11,7 +11,7 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig +from levanter.metrics import WandbConfig from levanter.models.gpt2 import Gpt2LMHeadModel from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_logging.py b/tests/test_logging.py index dc74c78ed..cfaf39350 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from levanter.logging import WandbConfig +from levanter.metrics import WandbConfig def test_infer_experiment_git_root(): diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index 3cd762d8b..33476b7c4 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -8,7 +8,7 @@ import levanter.main.train_lm as train_lm import tiny_test_corpus from levanter.distributed import RayConfig -from levanter.logging import WandbConfig +from levanter.metrics import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 665c98772..0711c31a9 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -11,7 +11,7 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig +from levanter.metrics import WandbConfig from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel from levanter.utils.py_utils import logical_cpu_core_count From 6930fa9f1ff19309a663d62128d94e6f3278df38 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 7 Nov 2023 21:42:04 -0800 Subject: [PATCH 005/205] refactor and move stuff around --- src/levanter/callbacks.py | 22 ++--- src/levanter/data/shard_cache.py | 4 +- src/levanter/logging.py | 2 +- src/levanter/{metrics.py => tracker.py} | 119 ++++++++++++------------ src/levanter/trainer.py | 22 ++--- tests/test_eval_lm.py | 2 +- tests/test_logging.py | 2 +- tests/test_train_lm.py | 2 +- tests/test_viz_lm.py | 2 +- 9 files changed, 90 insertions(+), 87 deletions(-) rename src/levanter/{metrics.py => tracker.py} (81%) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 181fc2317..9a41eab1c 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -14,9 +14,9 @@ import wandb from tqdm import tqdm -import levanter.metrics +import levanter.tracker from levanter.logging import save_xla_dumps_to_wandb -from levanter.metrics import WandbConfig, log_optimizer_hyperparams +from levanter.tracker import WandbConfig, log_optimizer_hyperparams from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -63,7 +63,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - levanter.metrics.log_metrics({f"{prefix}/loss": loss}, step=info.step) + levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -76,7 +76,7 @@ def compute_loss(info: StepInfo): def log_step_info(step: StepInfo): - levanter.metrics.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) + levanter.tracker.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") @@ -110,14 +110,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - levanter.metrics.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - levanter.metrics.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - levanter.metrics.log_metrics( + levanter.tracker.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -127,7 +127,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - levanter.metrics.log_metrics( + levanter.tracker.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -220,7 +220,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - levanter.metrics.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -231,14 +231,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - levanter.metrics.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - levanter.metrics.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index aca26bc4c..a983cbdad 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -44,7 +44,7 @@ TimeRemainingColumn, ) -import levanter.metrics +import levanter.tracker from .. import logging from . import ShardableDataset @@ -459,7 +459,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - levanter.metrics.log_metrics(to_log, step=None, commit=self.commit) + levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/logging.py b/src/levanter/logging.py index c9c7258df..bcd7440de 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -34,7 +34,7 @@ def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os - from levanter.metrics import is_wandb_available + from levanter.tracker import is_wandb_available if not is_wandb_available(): pylogger.warning("Wandb is not available, so we can't save XLA dumps") diff --git a/src/levanter/metrics.py b/src/levanter/tracker.py similarity index 81% rename from src/levanter/metrics.py rename to src/levanter/tracker.py index 1f36e95ec..ddc0a6082 100644 --- a/src/levanter/metrics.py +++ b/src/levanter/tracker.py @@ -21,28 +21,30 @@ from levanter.utils.jax_utils import is_inside_jit, jnp_to_python -_global_logger: Optional["MetricsLogger"] = None +_global_tracker: Optional["Tracker"] = None def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): """ - Log metrics to the global logger. + Log metrics to the global tracker. - :param metrics: Metrics to log - :param step: Step to log metrics at + Args + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. """ - global _global_logger - if _global_logger is None: - raise RuntimeError("No global logger set") + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") if is_inside_jit(): # we're inside a jit, so we need to log from the host if commit: - raise ValueError("Cannot commit from inside a jit") + raise ValueError("Cannot commit from inside jit") jit_log_metrics(metrics, step=step) else: # TODO: do we need to coerce to np here? - _global_logger.log(metrics, step=step) + _global_tracker.log(metrics, step=step) def jit_log_metrics(metrics, *, step=None): @@ -52,58 +54,57 @@ def jit_log_metrics(metrics, *, step=None): def log_summary(metrics: dict[str, Any]): """ - Log summary metrics to the global logger. + Log summary metrics to the global tracker. :param metrics: Metrics to log """ - global _global_logger - if _global_logger is None: - raise RuntimeError("No global logger set") - _global_logger.log_summary(metrics) + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log_summary(metrics) @typing.overload -def global_logger() -> "MetricsLogger": +def current_tracker() -> "Tracker": ... @typing.overload -def global_logger(logger: "MetricsLogger") -> contextlib.AbstractContextManager: - """Context manager for setting the global logger""" +def current_tracker(tracker: "Tracker") -> contextlib.AbstractContextManager: + """Context manager for setting the global tracker""" ... -def global_logger( - logger: Optional["MetricsLogger"] = None, -) -> Union["MetricsLogger", contextlib.AbstractContextManager]: +def current_tracker( + tracker: Optional["Tracker"] = None, +) -> Union["Tracker", contextlib.AbstractContextManager]: """ - Get or set the global logger. + Get or set the global tracker. - :param logger: If provided, sets the global logger to this value. - :return: The global logger, or a context manager for setting the global logger. + :param tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + :return: The global tracker, or a context manager for setting the global tracker. """ - global _global_logger - if logger is None: - if _global_logger is None: - raise RuntimeError("No global logger set") - return _global_logger + global _global_tracker + if tracker is None: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + return _global_tracker else: - return _GlobalLoggerContextManager(logger) + return _GlobalLoggerContextManager(tracker) -class MetricsLogger(abc.ABC): +class Tracker(abc.ABC): """ - A logger for logging metrics to some backend(s). - Meant to be used with the [global_logger][] context manager, but can also be used directly. + A tracker is responsible for logging metrics, hyperparameters, and artifacts. + Meant to be used with the [current_tracker][] context manager, but can also be used directly. - We call it a "metrics" logger because it's meant to be used for logging metrics, but it can also be used for - logging artifacts and such. We're mostly trying to distinguish it from python's built-in logging module. + The name is borrowed from Accelerate. """ @abc.abstractmethod def init(self, run_id: Optional[str]): """ - Initialize the logger with a run id. + Initialize the tracker with a run id. """ pass @@ -114,10 +115,12 @@ def log_hyperparameters(self, hparams: dict[str, Any]): @abc.abstractmethod def log(self, metrics: dict[str, typing.Any], *, step, commit: Optional[bool] = None): """ - Log metrics to the logger. Step is always required. + Log metrics to the tracker. Step is always required. Args: - commit: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. """ pass @@ -130,46 +133,46 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s pass -class CompositeLogger(MetricsLogger): - def __init__(self, loggers: List[MetricsLogger]): +class CompositeTracker(Tracker): + def __init__(self, loggers: List[Tracker]): self.loggers = loggers def init(self, run_id: Optional[str]): - for logger in self.loggers: - logger.init(run_id) + for tracker in self.loggers: + tracker.init(run_id) def log_hyperparameters(self, hparams: dict[str, Any]): - for logger in self.loggers: - logger.log_hyperparameters(hparams) + for tracker in self.loggers: + tracker.log_hyperparameters(hparams) def log(self, metrics: dict[str, Any], *, step, commit=None): - for logger in self.loggers: - logger.log(metrics, step=step, commit=commit) + for tracker in self.loggers: + tracker.log(metrics, step=step, commit=commit) def log_summary(self, metrics: dict[str, Any]): - for logger in self.loggers: - logger.log_summary(metrics) + for tracker in self.loggers: + tracker.log_summary(metrics) def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - for logger in self.loggers: - logger.log_artifact(artifact, name=name, type=type) + for tracker in self.loggers: + tracker.log_artifact(artifact, name=name, type=type) class _GlobalLoggerContextManager(contextlib.AbstractContextManager): - def __init__(self, logger: "MetricsLogger"): - self.logger = logger + def __init__(self, tracker: "Tracker"): + self.tracker = tracker def __enter__(self): - global _global_logger - self.old_logger = _global_logger - _global_logger = self.logger + global _global_tracker + self.old_tracker = _global_tracker + _global_tracker = self.tracker def __exit__(self, exc_type, exc_val, exc_tb): - global _global_logger - _global_logger = self.old_logger + global _global_tracker + _global_tracker = self.old_tracker -class WandbLogger(MetricsLogger): +class WandbTracker(Tracker): _run: Optional["wandb.sdk.wandb_run.Run"] def __init__(self, config: "WandbConfig"): @@ -200,7 +203,7 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s self._run.log_artifact(artifact, name=name, type=type) -class TensorboardLogger(MetricsLogger): +class TensorboardTracker(Tracker): def __init__(self, logdir: Union[str, Path]): self.logdir = logdir self.writer = None diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index b2669df2e..648c46016 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -30,7 +30,7 @@ from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit import levanter.logging -import levanter.metrics +import levanter.tracker from levanter import logging from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom @@ -38,7 +38,7 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded from levanter.logging import capture_time -from levanter.metrics import WandbConfig +from levanter.tracker import WandbConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -117,7 +117,7 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks - _logger: levanter.metrics.MetricsLogger + _tracker: levanter.tracker.Tracker is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable _cmanagers: List[typing.ContextManager] = [] @@ -146,9 +146,9 @@ def __init__( self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable - self._logger = levanter.metrics.WandbLogger(self.config.wandb) + self._tracker = levanter.tracker.WandbTracker(self.config.wandb) # TODO: hacky hack - self._logger._run = wandb.run + self._tracker._run = wandb.run self._cmanagers = [] @cached_property @@ -216,7 +216,7 @@ def __enter__(self): raise RuntimeError("Trainer is already entered") self._cmanagers = [ - levanter.metrics.global_logger(self._logger), + levanter.tracker.current_tracker(self._tracker), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] @@ -248,7 +248,7 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - with levanter.metrics.global_logger(self._logger): + with levanter.tracker.current_tracker(self._tracker): if model is not None and model_init is not None: raise ValueError("only one of model and model_init should be specified") elif model is None and model_init is None: @@ -292,7 +292,7 @@ def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepI """ Performs a single training step. """ - with capture_time() as step_time, levanter.metrics.global_logger(self._logger): + with capture_time() as step_time, levanter.tracker.current_tracker(self._tracker): key, new_key = jax.random.split(state.training_key) loss, new_model, new_optstate = self._train_step_fn( state.model, state.opt_state, *batch, **batch_kwargs, key=key @@ -309,13 +309,13 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) - with levanter.metrics.global_logger(self._logger): + with levanter.tracker.current_tracker(self._tracker): while state.step < self.config.num_train_steps: with capture_time() as loading_time: example = next(iter_data) # TODO: refactor logging - levanter.metrics.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) info = self.train_step(state, example) state = info.state @@ -324,7 +324,7 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - levanter.metrics.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) yield info diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index 7fb64ed5b..0ea76f211 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -11,8 +11,8 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.metrics import WandbConfig from levanter.models.gpt2 import Gpt2LMHeadModel +from levanter.tracker import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_logging.py b/tests/test_logging.py index cfaf39350..14c13bad8 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from levanter.metrics import WandbConfig +from levanter.tracker import WandbConfig def test_infer_experiment_git_root(): diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index 33476b7c4..c3b8279a8 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -8,7 +8,7 @@ import levanter.main.train_lm as train_lm import tiny_test_corpus from levanter.distributed import RayConfig -from levanter.metrics import WandbConfig +from levanter.tracker import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 0711c31a9..345d55c53 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -11,8 +11,8 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.metrics import WandbConfig from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel +from levanter.tracker import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count From abf7ec35b58d5baedd8f59e763e3c05ee8f14bd8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 7 Nov 2023 23:23:24 -0800 Subject: [PATCH 006/205] use generic infrastructure for summary --- examples/alpaca-lora/alpaca_lora.py | 12 +++++++++--- src/levanter/__init__.py | 2 ++ src/levanter/main/lora_lm.py | 13 +++++++++---- src/levanter/main/train_lm.py | 9 ++++++--- src/levanter/tracker.py | 2 +- 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index a4380a92b..76151f6cd 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -8,11 +8,11 @@ import jax.random as jrandom import transformers -import wandb import haliax as hax import levanter +from levanter import tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.lora import ( LoraConfig, @@ -112,8 +112,14 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params + tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 33bcd249d..971dbeada 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,4 +3,6 @@ import levanter.data as data import levanter.distributed as distributed import levanter.logging as logging +import levanter.tracker as tracker import levanter.visualization as visualization +from levanter.tracker import current_tracker, log_metrics, log_summary diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index dbe597e30..ef02b687f 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -4,12 +4,11 @@ from typing import Optional import jax.random as jrandom -import wandb import haliax.random import levanter -from levanter import callbacks +from levanter import callbacks, tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.data.text import CausalLmDataset, LMDatasetConfig, LmExample from levanter.lora import ( @@ -95,8 +94,14 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params + tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index bb6dc057c..dd6ac20df 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,14 +5,13 @@ from typing import Optional, Union import jax.random as jrandom -import wandb import haliax as hax from haliax import Axis from haliax.partitioning import named_jit, round_axis_for_partitioning import levanter -from levanter import callbacks +from levanter import callbacks, tracker from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config @@ -131,7 +130,11 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): else: logger.info("No checkpoint found. Starting from scratch.") - wandb.summary["parameter_count"] = parameter_count(state.model) + tracker.log_summary( + { + "parameter_count": parameter_count(state.model), + } + ) # boilerplate hooks and such trainer.add_default_hooks() diff --git a/src/levanter/tracker.py b/src/levanter/tracker.py index ddc0a6082..0104864f0 100644 --- a/src/levanter/tracker.py +++ b/src/levanter/tracker.py @@ -235,7 +235,7 @@ def log_summary(self, metrics: dict[str, Any]): raise RuntimeError("Must call init before logging metrics") for k, v in metrics.items(): - self.writer.add_scalar(k, v, 0) + self.writer.add_scalar(k, v, global_step=None) def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): pylogger.warning("TensorboardLogger does not support logging artifacts yet") From 547cea8608d02708f12cb7dda1c4060737955841 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Nov 2023 09:40:30 -0800 Subject: [PATCH 007/205] wip towards a clean tracker package --- src/levanter/callbacks.py | 21 +- src/levanter/logging.py | 2 +- src/levanter/tracker.py | 434 ---------------------------- src/levanter/tracker/__init__.py | 9 + src/levanter/tracker/helpers.py | 71 +++++ src/levanter/tracker/tensorboard.py | 48 +++ src/levanter/tracker/tracker.py | 162 +++++++++++ src/levanter/tracker/wandb.py | 206 +++++++++++++ src/levanter/trainer.py | 20 +- tests/test_eval_lm.py | 2 +- tests/test_logging.py | 4 +- tests/test_train_lm.py | 2 +- tests/test_viz_lm.py | 2 +- 13 files changed, 523 insertions(+), 460 deletions(-) delete mode 100644 src/levanter/tracker.py create mode 100644 src/levanter/tracker/__init__.py create mode 100644 src/levanter/tracker/helpers.py create mode 100644 src/levanter/tracker/tensorboard.py create mode 100644 src/levanter/tracker/tracker.py create mode 100644 src/levanter/tracker/wandb.py diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 9a41eab1c..4f9061347 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -16,7 +16,8 @@ import levanter.tracker from levanter.logging import save_xla_dumps_to_wandb -from levanter.tracker import WandbConfig, log_optimizer_hyperparams +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -63,7 +64,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) + levanter.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -76,7 +77,7 @@ def compute_loss(info: StepInfo): def log_step_info(step: StepInfo): - levanter.tracker.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) + levanter.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") @@ -110,14 +111,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - levanter.tracker.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - levanter.tracker.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - levanter.tracker.log_metrics( + levanter.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -127,7 +128,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - levanter.tracker.log_metrics( + levanter.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -220,7 +221,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -231,14 +232,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage diff --git a/src/levanter/logging.py b/src/levanter/logging.py index bcd7440de..23cf63047 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -34,7 +34,7 @@ def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os - from levanter.tracker import is_wandb_available + from levanter.tracker.wandb import is_wandb_available if not is_wandb_available(): pylogger.warning("Wandb is not available, so we can't save XLA dumps") diff --git a/src/levanter/tracker.py b/src/levanter/tracker.py deleted file mode 100644 index 0104864f0..000000000 --- a/src/levanter/tracker.py +++ /dev/null @@ -1,434 +0,0 @@ -import abc -import contextlib -import dataclasses -import os -import tempfile -import typing -import warnings -from dataclasses import dataclass -from pathlib import Path -from typing import Any, List, Optional, Union - -import draccus -import jax -import wandb -from draccus import field -from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax._src.wrappers import MultiStepsState - -from levanter.logging import pylogger -from levanter.utils import jax_utils -from levanter.utils.jax_utils import is_inside_jit, jnp_to_python - - -_global_tracker: Optional["Tracker"] = None - - -def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): - """ - Log metrics to the global tracker. - - Args - metrics: Metrics to log - step: Step to log at - commit: Whether to commit the metrics. If None, uses the default for the tracker. - """ - global _global_tracker - if _global_tracker is None: - raise RuntimeError("No global tracker set") - - if is_inside_jit(): - # we're inside a jit, so we need to log from the host - if commit: - raise ValueError("Cannot commit from inside jit") - jit_log_metrics(metrics, step=step) - else: - # TODO: do we need to coerce to np here? - _global_tracker.log(metrics, step=step) - - -def jit_log_metrics(metrics, *, step=None): - """uses jax effect callback to log to wandb from the host""" - jax.debug.callback(log_metrics, metrics, step=step) - - -def log_summary(metrics: dict[str, Any]): - """ - Log summary metrics to the global tracker. - - :param metrics: Metrics to log - """ - global _global_tracker - if _global_tracker is None: - raise RuntimeError("No global tracker set") - _global_tracker.log_summary(metrics) - - -@typing.overload -def current_tracker() -> "Tracker": - ... - - -@typing.overload -def current_tracker(tracker: "Tracker") -> contextlib.AbstractContextManager: - """Context manager for setting the global tracker""" - ... - - -def current_tracker( - tracker: Optional["Tracker"] = None, -) -> Union["Tracker", contextlib.AbstractContextManager]: - """ - Get or set the global tracker. - - :param tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. - :return: The global tracker, or a context manager for setting the global tracker. - """ - global _global_tracker - if tracker is None: - if _global_tracker is None: - raise RuntimeError("No global tracker set") - return _global_tracker - else: - return _GlobalLoggerContextManager(tracker) - - -class Tracker(abc.ABC): - """ - A tracker is responsible for logging metrics, hyperparameters, and artifacts. - Meant to be used with the [current_tracker][] context manager, but can also be used directly. - - The name is borrowed from Accelerate. - """ - - @abc.abstractmethod - def init(self, run_id: Optional[str]): - """ - Initialize the tracker with a run id. - """ - pass - - @abc.abstractmethod - def log_hyperparameters(self, hparams: dict[str, Any]): - pass - - @abc.abstractmethod - def log(self, metrics: dict[str, typing.Any], *, step, commit: Optional[bool] = None): - """ - Log metrics to the tracker. Step is always required. - - Args: - metrics: Metrics to log - step: Step to log at - commit: Whether to commit the metrics. If None, uses the default for the tracker. - """ - pass - - @abc.abstractmethod - def log_summary(self, metrics: dict[str, Any]): - pass - - @abc.abstractmethod - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - pass - - -class CompositeTracker(Tracker): - def __init__(self, loggers: List[Tracker]): - self.loggers = loggers - - def init(self, run_id: Optional[str]): - for tracker in self.loggers: - tracker.init(run_id) - - def log_hyperparameters(self, hparams: dict[str, Any]): - for tracker in self.loggers: - tracker.log_hyperparameters(hparams) - - def log(self, metrics: dict[str, Any], *, step, commit=None): - for tracker in self.loggers: - tracker.log(metrics, step=step, commit=commit) - - def log_summary(self, metrics: dict[str, Any]): - for tracker in self.loggers: - tracker.log_summary(metrics) - - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - for tracker in self.loggers: - tracker.log_artifact(artifact, name=name, type=type) - - -class _GlobalLoggerContextManager(contextlib.AbstractContextManager): - def __init__(self, tracker: "Tracker"): - self.tracker = tracker - - def __enter__(self): - global _global_tracker - self.old_tracker = _global_tracker - _global_tracker = self.tracker - - def __exit__(self, exc_type, exc_val, exc_tb): - global _global_tracker - _global_tracker = self.old_tracker - - -class WandbTracker(Tracker): - _run: Optional["wandb.sdk.wandb_run.Run"] - - def __init__(self, config: "WandbConfig"): - self.config = config - self._run = None - - def init(self, run_id: Optional[str]): - self._run = self.config.init(run_id) - - def log_hyperparameters(self, hparams: dict[str, Any]): - if self._run is None: - raise RuntimeError("Must call init before logging hyperparameters") - self._run.config.update(hparams) - - def log(self, metrics: dict[str, Any], *, step, commit=None): - if self._run is None: - raise RuntimeError("Must call init before logging metrics") - self._run.log(metrics, step=step, commit=commit) - - def log_summary(self, metrics: dict[str, Any]): - if self._run is None: - raise RuntimeError("Must call init before logging summary") - self._run.summary.update(metrics) - - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - if self._run is None: - raise RuntimeError("Must call init before logging artifacts") - self._run.log_artifact(artifact, name=name, type=type) - - -class TensorboardTracker(Tracker): - def __init__(self, logdir: Union[str, Path]): - self.logdir = logdir - self.writer = None - - def init(self, run_id: Optional[str]): - from tensorboardX import SummaryWriter - - dir_to_write = self.logdir - if run_id is not None: - dir_to_write = os.path.join(dir_to_write, run_id) - self.writer = SummaryWriter(dir_to_write) - - def log_hyperparameters(self, hparams: dict[str, Any]): - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - - self.writer.add_hparams(hparams, {"dummy": 0}) - - def log(self, metrics: dict[str, Any], *, step, commit=None): - del commit - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - - for k, v in metrics.items(): - self.writer.add_scalar(k, v, step) - - def log_summary(self, metrics: dict[str, Any]): - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - - for k, v in metrics.items(): - self.writer.add_scalar(k, v, global_step=None) - - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - pylogger.warning("TensorboardLogger does not support logging artifacts yet") - pass - - -def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state - - def wrap_key(key): - if prefix: - return f"{prefix}/{key}" - return key - - if hasattr(opt_state, "hyperparams"): - params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - log_metrics(params, step=step) - - -def is_wandb_available(): - try: - import wandb - except ImportError: - return False - return wandb is not None and wandb.run is not None - - -@dataclass -class WandbConfig: - """ - Configuration for wandb. - """ - - entity: Optional[str] = None # An entity is a username or team name where you send runs - project: Optional[str] = None # The name of the project where you are sending the enw run. - name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. - tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. - id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project - group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. - mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. - resume: Optional[Union[bool, str]] = None # - """ - Set the resume behavior. Options: "allow", "must", "never", "auto" or None. - By default, if the new run has the same ID as a previous run, this run overwrites that data. - Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) - document for more details. - """ - - save_code: Union[bool, str] = True - """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we - typically don't run from the root of the repo).""" - - save_xla_dumps: bool = False - """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - - def init(self, run_id: Optional[str], hparams=None, **extra_hparams): - import wandb - - if run_id is not None and self.id is not None and run_id != self.id: - warnings.warn( - f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" - " config." - ) - - id = self.id - if id is None: - id = run_id - - if hparams is None: - hparams_to_save = {} - elif dataclasses.is_dataclass(hparams): - hparams_to_save = dataclasses.asdict(hparams) - else: - hparams_to_save = dict(hparams) - - if extra_hparams: - hparams_to_save.update(extra_hparams) - - # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled - # however, we do share information about the run id, so that we can link to it from the other workers - mode = self.mode - if jax.process_index() != 0: - mode = "disabled" - - if isinstance(self.save_code, str): - code_dir = self.save_code - elif self.save_code: - code_dir = WandbConfig._infer_experiment_git_root() or "." # type: ignore - else: - code_dir = None - - other_settings = dict() - if code_dir is not None: - pylogger.info(f"Setting wandb code_dir to {code_dir}") - other_settings["code_dir"] = code_dir - other_settings["git_root"] = code_dir - # for some reason, wandb isn't populating the git commit, so we do it here - try: - repo = Repo(code_dir) - other_settings["git_commit"] = repo.head.commit.hexsha - hparams_to_save["git_commit"] = repo.head.commit.hexsha - except (NoSuchPathError, InvalidGitRepositoryError): - pylogger.warning(f"Could not find git repo at {code_dir}") - pass - - r = wandb.init( - entity=self.entity, - project=self.project, - name=self.name, - tags=self.tags, - id=id, - group=self.group, - resume=self.resume, - mode=mode, - config=hparams_to_save, - settings=other_settings, - ) - - assert r is not None - - if jax.process_count() > 1: - # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things - metadata_to_share = dict( - entity=r.entity, - project=r.project, - name=r.name, - tags=r.tags, - id=r.id, - group=r.group, - ) - metadata_to_share = jax_utils.multihost_broadcast_sync( - metadata_to_share, is_source=jax.process_index() == 0 - ) - - if jax.process_index() != 0: - assert r.mode == "disabled" - for k, v in metadata_to_share.items(): - setattr(r, k, v) - - pylogger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - - if dataclasses.is_dataclass(hparams): - with tempfile.TemporaryDirectory() as tmpdir: - config_path = os.path.join(tmpdir, "config.yaml") - with open(config_path, "w") as f: - draccus.dump(hparams, f, encoding="utf-8") - if wandb.run is not None: - wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") - - # generate a pip freeze - with tempfile.TemporaryDirectory() as tmpdir: - requirements_path = os.path.join(tmpdir, "requirements.txt") - requirements = _generate_pip_freeze() - with open(requirements_path, "w") as f: - f.write(requirements) - if wandb.run is not None: - wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") - - wandb.summary["num_devices"] = jax.device_count() - wandb.summary["num_hosts"] = jax.process_count() - wandb.summary["backend"] = jax.default_backend() - - return r - - @staticmethod - def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: - # sniff out the main directory (since we typically don't run from the root of the repo) - # we'll walk the stack and directories for the files in the stack the until we're at a git root - import os - import traceback - - stack = traceback.extract_stack() - # start from the top of the stack and work our way down since we want to hit the main file first - top_git_root = None - for frame in stack: - dirname = os.path.dirname(frame.filename) - # bit hacky but we want to skip anything that's in the python env - if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): - continue - # see if it's under a git root - try: - repo = Repo(dirname, search_parent_directories=True) - top_git_root = repo.working_dir - break - except (NoSuchPathError, InvalidGitRepositoryError): - pylogger.debug(f"Skipping {dirname} since it's not a git root") - pass - return top_git_root - - -def _generate_pip_freeze(): - from importlib.metadata import distributions - - dists = distributions() - return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py new file mode 100644 index 000000000..2bdcc8d58 --- /dev/null +++ b/src/levanter/tracker/__init__.py @@ -0,0 +1,9 @@ +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.tracker import ( + CompositeTracker, + Tracker, + current_tracker, + jit_log_metrics, + log_metrics, + log_summary, +) diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py new file mode 100644 index 000000000..26e5e0e2e --- /dev/null +++ b/src/levanter/tracker/helpers.py @@ -0,0 +1,71 @@ +import dataclasses +import logging +import os +from typing import Optional + +from git import InvalidGitRepositoryError, NoSuchPathError, Repo +from optax._src.wrappers import MultiStepsState + +import levanter.tracker +from levanter.utils.jax_utils import jnp_to_python + + +logger = logging.getLogger(__name__) + + +def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + + def wrap_key(key): + if prefix: + return f"{prefix}/{key}" + return key + + if hasattr(opt_state, "hyperparams"): + params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} + levanter.log_metrics(params, step=step) + + +def hparams_to_dict(hparams, **extra_hparams): + if hparams is None: + hparams_to_save = {} + elif dataclasses.is_dataclass(hparams): + hparams_to_save = dataclasses.asdict(hparams) + else: + hparams_to_save = dict(hparams) + if extra_hparams: + hparams_to_save.update(extra_hparams) + return hparams_to_save + + +def infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: + # sniff out the main directory (since we typically don't run from the root of the repo) + # we'll walk the stack and directories for the files in the stack the until we're at a git root + import os + import traceback + + stack = traceback.extract_stack() + # start from the top of the stack and work our way down since we want to hit the main file first + top_git_root = None + for frame in stack: + dirname = os.path.dirname(frame.filename) + # bit hacky but we want to skip anything that's in the python env + if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): + continue + # see if it's under a git root + try: + repo = Repo(dirname, search_parent_directories=True) + top_git_root = repo.working_dir + break + except (NoSuchPathError, InvalidGitRepositoryError): + logger.debug(f"Skipping {dirname} since it's not a git root") + pass + return top_git_root + + +def generate_pip_freeze(): + from importlib.metadata import distributions + + dists = distributions() + return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py new file mode 100644 index 000000000..1377167ba --- /dev/null +++ b/src/levanter/tracker/tensorboard.py @@ -0,0 +1,48 @@ +import logging +import os +from pathlib import Path +from typing import Any, Optional, Union + +from levanter.tracker import Tracker + + +pylogger = logging.getLogger(__name__) + + +class TensorboardTracker(Tracker): + def __init__(self, logdir: Union[str, Path]): + self.logdir = logdir + self.writer = None + + def init(self, run_id: Optional[str]): + from tensorboardX import SummaryWriter + + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + self.writer = SummaryWriter(dir_to_write) + + def log_hyperparameters(self, hparams: dict[str, Any]): + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + + self.writer.add_hparams(hparams, {"dummy": 0}) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + del commit + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_summary(self, metrics: dict[str, Any]): + if self.writer is None: + raise RuntimeError("Must call init before logging metrics") + + for k, v in metrics.items(): + self.writer.add_scalar(k, v, global_step=None) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pylogger.warning("TensorboardLogger does not support logging artifacts yet") + pass diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py new file mode 100644 index 000000000..3745fc26b --- /dev/null +++ b/src/levanter/tracker/tracker.py @@ -0,0 +1,162 @@ +import abc +import typing +from contextlib import AbstractContextManager +from typing import Any, List, Optional + +import draccus +import jax + +from levanter.utils.jax_utils import is_inside_jit + + +class Tracker(abc.ABC): + """ + A tracker is responsible for logging metrics, hyperparameters, and artifacts. + Meant to be used with the [current_tracker][] context manager, but can also be used directly. + + The name is borrowed from Accelerate. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + """ + + @abc.abstractmethod + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + @abc.abstractmethod + def log(self, metrics: dict[str, typing.Any], *, step, commit: Optional[bool] = None): + """ + Log metrics to the tracker. Step is always required. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + pass + + @abc.abstractmethod + def log_summary(self, metrics: dict[str, Any]): + pass + + @abc.abstractmethod + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + +class CompositeTracker(Tracker): + def __init__(self, loggers: List[Tracker]): + self.loggers = loggers + + def log_hyperparameters(self, hparams: dict[str, Any]): + for tracker in self.loggers: + tracker.log_hyperparameters(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + for tracker in self.loggers: + tracker.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + for tracker in self.loggers: + tracker.log_summary(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + for tracker in self.loggers: + tracker.log_artifact(artifact, name=name, type=type) + + +_global_tracker: Optional["Tracker"] = None + + +class TrackerConfig(draccus.PluginRegistry): + discover_packages_path = "levanter.tracker" + + def init(self, run_id: Optional[str], hparams=None, **extra_hparams) -> Tracker: + raise NotImplementedError + + +def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + """ + Log metrics to the global tracker. + + Args + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_tracker.log(metrics, step=step) + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global tracker. + + :param metrics: Metrics to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log_summary(metrics) + + +@typing.overload +def current_tracker() -> "Tracker": + ... + + +@typing.overload +def current_tracker(tracker: "Tracker") -> typing.ContextManager: + """Returns a context manager for setting the global tracker""" + ... + + +def current_tracker( + tracker: Optional[Tracker] = None, +) -> Tracker | typing.ContextManager: + """ + Get or set the global tracker. + + :param tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + :return: The global tracker, or a context manager for setting the global tracker. + """ + global _global_tracker + if tracker is None: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + return _global_tracker + else: + return _GlobalLoggerContextManager(tracker) + + +class _GlobalLoggerContextManager(AbstractContextManager): + def __init__(self, tracker: "Tracker"): + self.tracker = tracker + + def __enter__(self): + global _global_tracker + self.old_tracker = _global_tracker + _global_tracker = self.tracker + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_tracker + _global_tracker = self.old_tracker diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py new file mode 100644 index 000000000..1bd74e40c --- /dev/null +++ b/src/levanter/tracker/wandb.py @@ -0,0 +1,206 @@ +import dataclasses +import logging +import os +import tempfile +import typing +import warnings +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +import draccus +import jax +from draccus import field +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +from levanter.tracker import Tracker +from levanter.tracker.helpers import generate_pip_freeze, hparams_to_dict, infer_experiment_git_root +from levanter.tracker.tracker import TrackerConfig +from levanter.utils import jax_utils + + +if typing.TYPE_CHECKING: + import wandb + import wandb.sdk.lib.disabled + + +logger = logging.getLogger(__name__) + +WandbRun = Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.disabled.RunDisabled"] + + +class WandbTracker(Tracker): + _run: Optional[WandbRun] + + def __init__(self, run: Optional[WandbRun]): + import wandb + + if run is None: + if wandb.run is None: + logger.warning("Wandb run is not initialized. Initializing a new run.") + run = wandb.init() + + self._run = run + + def log_hyperparameters(self, hparams: dict[str, Any]): + if self._run is None: + raise RuntimeError("Must call init before logging hyperparameters") + self._run.config.update(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + if self._run is None: + raise RuntimeError("Must call init before logging metrics") + self._run.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + if self._run is None: + raise RuntimeError("Must call init before logging summary") + self._run.summary.update(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + if self._run is None: + raise RuntimeError("Must call init before logging artifacts") + self._run.log_artifact(artifact, name=name, type=type) + + +def is_wandb_available(): + try: + import wandb + except ImportError: + return False + return wandb is not None and wandb.run is not None + + +@TrackerConfig.register_subclass("wandb") +@dataclass +class WandbConfig(TrackerConfig): + """ + Configuration for wandb. + """ + + entity: Optional[str] = None # An entity is a username or team name where you send runs + project: Optional[str] = None # The name of the project where you are sending the enw run. + name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. + tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. + id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project + group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. + mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. + resume: Optional[Union[bool, str]] = None # + """ + Set the resume behavior. Options: "allow", "must", "never", "auto" or None. + By default, if the new run has the same ID as a previous run, this run overwrites that data. + Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) + document for more details. + """ + + save_code: Union[bool, str] = True + """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we + typically don't run from the root of the repo).""" + + save_xla_dumps: bool = False + """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + + def init(self, run_id: Optional[str], hparams=None, **extra_hparams) -> WandbTracker: + import wandb + + if run_id is not None and self.id is not None and run_id != self.id: + warnings.warn( + f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" + " config." + ) + + id = self.id + if id is None: + id = run_id + + hparams_to_save = hparams_to_dict(hparams, **extra_hparams) + + # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled + # however, we do share information about the run id, so that we can link to it from the other workers + mode = self.mode + if jax.process_index() != 0: + mode = "disabled" + + git_settings = self._git_settings() + + if "git_commit" in git_settings: + hparams_to_save["git_commit"] = git_settings["git_commit"] + + r = wandb.init( + entity=self.entity, + project=self.project, + name=self.name, + tags=self.tags, + id=id, + group=self.group, + resume=self.resume, + mode=mode, + config=hparams_to_save, + settings=git_settings, + ) + + assert r is not None + + if jax.process_count() > 1: + # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things + metadata_to_share = dict( + entity=r.entity, + project=r.project, + name=r.name, + tags=r.tags, + id=r.id, + group=r.group, + ) + metadata_to_share = jax_utils.multihost_broadcast_sync( + metadata_to_share, is_source=jax.process_index() == 0 + ) + + if jax.process_index() != 0: + assert r.mode == "disabled" + for k, v in metadata_to_share.items(): + setattr(r, k, v) + + logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") + + if dataclasses.is_dataclass(hparams): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + draccus.dump(hparams, f, encoding="utf-8") + if wandb.run is not None: + wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") + + # generate a pip freeze + with tempfile.TemporaryDirectory() as tmpdir: + requirements_path = os.path.join(tmpdir, "requirements.txt") + requirements = generate_pip_freeze() + with open(requirements_path, "w") as f: + f.write(requirements) + if wandb.run is not None: + wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") + + wandb.summary["num_devices"] = jax.device_count() + wandb.summary["num_hosts"] = jax.process_count() + wandb.summary["backend"] = jax.default_backend() + + return WandbTracker(r) + + def _git_settings(self): + other_settings = dict() + if isinstance(self.save_code, str): + code_dir = self.save_code + elif self.save_code: + code_dir = infer_experiment_git_root() or "." # type: ignore + else: + code_dir = None + if code_dir is not None: + logger.info(f"Setting wandb code_dir to {code_dir}") + other_settings["code_dir"] = code_dir + other_settings["git_root"] = code_dir + # for some reason, wandb isn't populating the git commit, so we do it here + try: + repo = Repo(code_dir) + other_settings["git_commit"] = repo.head.commit.hexsha + except (NoSuchPathError, InvalidGitRepositoryError): + logger.warning(f"Could not find git repo at {code_dir}") + pass + return other_settings diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 648c46016..4f51970d8 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -31,6 +31,8 @@ import levanter.logging import levanter.tracker +import levanter.tracker.tracker +import levanter.tracker.wandb from levanter import logging from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom @@ -38,7 +40,7 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded from levanter.logging import capture_time -from levanter.tracker import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -117,7 +119,7 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks - _tracker: levanter.tracker.Tracker + _tracker: levanter.tracker.tracker.Tracker is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable _cmanagers: List[typing.ContextManager] = [] @@ -146,9 +148,8 @@ def __init__( self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable - self._tracker = levanter.tracker.WandbTracker(self.config.wandb) # TODO: hacky hack - self._tracker._run = wandb.run + self._tracker = levanter.tracker.wandb.WandbTracker(wandb.run) self._cmanagers = [] @cached_property @@ -216,7 +217,7 @@ def __enter__(self): raise RuntimeError("Trainer is already entered") self._cmanagers = [ - levanter.tracker.current_tracker(self._tracker), + levanter.current_tracker(self._tracker), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] @@ -292,7 +293,7 @@ def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepI """ Performs a single training step. """ - with capture_time() as step_time, levanter.tracker.current_tracker(self._tracker): + with capture_time() as step_time, levanter.current_tracker(self._tracker): key, new_key = jax.random.split(state.training_key) loss, new_model, new_optstate = self._train_step_fn( state.model, state.opt_state, *batch, **batch_kwargs, key=key @@ -309,13 +310,12 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) - with levanter.tracker.current_tracker(self._tracker): + with levanter.current_tracker(self._tracker): while state.step < self.config.num_train_steps: with capture_time() as loading_time: example = next(iter_data) - # TODO: refactor logging - levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) + levanter.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) info = self.train_step(state, example) state = info.state @@ -324,7 +324,7 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + levanter.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) yield info diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index 0ea76f211..178069f26 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -12,7 +12,7 @@ from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig from levanter.models.gpt2 import Gpt2LMHeadModel -from levanter.tracker import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_logging.py b/tests/test_logging.py index 14c13bad8..7c537b182 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from levanter.tracker import WandbConfig +from levanter.tracker.helpers import infer_experiment_git_root def test_infer_experiment_git_root(): @@ -13,7 +13,7 @@ def test_infer_experiment_git_root(): except (InvalidGitRepositoryError, NoSuchPathError): pytest.skip("test not running in a git repo") - root = WandbConfig._infer_experiment_git_root() + root = infer_experiment_git_root() # ensure that 1) this is a git root and 2) this source file is underneath assert root is not None diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index c3b8279a8..f95b27efb 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -8,7 +8,7 @@ import levanter.main.train_lm as train_lm import tiny_test_corpus from levanter.distributed import RayConfig -from levanter.tracker import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 345d55c53..cf4fb74a6 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -12,7 +12,7 @@ from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel -from levanter.tracker import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count From 2f481ed99c75a96a2307d61f268699a6aee42a86 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Nov 2023 16:49:37 -0800 Subject: [PATCH 008/205] wip --- src/levanter/data/shard_cache.py | 2 +- src/levanter/tracker/tracker.py | 2 +- src/levanter/tracker/wandb.py | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index a983cbdad..8a3fe7d9b 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -459,7 +459,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) + levanter.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 3745fc26b..0e4228b4c 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -75,7 +75,7 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s class TrackerConfig(draccus.PluginRegistry): discover_packages_path = "levanter.tracker" - def init(self, run_id: Optional[str], hparams=None, **extra_hparams) -> Tracker: + def init(self, run_id: Optional[str], hparams=None) -> Tracker: raise NotImplementedError diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 1bd74e40c..66e43a59e 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -83,7 +83,7 @@ class WandbConfig(TrackerConfig): tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. - mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. + mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be whatever W&B decides. resume: Optional[Union[bool, str]] = None # """ Set the resume behavior. Options: "allow", "must", "never", "auto" or None. @@ -99,7 +99,7 @@ class WandbConfig(TrackerConfig): save_xla_dumps: bool = False """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - def init(self, run_id: Optional[str], hparams=None, **extra_hparams) -> WandbTracker: + def init(self, run_id: Optional[str], hparams=None) -> WandbTracker: import wandb if run_id is not None and self.id is not None and run_id != self.id: @@ -112,12 +112,13 @@ def init(self, run_id: Optional[str], hparams=None, **extra_hparams) -> WandbTra if id is None: id = run_id - hparams_to_save = hparams_to_dict(hparams, **extra_hparams) + hparams_to_save = hparams_to_dict(hparams) # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled # however, we do share information about the run id, so that we can link to it from the other workers - mode = self.mode - if jax.process_index() != 0: + if is_rank_0: + mode = self.mode + else: mode = "disabled" git_settings = self._git_settings() From 0b080fb02099d427c9351ae97b6353fae2fb7f81 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 9 Nov 2023 09:55:21 -0800 Subject: [PATCH 009/205] remove more wandb deps --- pyproject.toml | 2 +- src/levanter/callbacks.py | 16 +++++---- src/levanter/main/cache_dataset.py | 10 +++--- src/levanter/tracker/__init__.py | 15 ++++++++ src/levanter/tracker/tracker.py | 19 +++++++++++ src/levanter/tracker/wandb.py | 2 +- src/levanter/trainer.py | 55 ++++++++++++++---------------- 7 files changed, 75 insertions(+), 44 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df2eb7533..70ef900ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "transformers>=4.22.0", "optax", "wandb", - "draccus>=0.6", + "draccus>=0.7", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets==2.11.0", diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 4f9061347..85220cda5 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -11,7 +11,6 @@ import humanfriendly import jax -import wandb from tqdm import tqdm import levanter.tracker @@ -60,11 +59,10 @@ def compute_validation_loss( def compute_loss(info: StepInfo): loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) - if wandb.run is not None: - prefix = "eval" - if name: - prefix += "/" + name - levanter.log_metrics({f"{prefix}/loss": loss}, step=info.step) + prefix = "eval" + if name: + prefix += "/" + name + levanter.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -82,6 +80,8 @@ def log_step_info(step: StepInfo): def wandb_xla_logger(config: WandbConfig): + import wandb + last_mtime = wandb.run and wandb.run.start_time or time.time() def log_xla_to_wandb(step: StepInfo): @@ -155,7 +155,7 @@ def update_pbar(step: StepInfo): def log_memory_usage(sample_interval: float = 1.0, log_individual_devices: bool = False): """ - Logs memory usage to wandb. This runs a loop that samples memory usage every `sample_interval` seconds. + Logs memory usage. This runs a loop that samples memory usage every `sample_interval` seconds. We only log when hooks are invoked, so there's not much point in running this much more frequently than you invoke the hook. @@ -266,6 +266,8 @@ def compute_and_viz_log_probs(step: StepInfo): viz_probs(path, model, tokenizer, log_prob_fn, test_data, max_docs=max_docs) # TODO: convert to generic logging + import wandb + wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index dfbbf800a..80ff6949e 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -1,14 +1,14 @@ import logging import os -from dataclasses import dataclass - -import wandb +from dataclasses import dataclass, field import levanter from levanter.data.shard_cache import LoggingMetricsMonitor, RichMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig from levanter.logging import init_logging +from levanter.tracker import TrackerConfig +from levanter.tracker.tracker import NullTrackerConfig logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ @dataclass class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): - pass + tracker: TrackerConfig = field(default_factory=NullTrackerConfig) @levanter.config.main() @@ -27,8 +27,6 @@ def main(args: RayCachedLMDatasetConfig): tokenizer = args.the_tokenizer - wandb.init(mode="offline") - for split in ["train", "validation"]: print(f"Caching {split} to {args.cache_dir}.") # connect or start the actor diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py index 2bdcc8d58..744886c1e 100644 --- a/src/levanter/tracker/__init__.py +++ b/src/levanter/tracker/__init__.py @@ -1,9 +1,24 @@ from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.tracker import ( CompositeTracker, + NullTrackerConfig, Tracker, + TrackerConfig, current_tracker, jit_log_metrics, log_metrics, log_summary, ) + + +__all__ = [ + "Tracker", + "TrackerConfig", + "CompositeTracker", + "log_metrics", + "log_summary", + "current_tracker", + "jit_log_metrics", + "log_optimizer_hyperparams", + "NullTrackerConfig", +] diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 0e4228b4c..fb6f0493b 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -160,3 +160,22 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): global _global_tracker _global_tracker = self.old_tracker + + +class NullTracker(Tracker): + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + pass + + def log_summary(self, metrics: dict[str, Any]): + pass + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + +class NullTrackerConfig(TrackerConfig): + def init(self, run_id: Optional[str], hparams=None) -> Tracker: + return NullTracker() diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 66e43a59e..6ef0125af 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -116,7 +116,7 @@ def init(self, run_id: Optional[str], hparams=None) -> WandbTracker: # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled # however, we do share information about the run id, so that we can link to it from the other workers - if is_rank_0: + if jax.process_index() != 0: mode = self.mode else: mode = "disabled" diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 4f51970d8..dcbfcbfff 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union import equinox as eqx import jax @@ -17,7 +17,6 @@ import jmp import numpy as np import optax -import wandb from draccus import field from jax import ShapeDtypeStruct from jax.experimental import multihost_utils @@ -31,16 +30,15 @@ import levanter.logging import levanter.tracker -import levanter.tracker.tracker import levanter.tracker.wandb -from levanter import logging +from levanter import logging, tracker from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded from levanter.logging import capture_time -from levanter.tracker.wandb import WandbConfig +from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -119,7 +117,7 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks - _tracker: levanter.tracker.tracker.Tracker + tracker: levanter.tracker.Tracker is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable _cmanagers: List[typing.ContextManager] = [] @@ -148,8 +146,10 @@ def __init__( self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable - # TODO: hacky hack - self._tracker = levanter.tracker.wandb.WandbTracker(wandb.run) + if isinstance(config.tracker, Sequence): + self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) + else: + self.tracker = config.tracker.init(self.run_id) self._cmanagers = [] @cached_property @@ -217,7 +217,7 @@ def __enter__(self): raise RuntimeError("Trainer is already entered") self._cmanagers = [ - levanter.current_tracker(self._tracker), + levanter.current_tracker(self.tracker), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] @@ -249,7 +249,7 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - with levanter.tracker.current_tracker(self._tracker): + with levanter.tracker.current_tracker(self.tracker): if model is not None and model_init is not None: raise ValueError("only one of model and model_init should be specified") elif model is None and model_init is None: @@ -293,7 +293,7 @@ def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepI """ Performs a single training step. """ - with capture_time() as step_time, levanter.current_tracker(self._tracker): + with capture_time() as step_time, levanter.current_tracker(self.tracker): key, new_key = jax.random.split(state.training_key) loss, new_model, new_optstate = self._train_step_fn( state.model, state.opt_state, *batch, **batch_kwargs, key=key @@ -310,7 +310,7 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) - with levanter.current_tracker(self._tracker): + with levanter.current_tracker(self.tracker): while state.step < self.config.num_train_steps: with capture_time() as loading_time: example = next(iter_data) @@ -348,7 +348,6 @@ def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): self.add_hook(callbacks.log_step_info, every=1) if eval_dataset is not None: self.add_eval_hook(eval_dataset) - self.add_hook(callbacks.wandb_xla_logger(self.config.wandb), every=self.config.steps_per_eval) # engine.add_hook(callbacks.log_memory_usage(), every=1) checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency @@ -502,11 +501,13 @@ class TrainerConfig: seed: int = 0 # random seed mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy - wandb: WandbConfig = field(default_factory=WandbConfig) + wandb: Optional[tracker.wandb.WandbConfig] = None log_dir: Path = Path("logs/") run_base_dir: Path = Path("runs/") id: Optional[str] = None # run id. if None, will be set to a random string + tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig) + # config related to partitioning batch_axis: Optional[str] = "batch" # Batch axis for data parallel. @@ -553,15 +554,6 @@ class TrainerConfig: # whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes shutdown_at_exit: Union[bool, float] = False - @property - def run_name(self) -> str: - try: - import wandb - - return wandb.run and (wandb.run.name or wandb.run.id) or "unnamed" - except ImportError: - return "unnamed" - @property def TrainBatch(self): return Axis("batch", self.train_batch_size) @@ -570,15 +562,20 @@ def TrainBatch(self): def EvalBatch(self): return Axis("batch", self.eval_batch_size) + def __post_init__(self): + if self.wandb is not None: + warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) + self.tracker = self.wandb + def initialize(self, all_config): - """Initializes jax, wandb, logging, setting the run name/id in the process""" - self.distributed.initialize() - self._maybe_set_id() - self.ray.initialize() + """Initializes jax, logging, setting the run name/id in the process""" self._initialize_jax_config() + self.distributed.initialize() self._validate_and_set_defaults() - self.wandb.init(self.id, all_config) + + self._maybe_set_id() self._initialize_logging() + self.ray.initialize() if self.require_accelerator is None: self.require_accelerator = not sys.platform.startswith("darwin") @@ -659,7 +656,7 @@ def _maybe_set_id(self): # TODO: this doesn't work with wandb sweeps. need to reconcile when we merge if "RUN_ID" in os.environ: self.id = os.environ["RUN_ID"] - elif self.wandb.id is not None: + elif self.wandb is not None and self.wandb.id is not None: self.id = self.wandb.id else: # wandb run ids are 8 characters [a-z0-9], which we'll emulate here From a324ae50ce14610395bcdefbf2f497572807c5dd Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 9 Nov 2023 13:16:28 -0800 Subject: [PATCH 010/205] tiny cleanup --- src/levanter/tracker/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 6ef0125af..5cc322fb5 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -84,7 +84,7 @@ class WandbConfig(TrackerConfig): id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be whatever W&B decides. - resume: Optional[Union[bool, str]] = None # + resume: Optional[Union[bool, str]] = None """ Set the resume behavior. Options: "allow", "must", "never", "auto" or None. By default, if the new run has the same ID as a previous run, this run overwrites that data. From cfdcbb961828a14da1313ec987b71d14ee1b9eb5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 9 Nov 2023 13:49:54 -0800 Subject: [PATCH 011/205] add some tests --- src/levanter/main/cache_dataset.py | 4 +-- src/levanter/tracker/__init__.py | 4 +-- src/levanter/tracker/tracker.py | 9 ++++--- tests/test_tracker.py | 41 ++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 7 deletions(-) create mode 100644 tests/test_tracker.py diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 80ff6949e..616d917fb 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -8,7 +8,7 @@ from levanter.distributed import RayConfig from levanter.logging import init_logging from levanter.tracker import TrackerConfig -from levanter.tracker.tracker import NullTrackerConfig +from levanter.tracker.tracker import NoopTrackerConfig logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ @dataclass class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): - tracker: TrackerConfig = field(default_factory=NullTrackerConfig) + tracker: TrackerConfig = field(default_factory=NoopTrackerConfig) @levanter.config.main() diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py index 744886c1e..440e3375b 100644 --- a/src/levanter/tracker/__init__.py +++ b/src/levanter/tracker/__init__.py @@ -1,7 +1,7 @@ from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.tracker import ( CompositeTracker, - NullTrackerConfig, + NoopTracker, Tracker, TrackerConfig, current_tracker, @@ -20,5 +20,5 @@ "current_tracker", "jit_log_metrics", "log_optimizer_hyperparams", - "NullTrackerConfig", + "NoopTracker", ] diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index fb6f0493b..cafb9eb61 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -1,4 +1,5 @@ import abc +import dataclasses import typing from contextlib import AbstractContextManager from typing import Any, List, Optional @@ -162,7 +163,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): _global_tracker = self.old_tracker -class NullTracker(Tracker): +class NoopTracker(Tracker): def log_hyperparameters(self, hparams: dict[str, Any]): pass @@ -176,6 +177,8 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s pass -class NullTrackerConfig(TrackerConfig): +@TrackerConfig.register_subclass("noop") +@dataclasses.dataclass +class NoopTrackerConfig(TrackerConfig): def init(self, run_id: Optional[str], hparams=None) -> Tracker: - return NullTracker() + return NoopTracker() diff --git a/tests/test_tracker.py b/tests/test_tracker.py new file mode 100644 index 000000000..72f118467 --- /dev/null +++ b/tests/test_tracker.py @@ -0,0 +1,41 @@ +# NOTE: Do not explicitly import wandb/other trackers here, as this will cause the tests to trivially pass. +import dataclasses +from typing import Tuple + +import pytest +import yaml + +from levanter.tracker import TrackerConfig + + +def test_tracker_plugin_stuff_works(): + assert TrackerConfig.get_choice_class("wandb") is not None + with pytest.raises(KeyError): + TrackerConfig.get_choice_class("foo") + + +def test_tracker_plugin_multi_parsing_work(): + config = """ + tracker: + type: noop + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig | Tuple[TrackerConfig, ...] + + import draccus + + from levanter.tracker.tracker import NoopTrackerConfig + + assert isinstance(draccus.decode(ConfigHolder, parsed).tracker, NoopTrackerConfig) + + config = """ + tracker: + - type: noop + - type: wandb + """ + parsed = yaml.safe_load(config) + decoded = draccus.decode(ConfigHolder, parsed).tracker + assert decoded == (NoopTrackerConfig(), TrackerConfig.get_choice_class("wandb")()) From 2ddc558e9e560833f48740509fcb3e485845069e Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 9 Nov 2023 15:20:20 -0800 Subject: [PATCH 012/205] migrate alpaca-lora to new logger --- examples/alpaca-lora/alpaca_lora.py | 99 +++++++++++++++-------------- 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 76151f6cd..b87286a05 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -103,57 +103,58 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - tracker.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } - ) - - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) - - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + with trainer: + # end major difference from Alpaca + + trainer.add_default_hooks() + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + + tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": From 9b0df08b683d1c0bdaa37e53977f9d5252904e09 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 10 Nov 2023 13:15:39 -0800 Subject: [PATCH 013/205] sort of get tb to work --- config/gpt2_nano_tb.yaml | 26 +++++++++ src/levanter/tracker/tensorboard.py | 82 +++++++++++++++++++++-------- src/levanter/tracker/tracker.py | 3 +- 3 files changed, 87 insertions(+), 24 deletions(-) create mode 100644 config/gpt2_nano_tb.yaml diff --git a/config/gpt2_nano_tb.yaml b/config/gpt2_nano_tb.yaml new file mode 100644 index 000000000..9ada16aa3 --- /dev/null +++ b/config/gpt2_nano_tb.yaml @@ -0,0 +1,26 @@ +data: + id: dlwh/wikitext_103_detokenized +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_eval_parallelism: 1 + per_device_parallelism: 1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + tracker: + type: tensorboard + logdir: tb_logs/ diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index 1377167ba..779856caf 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -1,48 +1,84 @@ import logging import os -from pathlib import Path -from typing import Any, Optional, Union +import typing +from dataclasses import dataclass +from typing import Any, Optional -from levanter.tracker import Tracker +from levanter.tracker import Tracker, TrackerConfig, helpers pylogger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + from tensorboardX import SummaryWriter # noqa: F401 -class TensorboardTracker(Tracker): - def __init__(self, logdir: Union[str, Path]): - self.logdir = logdir - self.writer = None - - def init(self, run_id: Optional[str]): - from tensorboardX import SummaryWriter - dir_to_write = self.logdir - if run_id is not None: - dir_to_write = os.path.join(dir_to_write, run_id) - self.writer = SummaryWriter(dir_to_write) +class TensorboardTracker(Tracker): + def __init__(self, writer: "SummaryWriter"): + self.writer = writer def log_hyperparameters(self, hparams: dict[str, Any]): - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - self.writer.add_hparams(hparams, {"dummy": 0}) def log(self, metrics: dict[str, Any], *, step, commit=None): del commit - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - for k, v in metrics.items(): self.writer.add_scalar(k, v, step) def log_summary(self, metrics: dict[str, Any]): - if self.writer is None: - raise RuntimeError("Must call init before logging metrics") - for k, v in metrics.items(): self.writer.add_scalar(k, v, global_step=None) def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): pylogger.warning("TensorboardLogger does not support logging artifacts yet") pass + + +@TrackerConfig.register_subclass("tensorboard") +@dataclass +class TensorboardTrackerConfig(TrackerConfig): + logdir: str = "tblogs" + comment: Optional[str] = "" + purge_step: Optional[int] = None + max_queue: Optional[int] = 10 + flush_secs: Optional[int] = 120 + filename_suffix: Optional[str] = "" + write_to_disk: Optional[bool] = True + + def init(self, run_id: Optional[str], hparams=None) -> TensorboardTracker: + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + + pylogger.info(f"Writing Tensorboard logs to {dir_to_write}") + + from tensorboardX import SummaryWriter # noqa: F811 + + writer = SummaryWriter( + dir_to_write, + comment=self.comment, + purge_step=self.purge_step, + max_queue=self.max_queue, + flush_secs=self.flush_secs, + filename_suffix=self.filename_suffix, + write_to_disk=self.write_to_disk, + ) + + hparams_dict = helpers.hparams_to_dict(hparams) + hparams_dict = _flatten_nested_dict(hparams_dict) + + writer.add_hparams(hparams_dict, {"dummy": 0}) + + return TensorboardTracker(writer) + + +def _flatten_nested_dict(d): + def items(): + for key, value in d.items(): + if isinstance(value, dict): + for subkey, subvalue in _flatten_nested_dict(value).items(): + yield key + "/" + subkey, subvalue + else: + yield key, value + + return dict(items()) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index cafb9eb61..3c94111cf 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -73,9 +73,10 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s _global_tracker: Optional["Tracker"] = None -class TrackerConfig(draccus.PluginRegistry): +class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" + @abc.abstractmethod def init(self, run_id: Optional[str], hparams=None) -> Tracker: raise NotImplementedError From 4fd2526294bc19553db5a3e0d89e38b0a3c92d4e Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 14 Nov 2023 15:13:10 -0800 Subject: [PATCH 014/205] wip --- src/levanter/doremi.py | 136 ++++++++++++++++++++++++++++++++ src/levanter/models/lm_model.py | 8 +- 2 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 src/levanter/doremi.py diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py new file mode 100644 index 000000000..96590a90b --- /dev/null +++ b/src/levanter/doremi.py @@ -0,0 +1,136 @@ +from typing import Iterator, TypeVar + +import jax.lax +import jax.random as jrandom +import jax.numpy as jnp +import optax +from jaxtyping import PRNGKeyArray + + +import haliax as hax +from haliax.types import IntScalar +from levanter.data import Dataset, ShardableDataset +from levanter.data.mixture import MixtureDataset +from levanter.trainer import Trainer, TrainerConfig + +T = TypeVar("T") + +def estimate_mixture_weights( + optimizer: optax.GradientTransformation, + loss_fn, + initial_proxy, + ref, + data_sources: dict[str, ShardableDataset[T]], + ref_weights: dict[str, float], + domain_weight_step: float = 1.0, + smoothing: float = 1e-3, + *, + key: PRNGKeyArray, +) -> dict[str, float]: + """ + Estimate the mixture weights for the data sources using DoReMi. + https://arxiv.org/abs/2305.10429 + """ + training_key, data_key = jrandom.split(key) + domain_indices = list(data_sources.keys()) + domain_to_index = {domain: index for index, domain in enumerate(domain_indices)} + tagged_mixture = domain_tagged_mixture(data_sources, ref_weights, domain_to_index, key=data_key) + + + state = trainer.initial_state(training_key, model=initial_proxy) + + + del initial_proxy + + # Initialize domain weights + Domain = hax.Axis("domain", len(domain_indices)) + initial_alpha = hax.ones(Domain) / Domain.size + + def doremi_step(opt_state, proxy, alpha, batch, domains): + # calculate per-elem losses for proxy and ref + proxy_losses = TODO + ref_losses = TODO + # calculate excess losses + excess_losses = hax.max(proxy_losses - ref_losses, 0) + def total_losses_per_domain(excess_losses, domain): + return jax.lax.cond( + hax.sum(domains == domain) == 0, + lambda: hax.zeros(()), + lambda: hax.mean(excess_losses, where=domains == domain), + ) + + per_domain_losses = hax.vmap(total_losses_per_domain, Domain)(excess_losses, hax.arange(Domain)) + # Update domain weights (exp is entrywise): α ← α exp(ηλt) + alpha = alpha * hax.exp(domain_weight_step * per_domain_losses) + # Renormalize and smooth domain weights: α ← (1 − c) αPk i=1 α′ t[i] + cu + alpha /= hax.sum(alpha) + alpha = (1 - smoothing) * alpha + initial_alpha * smoothing + # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) + optimizer.update() + + + + + + + def update_one_domain(per_token_losses, domains, target_domain): + total_in_domain = jnp.sum(domains == target_domain) + + + loader = trainer.sharded_loader(tagged_mixture, trainer.TrainBatch) + for batch, tags in loader: + state = trainer.train_step(state, batch) + trainer.run_hooks(state) + + # Compute per-domain excess losses for each domain i ∈ {1, 2, ..., k} (ℓ_ο_j(x) is j-th token-level loss): + # λ_t[i] ← (1 / |B\cap D_i|) * Σ_(x ∈ B\cap D_i) (1 / |x|) * Σ_(x ∈ BD_i) Σ_(j=1)^|x| max{ℓ_(t-1,j)(x) - ℓ_ref,j(x), 0} + + + +def domain_tagged_mixture( + data_sources: dict[str, ShardableDataset[T]], + weights: dict[str, float], + domain_to_index: dict[str, int], + *, + key: PRNGKeyArray, +) -> MixtureDataset[(T, IntScalar)]: + """ + Domain tagged mixture dataset. This dataset will yield from the datasets according to the weights, + and will yield the domain index as a second element of the tuple. + """ + tagged_datasets = { + domain_index: DomainTaggedDataset(data_sources[domain], domain_index) + for domain, domain_index in domain_to_index.items() + } + + return MixtureDataset(tagged_datasets, weights, key=key) + + +class DomainTaggedDataset(ShardableDataset[(T, hax.NamedArray)]): # named array is a scalar int + + def __init__( + self, + dataset: ShardableDataset[T], + domain_index: int|hax.NamedArray, + ): + self.dataset = dataset + + if isinstance(domain_index, int): + self.domain_index = hax.named(jnp.array(domain_index, dtype=int), ()) + else: + self.domain_index = domain_index + + def shard(self, shard_id: int, num_shards: int) -> "DomainTaggedDataset[T]": + return DomainTaggedDataset(self.dataset.shard(shard_id, num_shards), self.domain_index) + + def __iter__(self) -> Iterator[(T, IntScalar)]: + for item in self.dataset: + yield item, self.domain_index + + + + + + + + diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 665137846..57e80d523 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -101,10 +101,16 @@ def compute_loss( logits = self(example.tokens, example.attn_mask, key=key) targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) - return cross_entropy_loss( + losses = cross_entropy_loss( logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask ) + if reduction is None: + return hax.where(example.loss_mask, losses, 0) + else: + return losses + + @property def vocab_size(self) -> int: return self.Vocab.size From a608a65b0ba816f81f45a7fe940b472595540642 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 16 Nov 2023 12:17:24 -0800 Subject: [PATCH 015/205] wip --- src/levanter/doremi.py | 112 +++++++++++++++++++---------------------- 1 file changed, 52 insertions(+), 60 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 96590a90b..ab8db205b 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,31 +1,32 @@ from typing import Iterator, TypeVar -import jax.lax -import jax.random as jrandom +import equinox as eqx import jax.numpy as jnp +import jax.random as jrandom import optax from jaxtyping import PRNGKeyArray - import haliax as hax from haliax.types import IntScalar + from levanter.data import Dataset, ShardableDataset from levanter.data.mixture import MixtureDataset -from levanter.trainer import Trainer, TrainerConfig + T = TypeVar("T") + def estimate_mixture_weights( - optimizer: optax.GradientTransformation, - loss_fn, - initial_proxy, - ref, - data_sources: dict[str, ShardableDataset[T]], - ref_weights: dict[str, float], - domain_weight_step: float = 1.0, - smoothing: float = 1e-3, - *, - key: PRNGKeyArray, + optimizer: optax.GradientTransformation, + loss_fn, + initial_proxy, + ref, + data_sources: dict[str, ShardableDataset[T]], + ref_weights: dict[str, float], + domain_weight_step: float = 1.0, + smoothing: float = 1e-3, + *, + key: PRNGKeyArray, ) -> dict[str, float]: """ Estimate the mixture weights for the data sources using DoReMi. @@ -36,63 +37,63 @@ def estimate_mixture_weights( domain_to_index = {domain: index for index, domain in enumerate(domain_indices)} tagged_mixture = domain_tagged_mixture(data_sources, ref_weights, domain_to_index, key=data_key) - state = trainer.initial_state(training_key, model=initial_proxy) - del initial_proxy # Initialize domain weights Domain = hax.Axis("domain", len(domain_indices)) initial_alpha = hax.ones(Domain) / Domain.size - def doremi_step(opt_state, proxy, alpha, batch, domains): - # calculate per-elem losses for proxy and ref + # calculate per-token losses for proxy and ref + def compute_excess_loss(proxy, ref, batch): proxy_losses = TODO ref_losses = TODO # calculate excess losses - excess_losses = hax.max(proxy_losses - ref_losses, 0) - def total_losses_per_domain(excess_losses, domain): - return jax.lax.cond( - hax.sum(domains == domain) == 0, - lambda: hax.zeros(()), - lambda: hax.mean(excess_losses, where=domains == domain), - ) - - per_domain_losses = hax.vmap(total_losses_per_domain, Domain)(excess_losses, hax.arange(Domain)) - # Update domain weights (exp is entrywise): α ← α exp(ηλt) - alpha = alpha * hax.exp(domain_weight_step * per_domain_losses) - # Renormalize and smooth domain weights: α ← (1 − c) αPk i=1 α′ t[i] + cu - alpha /= hax.sum(alpha) - alpha = (1 - smoothing) * alpha + initial_alpha * smoothing - # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) - optimizer.update() - + excess_losses = proxy_losses - ref_losses + return excess_losses + # Loss is alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) + def proxy_model_loss(excess_losses, domains, alpha): + one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch + # basically einsum(" * -> ", alpha, one_hot_domains, excess_losses) + loss = hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses).scalar() + return loss + def doremi_step(opt_state, proxy, alpha, batch, domains): + # this is one of those times when PyTorch's backward() is nice + excess_losses, excess_backward = eqx.filter_vjp(lambda proxy: compute_excess_loss(proxy, ref, batch), proxy) + # Update domain weights + ## Compute per-domain excess losses + clipped_losses = hax.maximum(excess_losses, 0) + one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch + per_domain_losses = hax.dot(excess_losses.axes, one_hot_domains, clipped_losses) - def update_one_domain(per_token_losses, domains, target_domain): - total_in_domain = jnp.sum(domains == target_domain) + old_alpha = alpha + alpha = alpha * hax.exp(domain_weight_step * per_domain_losses) + alpha /= hax.sum(alpha) + alpha = (1 - smoothing) * alpha + initial_alpha * smoothing + alpha_distance = hax.sum(hax.abs(alpha - old_alpha)) - loader = trainer.sharded_loader(tagged_mixture, trainer.TrainBatch) - for batch, tags in loader: - state = trainer.train_step(state, batch) - trainer.run_hooks(state) + # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) + val, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) + grad = excess_backward(grad_loss) - # Compute per-domain excess losses for each domain i ∈ {1, 2, ..., k} (ℓ_ο_j(x) is j-th token-level loss): - # λ_t[i] ← (1 / |B\cap D_i|) * Σ_(x ∈ B\cap D_i) (1 / |x|) * Σ_(x ∈ BD_i) Σ_(j=1)^|x| max{ℓ_(t-1,j)(x) - ℓ_ref,j(x), 0} + updates, new_state = optimizer.update(opt_state, grad, params=proxy) + proxy = optax.apply_updates(proxy, updates) + return new_state, proxy, alpha def domain_tagged_mixture( - data_sources: dict[str, ShardableDataset[T]], - weights: dict[str, float], - domain_to_index: dict[str, int], - *, - key: PRNGKeyArray, + data_sources: dict[str, ShardableDataset[T]], + weights: dict[str, float], + domain_to_index: dict[str, int], + *, + key: PRNGKeyArray, ) -> MixtureDataset[(T, IntScalar)]: """ Domain tagged mixture dataset. This dataset will yield from the datasets according to the weights, @@ -107,11 +108,10 @@ def domain_tagged_mixture( class DomainTaggedDataset(ShardableDataset[(T, hax.NamedArray)]): # named array is a scalar int - def __init__( - self, - dataset: ShardableDataset[T], - domain_index: int|hax.NamedArray, + self, + dataset: ShardableDataset[T], + domain_index: int | hax.NamedArray, ): self.dataset = dataset @@ -126,11 +126,3 @@ def shard(self, shard_id: int, num_shards: int) -> "DomainTaggedDataset[T]": def __iter__(self) -> Iterator[(T, IntScalar)]: for item in self.dataset: yield item, self.domain_index - - - - - - - - From 8d34f6fca848d7f3a3ce5ba8ff501fa2b19732b0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 17 Nov 2023 00:30:29 -0800 Subject: [PATCH 016/205] update configs, expose a method to find trackers --- README.md | 3 +- config/backpack.yaml | 3 +- config/gpt2_1536.yaml | 3 +- config/gpt2_20b.yaml | 3 +- config/gpt2_7b.yaml | 3 +- config/gpt2_large.yaml | 3 +- config/gpt2_medium.yaml | 3 +- config/gpt2_micro.yaml | 3 +- config/gpt2_small.yaml | 3 +- config/gpt2_small_fast.yaml | 3 +- config/gpt2_small_fast_mix.yaml | 3 +- config/gpt2_small_fast_pile.yaml | 3 +- config/gpt2_small_fast_wiki.yaml | 3 +- config/gpt2_xl.yaml | 3 +- config/llama2_7b.yaml | 3 +- config/llama2_7b_continued.yaml | 3 +- config/llama2_nano.yaml | 3 +- config/lora/mpt_biomed.yaml | 3 +- config/mpt_7b_continued.yaml | 3 +- config/mpt_7b_continued_biomedlm.yaml | 3 +- docs/Configuration-Guide.md | 3 +- docs/Training-On-Your-Data.md | 3 +- examples/alpaca-lora/alpaca_lora.py | 3 +- src/levanter/__init__.py | 2 +- src/levanter/callbacks.py | 18 ++-- src/levanter/data/shard_cache.py | 2 +- src/levanter/main/lora_lm.py | 4 +- src/levanter/main/train_lm.py | 4 +- src/levanter/tracker/__init__.py | 20 ++-- src/levanter/tracker/helpers.py | 2 +- src/levanter/tracker/tensorboard.py | 2 + src/levanter/tracker/tracker.py | 96 +------------------ src/levanter/tracker/tracker_fns.py | 130 ++++++++++++++++++++++++++ src/levanter/tracker/wandb.py | 30 +++--- src/levanter/trainer.py | 4 +- 35 files changed, 224 insertions(+), 159 deletions(-) create mode 100644 src/levanter/tracker/tracker_fns.py diff --git a/README.md b/README.md index 3f41614e3..bbc5cc6c6 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/backpack.yaml b/config/backpack.yaml index 5b6cef3cb..02a53064a 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -10,7 +10,8 @@ model: num_senses: 16 sense_intermediate_scale: 4 trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "backpack" ] diff --git a/config/gpt2_1536.yaml b/config/gpt2_1536.yaml index 50ccbd882..ad552e5e0 100644 --- a/config/gpt2_1536.yaml +++ b/config/gpt2_1536.yaml @@ -8,7 +8,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_20b.yaml b/config/gpt2_20b.yaml index 76bf6ba96..670b47c46 100644 --- a/config/gpt2_20b.yaml +++ b/config/gpt2_20b.yaml @@ -12,7 +12,8 @@ model: use_bias: false fcm_prob: 0.15 trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_7b.yaml b/config/gpt2_7b.yaml index affb67aa5..1af6cf22c 100644 --- a/config/gpt2_7b.yaml +++ b/config/gpt2_7b.yaml @@ -11,7 +11,8 @@ model: resid_pdrop: 0.0 fcm_prob: 0.15 trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_large.yaml b/config/gpt2_large.yaml index 525a92c99..3d82d763a 100644 --- a/config/gpt2_large.yaml +++ b/config/gpt2_large.yaml @@ -8,7 +8,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_medium.yaml b/config/gpt2_medium.yaml index 9ea4408bc..0ca3162d1 100644 --- a/config/gpt2_medium.yaml +++ b/config/gpt2_medium.yaml @@ -8,7 +8,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_micro.yaml b/config/gpt2_micro.yaml index 274ecddaa..9b0fb8f60 100644 --- a/config/gpt2_micro.yaml +++ b/config/gpt2_micro.yaml @@ -6,7 +6,8 @@ model: num_heads: 8 num_layers: 4 trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index 74d0e031a..82f478320 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -8,7 +8,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 4c8434f38..18c71a44c 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -8,7 +8,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_mix.yaml b/config/gpt2_small_fast_mix.yaml index 0785e9103..5ba75af1b 100644 --- a/config/gpt2_small_fast_mix.yaml +++ b/config/gpt2_small_fast_mix.yaml @@ -21,7 +21,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext+wiki", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index f30743c1d..47d4dda8a 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -8,7 +8,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "pile", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_wiki.yaml b/config/gpt2_small_fast_wiki.yaml index 407d8705b..d36a47f69 100644 --- a/config/gpt2_small_fast_wiki.yaml +++ b/config/gpt2_small_fast_wiki.yaml @@ -9,7 +9,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] diff --git a/config/gpt2_xl.yaml b/config/gpt2_xl.yaml index 8230b56a5..5084133e7 100644 --- a/config/gpt2_xl.yaml +++ b/config/gpt2_xl.yaml @@ -8,7 +8,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 diff --git a/config/llama2_7b.yaml b/config/llama2_7b.yaml index 68931f3fa..b4ebe705f 100644 --- a/config/llama2_7b.yaml +++ b/config/llama2_7b.yaml @@ -11,7 +11,8 @@ model: # initialize_from_hf: "meta-llama/Llama-2-7b-hf" # use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["openwebtext", "llama"] diff --git a/config/llama2_7b_continued.yaml b/config/llama2_7b_continued.yaml index e03be7168..edb72a7e4 100644 --- a/config/llama2_7b_continued.yaml +++ b/config/llama2_7b_continued.yaml @@ -6,7 +6,8 @@ model: initialize_from_hf: true use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "llama2"] diff --git a/config/llama2_nano.yaml b/config/llama2_nano.yaml index d7196c59b..877c1da8f 100644 --- a/config/llama2_nano.yaml +++ b/config/llama2_nano.yaml @@ -11,7 +11,8 @@ model: num_heads: 4 num_layers: 2 trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["openwebtext", "llama"] mp: p=f32 diff --git a/config/lora/mpt_biomed.yaml b/config/lora/mpt_biomed.yaml index f49267ca1..6b19d0ab5 100644 --- a/config/lora/mpt_biomed.yaml +++ b/config/lora/mpt_biomed.yaml @@ -11,7 +11,8 @@ lora: alpha: 32.0 target_modules: ["Wqkv"] trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["mpt", "lora", "pubmed"] diff --git a/config/mpt_7b_continued.yaml b/config/mpt_7b_continued.yaml index a7eaf800b..8357967e6 100644 --- a/config/mpt_7b_continued.yaml +++ b/config/mpt_7b_continued.yaml @@ -4,7 +4,8 @@ model: initialize_from_hf: true use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "mpt"] diff --git a/config/mpt_7b_continued_biomedlm.yaml b/config/mpt_7b_continued_biomedlm.yaml index 44961df46..c40ddc508 100644 --- a/config/mpt_7b_continued_biomedlm.yaml +++ b/config/mpt_7b_continued_biomedlm.yaml @@ -10,7 +10,8 @@ model: initialize_from_hf: "mosaicml/mpt-7b@68e1a8e0ebb9b30f3c45c1ef6195980f29063ae2" use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pubmed", "mpt", "continued"] diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index c7891e1e9..da253454e 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -35,7 +35,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/docs/Training-On-Your-Data.md b/docs/Training-On-Your-Data.md index edf33e0af..4c543b04f 100644 --- a/docs/Training-On-Your-Data.md +++ b/docs/Training-On-Your-Data.md @@ -214,7 +214,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" # TODO tags: ["gpt2"] diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index b87286a05..4ad4dfba2 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -12,7 +12,6 @@ import haliax as hax import levanter -from levanter import tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.lora import ( LoraConfig, @@ -113,7 +112,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - tracker.log_summary( + levanter.tracker.log_summary( { "parameter_count": all_param_count, "trainable_parameter_count": just_lora_params, diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 971dbeada..519e387ac 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -5,4 +5,4 @@ import levanter.logging as logging import levanter.tracker as tracker import levanter.visualization as visualization -from levanter.tracker import current_tracker, log_metrics, log_summary +from levanter.tracker import current_tracker, get_tracker, jit_log_metrics, log_metrics, log_summary diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 85220cda5..a80d0619e 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -62,7 +62,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - levanter.log_metrics({f"{prefix}/loss": loss}, step=info.step) + levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -75,7 +75,7 @@ def compute_loss(info: StepInfo): def log_step_info(step: StepInfo): - levanter.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) + levanter.tracker.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") @@ -111,14 +111,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - levanter.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - levanter.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - levanter.log_metrics( + levanter.tracker.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -128,7 +128,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - levanter.log_metrics( + levanter.tracker.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -221,7 +221,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - levanter.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -232,14 +232,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - levanter.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - levanter.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 48aa2c1a1..83e4a6c85 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -519,7 +519,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - levanter.log_metrics(to_log, step=None, commit=self.commit) + levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index ef02b687f..4e621239e 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -8,7 +8,7 @@ import haliax.random import levanter -from levanter import callbacks, tracker +from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.data.text import CausalLmDataset, LMDatasetConfig, LmExample from levanter.lora import ( @@ -94,7 +94,7 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - tracker.log_summary( + levanter.tracker.log_summary( { "parameter_count": all_param_count, "trainable_parameter_count": just_lora_params, diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index dd6ac20df..60d5dbbb6 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -11,7 +11,7 @@ from haliax.partitioning import named_jit, round_axis_for_partitioning import levanter -from levanter import callbacks, tracker +from levanter import callbacks from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config @@ -130,7 +130,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): else: logger.info("No checkpoint found. Starting from scratch.") - tracker.log_summary( + levanter.tracker.log_summary( { "parameter_count": parameter_count(state.model), } diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py index 440e3375b..73e8b5872 100644 --- a/src/levanter/tracker/__init__.py +++ b/src/levanter/tracker/__init__.py @@ -1,24 +1,16 @@ from levanter.tracker.helpers import log_optimizer_hyperparams -from levanter.tracker.tracker import ( - CompositeTracker, - NoopTracker, - Tracker, - TrackerConfig, - current_tracker, - jit_log_metrics, - log_metrics, - log_summary, -) +from levanter.tracker.tracker import CompositeTracker, NoopTracker, Tracker, TrackerConfig +from levanter.tracker.tracker_fns import current_tracker, get_tracker, jit_log_metrics, log_metrics, log_summary __all__ = [ "Tracker", "TrackerConfig", "CompositeTracker", - "log_metrics", - "log_summary", - "current_tracker", - "jit_log_metrics", "log_optimizer_hyperparams", "NoopTracker", + "current_tracker", + "jit_log_metrics", + "log_metrics", + "log_summary", ] diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py index 26e5e0e2e..31131d1ac 100644 --- a/src/levanter/tracker/helpers.py +++ b/src/levanter/tracker/helpers.py @@ -24,7 +24,7 @@ def wrap_key(key): if hasattr(opt_state, "hyperparams"): params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - levanter.log_metrics(params, step=step) + levanter.tracker.log_metrics(params, step=step) def hparams_to_dict(hparams, **extra_hparams): diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index 779856caf..f554eca9f 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -14,6 +14,8 @@ class TensorboardTracker(Tracker): + name: str = "tensorboard" + def __init__(self, writer: "SummaryWriter"): self.writer = writer diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 3c94111cf..086d778a9 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -1,13 +1,9 @@ import abc import dataclasses import typing -from contextlib import AbstractContextManager from typing import Any, List, Optional import draccus -import jax - -from levanter.utils.jax_utils import is_inside_jit class Tracker(abc.ABC): @@ -15,7 +11,7 @@ class Tracker(abc.ABC): A tracker is responsible for logging metrics, hyperparameters, and artifacts. Meant to be used with the [current_tracker][] context manager, but can also be used directly. - The name is borrowed from Accelerate. + The name is borrowed from HF Accelerate. Examples: >>> from levanter.tracker import current_tracker, log_metrics @@ -24,6 +20,8 @@ class Tracker(abc.ABC): ... log_metrics({"foo": 1}, step=0) """ + name: str + @abc.abstractmethod def log_hyperparameters(self, hparams: dict[str, Any]): pass @@ -70,9 +68,6 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s tracker.log_artifact(artifact, name=name, type=type) -_global_tracker: Optional["Tracker"] = None - - class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" @@ -81,90 +76,9 @@ def init(self, run_id: Optional[str], hparams=None) -> Tracker: raise NotImplementedError -def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): - """ - Log metrics to the global tracker. - - Args - metrics: Metrics to log - step: Step to log at - commit: Whether to commit the metrics. If None, uses the default for the tracker. - """ - global _global_tracker - if _global_tracker is None: - raise RuntimeError("No global tracker set") - - if is_inside_jit(): - # we're inside a jit, so we need to log from the host - if commit: - raise ValueError("Cannot commit from inside jit") - jit_log_metrics(metrics, step=step) - else: - # TODO: do we need to coerce to np here? - _global_tracker.log(metrics, step=step) - - -def jit_log_metrics(metrics, *, step=None): - """uses jax effect callback to log to wandb from the host""" - jax.debug.callback(log_metrics, metrics, step=step) - - -def log_summary(metrics: dict[str, Any]): - """ - Log summary metrics to the global tracker. - - :param metrics: Metrics to log - """ - global _global_tracker - if _global_tracker is None: - raise RuntimeError("No global tracker set") - _global_tracker.log_summary(metrics) - - -@typing.overload -def current_tracker() -> "Tracker": - ... - - -@typing.overload -def current_tracker(tracker: "Tracker") -> typing.ContextManager: - """Returns a context manager for setting the global tracker""" - ... - - -def current_tracker( - tracker: Optional[Tracker] = None, -) -> Tracker | typing.ContextManager: - """ - Get or set the global tracker. - - :param tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. - :return: The global tracker, or a context manager for setting the global tracker. - """ - global _global_tracker - if tracker is None: - if _global_tracker is None: - raise RuntimeError("No global tracker set") - return _global_tracker - else: - return _GlobalLoggerContextManager(tracker) - - -class _GlobalLoggerContextManager(AbstractContextManager): - def __init__(self, tracker: "Tracker"): - self.tracker = tracker - - def __enter__(self): - global _global_tracker - self.old_tracker = _global_tracker - _global_tracker = self.tracker - - def __exit__(self, exc_type, exc_val, exc_tb): - global _global_tracker - _global_tracker = self.old_tracker - - class NoopTracker(Tracker): + name: str = "noop" + def log_hyperparameters(self, hparams: dict[str, Any]): pass diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py new file mode 100644 index 000000000..a437102ee --- /dev/null +++ b/src/levanter/tracker/tracker_fns.py @@ -0,0 +1,130 @@ +import typing +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional + +import jax + +from levanter.tracker import Tracker +from levanter.tracker.composite import CompositeTracker +from levanter.tracker.tensorboard import TensorboardTracker +from levanter.tracker.wandb import WandbTracker +from levanter.utils.jax_utils import is_inside_jit + + +_global_tracker: Optional["Tracker"] = None + + +def log_metrics(metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + """ + Log metrics to the global tracker. + + Args + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_tracker.log(metrics, step=step) + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global tracker. + + :param metrics: Metrics to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log_summary(metrics) + + +@typing.overload +def current_tracker() -> "Tracker": + ... + + +@typing.overload +def current_tracker(tracker: "Tracker") -> typing.ContextManager: + """Returns a context manager for setting the global tracker""" + ... + + +def current_tracker( + tracker: Optional[Tracker] = None, +) -> Tracker | typing.ContextManager: + """ + Get or set the global tracker. + + :param tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + :return: The global tracker, or a context manager for setting the global tracker. + """ + global _global_tracker + if tracker is None: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + return _global_tracker + else: + return _GlobalLoggerContextManager(tracker) + + +@typing.overload +def get_tracker(name: Literal["wandb"]) -> WandbTracker: + ... + + +@typing.overload +def get_tracker(name: Literal["tensorboard"]) -> TensorboardTracker: + ... + + +@typing.overload +def get_tracker(name: str) -> Tracker: + ... + + +def get_tracker(name: str) -> Tracker: + """ + Lookup a tracker in the current global tracker with the provided name. + + :param name: Name of the tracker + :return: The tracker + """ + tracker = current_tracker() + if isinstance(tracker, CompositeTracker): + for t in tracker.loggers: + if t.name == name: + return t + elif tracker.name == name: + return tracker + + raise ValueError(f"Tracker with name {name} not found") + + +class _GlobalLoggerContextManager(AbstractContextManager): + def __init__(self, tracker: "Tracker"): + self.tracker = tracker + + def __enter__(self): + global _global_tracker + self.old_tracker = _global_tracker + _global_tracker = self.tracker + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_tracker + _global_tracker = self.old_tracker diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 5cc322fb5..93e75f610 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -29,7 +29,8 @@ class WandbTracker(Tracker): - _run: Optional[WandbRun] + name: str = "wandb" + run: WandbRun def __init__(self, run: Optional[WandbRun]): import wandb @@ -37,29 +38,34 @@ def __init__(self, run: Optional[WandbRun]): if run is None: if wandb.run is None: logger.warning("Wandb run is not initialized. Initializing a new run.") - run = wandb.init() - - self._run = run + runx = wandb.init() + if runx is None: + raise RuntimeError("Wandb run is not initialized.") + self.run = runx + else: + self.run = wandb.run + else: + self.run = run def log_hyperparameters(self, hparams: dict[str, Any]): - if self._run is None: + if self.run is None: raise RuntimeError("Must call init before logging hyperparameters") - self._run.config.update(hparams) + self.run.config.update(hparams) def log(self, metrics: dict[str, Any], *, step, commit=None): - if self._run is None: + if self.run is None: raise RuntimeError("Must call init before logging metrics") - self._run.log(metrics, step=step, commit=commit) + self.run.log(metrics, step=step, commit=commit) def log_summary(self, metrics: dict[str, Any]): - if self._run is None: + if self.run is None: raise RuntimeError("Must call init before logging summary") - self._run.summary.update(metrics) + self.run.summary.update(metrics) def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - if self._run is None: + if self.run is None: raise RuntimeError("Must call init before logging artifacts") - self._run.log_artifact(artifact, name=name, type=type) + self.run.log_artifact(artifact, name=name, type=type) def is_wandb_available(): diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index dcbfcbfff..4636b50d0 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -315,7 +315,7 @@ def training_steps( with capture_time() as loading_time: example = next(iter_data) - levanter.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) info = self.train_step(state, example) state = info.state @@ -324,7 +324,7 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - levanter.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) yield info From 42d7f2c682ff6e097ad940daa8a4d6e176cfeae0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 17 Nov 2023 14:05:11 -0800 Subject: [PATCH 017/205] use `trainer` more to set logging --- examples/alpaca-lora/alpaca_lora.py | 7 +++---- examples/alpaca/alpaca.py | 4 +--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 4ad4dfba2..31de93252 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -100,10 +100,9 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) + # end major difference from Alpaca - with trainer: - # end major difference from Alpaca + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) @@ -112,7 +111,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - levanter.tracker.log_summary( + levanter.log_summary( { "parameter_count": all_param_count, "trainable_parameter_count": just_lora_params, diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 113cef91a..6ce7f06c0 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -209,9 +209,7 @@ def train(config: TrainArgs): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss) - - with trainer.device_mesh: + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = trainer.parameter_axis_mapping From b8877615e4a90278066a3044490bd4d07dbf7735 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 17 Nov 2023 14:05:46 -0800 Subject: [PATCH 018/205] test the tracker get name stuff --- src/levanter/main/cache_dataset.py | 3 +-- src/levanter/tracker/__init__.py | 2 +- src/levanter/tracker/tracker.py | 14 ++++++++++++++ src/levanter/tracker/tracker_fns.py | 5 ++--- tests/test_tracker.py | 21 ++++++++++++++++++++- 5 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 616d917fb..eaec4e597 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -7,8 +7,7 @@ from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig from levanter.logging import init_logging -from levanter.tracker import TrackerConfig -from levanter.tracker.tracker import NoopTrackerConfig +from levanter.tracker import NoopTrackerConfig, TrackerConfig logger = logging.getLogger(__name__) diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py index 73e8b5872..604e12b8b 100644 --- a/src/levanter/tracker/__init__.py +++ b/src/levanter/tracker/__init__.py @@ -1,5 +1,5 @@ from levanter.tracker.helpers import log_optimizer_hyperparams -from levanter.tracker.tracker import CompositeTracker, NoopTracker, Tracker, TrackerConfig +from levanter.tracker.tracker import CompositeTracker, NoopTracker, NoopTrackerConfig, Tracker, TrackerConfig from levanter.tracker.tracker_fns import current_tracker, get_tracker, jit_log_metrics, log_metrics, log_summary diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 086d778a9..27b3d47cf 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -46,6 +46,20 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): pass + def __enter__(self): + import levanter.tracker.tracker_fns as tracker_fns + + if hasattr(self, "_tracker_cm"): + raise RuntimeError("Tracker already set") + setattr(self, "_tracker_cm", tracker_fns.current_tracker(self)) + self._tracker_cm.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not hasattr(self, "_tracker_cm"): + raise RuntimeError("Tracker not set") + self._tracker_cm.__exit__(exc_type, exc_val, exc_tb) + delattr(self, "_tracker_cm") + class CompositeTracker(Tracker): def __init__(self, loggers: List[Tracker]): diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index a437102ee..d79d3cfa2 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -4,8 +4,7 @@ import jax -from levanter.tracker import Tracker -from levanter.tracker.composite import CompositeTracker +from levanter.tracker import CompositeTracker, Tracker from levanter.tracker.tensorboard import TensorboardTracker from levanter.tracker.wandb import WandbTracker from levanter.utils.jax_utils import is_inside_jit @@ -113,7 +112,7 @@ def get_tracker(name: str) -> Tracker: elif tracker.name == name: return tracker - raise ValueError(f"Tracker with name {name} not found") + raise KeyError(f"Tracker with name {name} not found") class _GlobalLoggerContextManager(AbstractContextManager): diff --git a/tests/test_tracker.py b/tests/test_tracker.py index 72f118467..62993fa16 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -5,7 +5,8 @@ import pytest import yaml -from levanter.tracker import TrackerConfig +import levanter.tracker +from levanter.tracker import CompositeTracker, TrackerConfig def test_tracker_plugin_stuff_works(): @@ -39,3 +40,21 @@ class ConfigHolder: parsed = yaml.safe_load(config) decoded = draccus.decode(ConfigHolder, parsed).tracker assert decoded == (NoopTrackerConfig(), TrackerConfig.get_choice_class("wandb")()) + + +def test_get_tracker_by_name(): + wandb_config = TrackerConfig.get_choice_class("wandb") + if wandb_config is None: + pytest.skip("wandb not installed") + + from levanter.tracker import NoopTracker + + wandb1 = wandb_config(mode="disabled").init(None) + tracker = CompositeTracker([wandb1, NoopTracker()]) + + with tracker: + assert levanter.tracker.get_tracker("wandb") is wandb1 + assert levanter.tracker.get_tracker("noop") is not None + + with pytest.raises(KeyError): + levanter.tracker.get_tracker("foo") From 3ebd1611c27a923aa2f8b8ca49772da6c15b147a Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 17 Nov 2023 14:10:08 -0800 Subject: [PATCH 019/205] minor --- src/levanter/tracker/tracker.py | 4 ++-- src/levanter/tracker/tracker_fns.py | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 27b3d47cf..b4e864ca5 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -50,13 +50,13 @@ def __enter__(self): import levanter.tracker.tracker_fns as tracker_fns if hasattr(self, "_tracker_cm"): - raise RuntimeError("Tracker already set") + raise RuntimeError("This tracker is already set as the global tracker") setattr(self, "_tracker_cm", tracker_fns.current_tracker(self)) self._tracker_cm.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): if not hasattr(self, "_tracker_cm"): - raise RuntimeError("Tracker not set") + raise RuntimeError("This tracker is not set as the global tracker") self._tracker_cm.__exit__(exc_type, exc_val, exc_tb) delattr(self, "_tracker_cm") diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index d79d3cfa2..fb845d53d 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -68,10 +68,22 @@ def current_tracker( tracker: Optional[Tracker] = None, ) -> Tracker | typing.ContextManager: """ - Get or set the global tracker. + Get or set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. - :param tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. - :return: The global tracker, or a context manager for setting the global tracker. + Args + tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Returns + If no tracker is provided, returns the current global tracker. + If a tracker is provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Examples + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... current_tracker().log_metrics({"foo": 2}, step=1) """ global _global_tracker if tracker is None: From 0d2efbc0a5d2fc139de4e06660197a44da4045e5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 17 Nov 2023 16:59:52 -0800 Subject: [PATCH 020/205] making speccing the loss function simpler --- examples/alpaca/alpaca.py | 2 +- src/levanter/main/train_lm.py | 7 ++---- src/levanter/trainer.py | 15 +++++++++---- src/levanter/types.py | 40 +++++++++++++++++++++++++++++++---- 4 files changed, 50 insertions(+), 14 deletions(-) diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 113cef91a..d8039437a 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -209,7 +209,7 @@ def train(config: TrainArgs): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss) + trainer = Trainer(config.trainer, optimizer) with trainer.device_mesh: # how we shard parameters across devices diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 187a8d92d..0a933c7c6 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -16,7 +16,7 @@ from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config -from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel +from levanter.models.lm_model import LmConfig, LmExample from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count @@ -91,13 +91,10 @@ def main(config: TrainLmConfig): compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - optimizer = config.optimizer.build(config.trainer.num_train_steps) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss) + trainer = Trainer(config.trainer, optimizer) eval_datasets = config.data.validation_sets(Pos.size) train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index aadeb97a8..754ea977e 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -36,7 +36,7 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded from levanter.logging import WandbConfig, capture_time -from levanter.types import FilterSpec +from levanter.types import FilterSpec, LossFunction, ModuleLoss from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -121,7 +121,7 @@ def __init__( self, config: "TrainerConfig", optimizer: GradientTransformation, - loss_fn: Callable, + loss_fn: Optional[LossFunction] = None, *, is_trainable: PyTree[FilterSpec] = True, ): @@ -138,9 +138,9 @@ def __init__( """ self.hooks = TrainerHooks() self.config = config - self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable + self._raw_loss_function = loss_fn or ModuleLoss() @cached_property def loss_fn(self): @@ -153,7 +153,7 @@ def loss_fn(self): def fn(model, *batch, **batch_kwargs): with hax.axis_mapping(self.compute_axis_mapping): model = self.mp.cast_to_compute(model) - return self._raw_loss_function(model, *batch, **batch_kwargs) + return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs)) return fn @@ -766,3 +766,10 @@ def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): return int(ratio_or_steps * num_train_steps) else: return int(ratio_or_steps) + + +def _ensure_scalar(x: hax.types.Scalar | hax.NamedArray) -> hax.types.Scalar: + if isinstance(x, hax.NamedArray): + return x.scalar() + else: + return x diff --git a/src/levanter/types.py b/src/levanter/types.py index 954578d27..e28499aab 100644 --- a/src/levanter/types.py +++ b/src/levanter/types.py @@ -1,17 +1,21 @@ -from typing import Any, Callable, Protocol, Tuple, TypeVar, Union +from typing import Any, Callable, Optional, Protocol, Tuple, TypeVar, Union + +import haliax as hax +from haliax.types import Scalar M = TypeVar("M") # Model +M_con = TypeVar("M_con", contravariant=True) # Model X = TypeVar("X", contravariant=True) # Input class ValAndGradFn(Protocol[M, X]): - def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: + def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[Scalar, M]: ... -class ValFn(Protocol[M, X]): - def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: +class ValFn(Protocol[M_con, X]): + def __call__(self, model: M_con, *inputs: X, **input_kwargs) -> Scalar: ... @@ -21,3 +25,31 @@ def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: treated as-is, while callables are called on each element of the pytree. If the callable returns True, the element is kept, otherwise it is filtered out. """ + + +class LossFunction(Protocol[M_con, X]): + def __call__( + self, + model: M_con, + *inputs: X, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + **kwargs, + ) -> Scalar | hax.NamedArray: + ... + + +class ModuleLoss(LossFunction[M, X]): + """ + Loss that just delegates to the model's compute_loss method. + """ + + def __call__( + self, + model, + *inputs: X, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + **kwargs, + ) -> Scalar | hax.NamedArray: + return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs) From f085287ad94b989f994c05839cea30389a985ca7 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 18 Nov 2023 13:07:52 -0800 Subject: [PATCH 021/205] stop requiring a loss function for every model definition --- docs/LoRA.md | 5 +---- examples/alpaca-lora/alpaca_lora.py | 7 ++----- examples/alpaca/alpaca.py | 3 --- src/levanter/main/lora_lm.py | 8 ++------ src/levanter/main/train_lm.py | 2 -- 5 files changed, 5 insertions(+), 20 deletions(-) diff --git a/docs/LoRA.md b/docs/LoRA.md index 06ee44c1a..ec90527da 100644 --- a/docs/LoRA.md +++ b/docs/LoRA.md @@ -82,10 +82,7 @@ def train(config: TrainArgs): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) + trainer = Trainer(config.trainer, optimizer, is_trainable=lora_param_filter) ``` ### 3. Serialize a PEFT-compatible checkpoint diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index a4380a92b..2e10b99e8 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -21,7 +21,7 @@ save_merged_hf_checkpoint_callback, save_peft_checkpoint_callback, ) -from levanter.models.lm_model import LmExample, LmHeadModel +from levanter.models.lm_model import LmHeadModel from levanter.trainer import Trainer from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle @@ -98,10 +98,7 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) + trainer = Trainer(config.trainer, optimizer, is_trainable=lora_param_filter) # end major difference from Alpaca diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index d8039437a..5f6b738ba 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -206,9 +206,6 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer) with trainer.device_mesh: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index dbe597e30..8a760a31a 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -11,7 +11,7 @@ import levanter from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter -from levanter.data.text import CausalLmDataset, LMDatasetConfig, LmExample +from levanter.data.text import CausalLmDataset, LMDatasetConfig from levanter.lora import ( LoraConfig, lora_trainable_params_filter, @@ -83,13 +83,9 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - optimizer = config.optimizer.build(config.trainer.num_train_steps) - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) + trainer = Trainer(config.trainer, optimizer, is_trainable=lora_param_filter) state = trainer.initial_state(training_key, model=model) all_param_count = parameter_count(state.model) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 0a933c7c6..74819a150 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -92,8 +92,6 @@ def main(config: TrainLmConfig): parameter_axis_mapping = config.trainer.parameter_axis_mapping optimizer = config.optimizer.build(config.trainer.num_train_steps) - - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp trainer = Trainer(config.trainer, optimizer) eval_datasets = config.data.validation_sets(Pos.size) From f21cf4b015228c462b3770fc81e688328be7a6a1 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 18 Nov 2023 17:13:50 -0800 Subject: [PATCH 022/205] wip --- src/levanter/doremi.py | 78 +++++++++++++++++++++++++++++++++-------- src/levanter/trainer.py | 12 ++++--- src/levanter/types.py | 9 +++-- 3 files changed, 78 insertions(+), 21 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index ab8db205b..9dc0192cc 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,6 +1,7 @@ from typing import Iterator, TypeVar import equinox as eqx +import jax import jax.numpy as jnp import jax.random as jrandom import optax @@ -9,23 +10,27 @@ import haliax as hax from haliax.types import IntScalar -from levanter.data import Dataset, ShardableDataset +from levanter.data import ShardableDataset from levanter.data.mixture import MixtureDataset +from levanter.logging import capture_time +from levanter.trainer import StepInfo, Trainer, TrainerState +from levanter.types import ComputeLossFunction T = TypeVar("T") def estimate_mixture_weights( - optimizer: optax.GradientTransformation, - loss_fn, + trainer: Trainer, + loss_fn: ComputeLossFunction, initial_proxy, ref, data_sources: dict[str, ShardableDataset[T]], ref_weights: dict[str, float], + *, domain_weight_step: float = 1.0, smoothing: float = 1e-3, - *, + eps_alpha: float = 1e-6, key: PRNGKeyArray, ) -> dict[str, float]: """ @@ -35,20 +40,17 @@ def estimate_mixture_weights( training_key, data_key = jrandom.split(key) domain_indices = list(data_sources.keys()) domain_to_index = {domain: index for index, domain in enumerate(domain_indices)} - tagged_mixture = domain_tagged_mixture(data_sources, ref_weights, domain_to_index, key=data_key) - - state = trainer.initial_state(training_key, model=initial_proxy) - del initial_proxy - # Initialize domain weights + # Initialize domain weights. + # TODO: should we initialize to the ref or to uniform? Domain = hax.Axis("domain", len(domain_indices)) initial_alpha = hax.ones(Domain) / Domain.size # calculate per-token losses for proxy and ref def compute_excess_loss(proxy, ref, batch): - proxy_losses = TODO - ref_losses = TODO + proxy_losses = loss_fn(proxy, batch, reduction_axis=()) + ref_losses = loss_fn(proxy, batch, reduction_axis=()) # calculate excess losses excess_losses = proxy_losses - ref_losses return excess_losses @@ -57,11 +59,15 @@ def compute_excess_loss(proxy, ref, batch): def proxy_model_loss(excess_losses, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch # basically einsum(" * -> ", alpha, one_hot_domains, excess_losses) + # TODO: I'd like to make the syntax for this nicer. einsum would be like + # einsum("d,bd,b... -> ()" ro something) + # but it's really just collapsing all axes loss = hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses).scalar() return loss - def doremi_step(opt_state, proxy, alpha, batch, domains): + @hax.named_jit(axis_resources=trainer.parameter_axis_mapping) + def doremi_step(proxy, opt_state, alpha, batch, domains): # this is one of those times when PyTorch's backward() is nice excess_losses, excess_backward = eqx.filter_vjp(lambda proxy: compute_excess_loss(proxy, ref, batch), proxy) @@ -76,16 +82,58 @@ def doremi_step(opt_state, proxy, alpha, batch, domains): alpha /= hax.sum(alpha) alpha = (1 - smoothing) * alpha + initial_alpha * smoothing + # TODO: log this alpha_distance = hax.sum(hax.abs(alpha - old_alpha)) # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) - val, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) + loss, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) grad = excess_backward(grad_loss) - updates, new_state = optimizer.update(opt_state, grad, params=proxy) + updates, new_opt_state = trainer.optimizer.update(opt_state, grad, params=proxy) proxy = optax.apply_updates(proxy, updates) - return new_state, proxy, alpha + return loss, proxy, new_opt_state, alpha, alpha_distance + + # TODO: we don't support serializing stuff from anything other than the model and the opt_state. should fix. + running_alpha_mean = initial_alpha + + # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts + with trainer: + tagged_mixture = domain_tagged_mixture(data_sources, ref_weights, domain_to_index, key=data_key) + state = trainer.initial_state(training_key, model=initial_proxy) + del initial_proxy + train_loader = iter(trainer.sharded_loader(tagged_mixture, trainer.TrainBatch)) + + if state.step > 0: + # step is after the batch, so we need to seek to step + # TODO: implement iter_data.seek(resume_step +1) + import tqdm + + for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + next(train_loader) + + while state.step < trainer.num_train_steps: + example, ex_domains = next(train_loader) + + key, new_key = jax.random.split(state.training_key) + proxy, alpha = state.model + + loss, new_model, new_optstate = doremi_step( + proxy, state.opt_state, alpha, example, ex_domains, + ) + loss = loss.item() # type: ignore + + new_info = StepInfo(TrainerState(state.step + 1, new_model, new_optstate, new_key), loss, step_time()) + + trainer.run_hooks(new_info) + + state = new_info + + + + + + def domain_tagged_mixture( diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 754ea977e..6e01ede47 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -36,7 +36,7 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded from levanter.logging import WandbConfig, capture_time -from levanter.types import FilterSpec, LossFunction, ModuleLoss +from levanter.types import ComputeLossFunction, FilterSpec, ModuleComputeLoss from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -121,7 +121,7 @@ def __init__( self, config: "TrainerConfig", optimizer: GradientTransformation, - loss_fn: Optional[LossFunction] = None, + loss_fn: Optional[ComputeLossFunction] = None, *, is_trainable: PyTree[FilterSpec] = True, ): @@ -140,7 +140,7 @@ def __init__( self.config = config self.optimizer = optimizer self.is_trainable_param = is_trainable - self._raw_loss_function = loss_fn or ModuleLoss() + self._raw_loss_function = loss_fn or ModuleComputeLoss() @cached_property def loss_fn(self): @@ -168,6 +168,10 @@ def mp(self) -> jmp.Policy: """Returns the mixed precision policy""" return self.config.mp + @property + def num_train_steps(self) -> int: + return self.config.num_train_steps + @typing.overload def add_hook(self, fn: Callable[[StepInfo], Any], *, every: int = 1): ... @@ -273,7 +277,7 @@ def training_steps( """ iter_data = iter(train_loader) - while state.step < self.config.num_train_steps: + while state.step < self.num_train_steps: with capture_time() as loading_time: example = next(iter_data) diff --git a/src/levanter/types.py b/src/levanter/types.py index e28499aab..60d7b82a0 100644 --- a/src/levanter/types.py +++ b/src/levanter/types.py @@ -27,7 +27,12 @@ def __call__(self, model: M_con, *inputs: X, **input_kwargs) -> Scalar: """ -class LossFunction(Protocol[M_con, X]): +class ComputeLossFunction(Protocol[M_con, X]): + """ + Function signature for "compute_loss" functions in Levanter: these + couple the computation of the logits and the evaluation of the loss + """ + def __call__( self, model: M_con, @@ -39,7 +44,7 @@ def __call__( ... -class ModuleLoss(LossFunction[M, X]): +class ModuleComputeLoss(ComputeLossFunction[M, X]): """ Loss that just delegates to the model's compute_loss method. """ From 01c8b87d84c6d6af7ed5801955480c7d9c48744d Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 18 Nov 2023 17:25:56 -0800 Subject: [PATCH 023/205] jkacjkac --- src/levanter/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 6e01ede47..1d6b07fc3 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -110,6 +110,10 @@ def decorator(fn: Callable[[StepInfo], None]): return decorator(fn) +# A note on subclassing Trainer. +# Trainer wasn't explicitly designed to be subclassed, though we are working on extending it in this direction + + class Trainer: config: "TrainerConfig" optimizer: GradientTransformation From e3746971a35864c2e9f26c79b566728cbc94f690 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 21 Nov 2023 23:50:55 -0600 Subject: [PATCH 024/205] tweak --- src/levanter/__init__.py | 2 +- src/levanter/callbacks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 519e387ac..d89ea4945 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -5,4 +5,4 @@ import levanter.logging as logging import levanter.tracker as tracker import levanter.visualization as visualization -from levanter.tracker import current_tracker, get_tracker, jit_log_metrics, log_metrics, log_summary +from levanter.tracker import current_tracker, get_tracker, log_metrics, log_summary diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index a80d0619e..154099e8a 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -62,7 +62,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) + levanter.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") From 921acf81842bfeb4306d2eaf004c66b1990e583c Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 22 Nov 2023 00:02:22 -0600 Subject: [PATCH 025/205] register default hooks by default... --- examples/alpaca-lora/alpaca_lora.py | 2 - examples/alpaca/alpaca.py | 1 - src/levanter/main/lora_lm.py | 110 +++++++++++++++------------- src/levanter/main/train_lm.py | 4 +- src/levanter/trainer.py | 6 +- 5 files changed, 64 insertions(+), 59 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 31de93252..c3cf71ce6 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -103,8 +103,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # end major difference from Alpaca with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) # log some info about the model diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 6ce7f06c0..eba8ab918 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -226,7 +226,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) if state.step != 0: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 4e621239e..5904cca8b 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -88,62 +88,68 @@ def compute_loss(model, example: LmExample, key=None): optimizer = config.optimizer.build(config.trainer.num_train_steps) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - state = trainer.initial_state(training_key, model=model) - - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - levanter.tracker.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } - ) - - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # data loaders - eval_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) - - train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) - train_loader = trainer.sharded_loader(train_dataset, Batch) - - # boilerplate hooks and such - trainer.add_default_hooks(eval_dataset) - trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) - if config.peft_save_path is not None: - full_save_path = os.path.join(config.peft_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload - ), - every=config.hf_save_steps, - ) - - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + eval_datasets = config.data.validation_sets(Pos.size) - # data loader. may need to seek to the right place if we're resuming - iter_data = non_caching_cycle(train_loader) + state = trainer.initial_state(training_key, model=model) - if state.step > 0: - # step is after the batch, so we need to seek to step - # TODO: implement iter_data.seek(resume_step +1) - import tqdm + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): - next(iter_data) + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) - ## OK, actually run training! - trainer.train(state, iter_data) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # data loaders + if len(eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + + for name, eval_dataset in eval_datasets.items(): + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) + trainer.add_eval_hook(eval_dataset, name=name) + + train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) + train_loader = trainer.sharded_loader(train_dataset, Batch) + + # boilerplate hooks and such + trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) + if config.peft_save_path is not None: + full_save_path = os.path.join(config.peft_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload + ), + every=config.hf_save_steps, + ) + + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + # data loader. may need to seek to the right place if we're resuming + iter_data = non_caching_cycle(train_loader) + + if state.step > 0: + # step is after the batch, so we need to seek to step + # TODO: implement iter_data.seek(resume_step +1) + import tqdm + + for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + next(iter_data) + + ## OK, actually run training! + trainer.train(state, iter_data) if __name__ == "__main__": diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 60d5dbbb6..132d1ed33 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -136,9 +136,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): } ) - # boilerplate hooks and such - trainer.add_default_hooks() - if len(eval_datasets) == 0: logger.warning("No evaluation datasets provided.") @@ -146,6 +143,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) trainer.add_eval_hook(eval_dataset, name=name) + # Register hooks trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.hf_save_path is not None: full_save_path = os.path.join(config.hf_save_path, trainer.run_id) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 4636b50d0..db7cbd3a5 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -129,6 +129,7 @@ def __init__( loss_fn: Callable, *, is_trainable: PyTree[FilterSpec] = True, + add_default_hooks: bool = True, ): """ @@ -152,6 +153,9 @@ def __init__( self.tracker = config.tracker.init(self.run_id) self._cmanagers = [] + if add_default_hooks: + self._add_default_hooks() + @cached_property def loss_fn(self): """ @@ -341,7 +345,7 @@ def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bo return info - def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): + def _add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) From c8a5d6c6d6b3efcbf647a13f1d7fc994b4fd04d7 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 24 Nov 2023 13:59:33 -0600 Subject: [PATCH 026/205] wip --- config/gpt2_nano.yaml | 4 +- examples/alpaca-lora/alpaca_lora.py | 2 +- src/levanter/checkpoint.py | 175 +++++++++++++------- src/levanter/main/lora_lm.py | 2 +- src/levanter/tensorstore_serialization.py | 2 - src/levanter/trainer.py | 190 +++++++++++----------- 6 files changed, 216 insertions(+), 159 deletions(-) diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 993302670..15d2b2482 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -1,5 +1,5 @@ -data: - id: dlwh/wikitext_103_detokenized +#data: +# id: dlwh/wikitext_103_detokenized model: type: gpt2 hidden_dim: 32 diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index c3cf71ce6..49e1ac9dc 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -107,7 +107,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # log some info about the model all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) levanter.log_summary( { diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index a9eaa0a22..26a312155 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -29,7 +29,7 @@ PathLike = Union[str, pathlib.Path] -M = TypeVar("M") +M = TypeVar("M", bound=PyTree) S = TypeVar("S") @@ -56,13 +56,13 @@ class Checkpointer: _last_temporary_checkpoint: Optional[str] = None def __init__( - self, - base_path: PathLike, - save_interval: Optional[datetime.timedelta], - step_policies: Sequence[CheckpointInterval], - *, - keep_params: PyTree[FilterSpec] = True, - dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, + self, + base_path: PathLike, + save_interval: Optional[datetime.timedelta], + step_policies: Sequence[CheckpointInterval], + *, + keep_params: PyTree[FilterSpec] = True, + dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, ): """ Class for managing checkpoints. Saves checkpoints according to two policies: time and step. @@ -102,39 +102,39 @@ def __init__( raise ValueError("Step policies must be sorted by 'until' value") def load_checkpoint( - self, - model: M, - training_state: S, - path: Optional[PathLike] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, S, int]]: + self, + state: M, + path: Optional[PathLike] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, + ) -> Optional[Tuple[M, int]]: if path is None: path = self.base_path return load_checkpoint( - model, training_state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh + state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh ) def load_model( - self, - model: M, - path: Optional[str] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, + self, + model: M, + path: Optional[str] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, ) -> Optional[Tuple[M, int]]: - if path is None: - path = self.base_path - ckpt = load_checkpoint( - model, None, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) - if ckpt is None: - return None - model, _, step = ckpt - return model, step + """ + Convenience method/holdover from previous API for loading checkpoints. + Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + """ + ret_dict = self.load_checkpoint({"model": model}, + path, + discover_latest=discover_latest, + axis_mapping=axis_mapping, + mesh=mesh) + return ret_dict["model"] def on_step(self, info, force: bool = False): step = info.step @@ -213,17 +213,15 @@ def _rm_checkpoint(self, checkpoint): cp_path = os.path.join(plain_path, checkpoint) logger.info(f"Deleting checkpoint {checkpoint} from {cp_path}") fs.rm(cp_path, recursive=True) - # don't let this take down a run except Exception: # pylint: disable=broad-except logger.exception("Failed to delete checkpoint", exc_info=True) def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - model = equinox.filter(info.model, self.keep_params) + state = equinox.partition(info.state, info.state.is_trainable) save_checkpoint( - model=model, - training_state=(info.opt_state, info.next_key), + state, step=info.step, checkpoint_path=path, ) @@ -237,7 +235,7 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") -def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike, *, exist_ok: bool = False): +def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike, *, exist_ok: bool = False): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint will be saved even if a checkpoint already exists at the given path. @@ -255,10 +253,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike, fs, plain_path = _get_fs_and_plain_path(checkpoint_path) fs.makedirs(plain_path, exist_ok=exist_ok) - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model) - if training_state is not None: - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "training_state"), training_state) - + tree_serialize_leaves_tensorstore(checkpoint_path, tree) save_metadata(checkpoint_path, fs, step) logger.info(f"Saved checkpoint for step {step}") @@ -280,13 +275,70 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( - model: M, - training_state: S, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, + tree: M, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, +) -> Optional[tuple[M, int]]: + fs: AbstractFileSystem + fs, _ = _get_fs_and_plain_path(checkpoint_path) + + checkpoint_path = str(checkpoint_path) + + if discover_latest: + checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore + + if checkpoint_path is None or not fs.exists(checkpoint_path): + return None + + logger.info(f"Loading checkpoint from {checkpoint_path}") + metadata = load_metadata(checkpoint_path, fs) + + try: + tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) + except: # noqa + from levanter.trainer import TrainerState + if not isinstance(tree, TrainerState): + raise + else: + logger.warning("Attempting to load old-style checkpoint") + model, training_state = tree.model, (tree.opt_state, tree.training_key) + + model = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh + ) + + if training_state is None: + opt_state = None + key = None + else: + training_state = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh + ) + opt_state, key = training_state + + # TODO: pretty sure this is right, but should verify + step = metadata["step"] + new_state = dataclasses.replace( + tree, # type: ignore + step=step + 1, + model=model, + opt_state=opt_state, + training_key=key) + return new_state, step + + + +def _old_load_checkpoint( + model: M, + training_state: S, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, ) -> Optional[Tuple[M, S, int]]: """ Load a checkpoint from a given path. @@ -370,10 +422,10 @@ def checkpoint_sort_key(ckpt_dir): def tree_serialise_leaves( - path: PathLike, - pytree: PyTree, - filter_spec=default_serialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, + path: PathLike, + pytree: PyTree, + filter_spec=default_serialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, ) -> None: """Analog to `equinox.tree_serialise_leaves`, but saves the leaves of a PyTree using fsspec.""" @@ -391,11 +443,11 @@ def __serialise(y): def tree_deserialise_leaves( - path: PathLike, - like: PyTree, - filter_spec=default_deserialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, - fs=None, + path: PathLike, + like: PyTree, + filter_spec=default_deserialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, + fs=None, ) -> PyTree: """ Analog to `equinox.tree_deserialise_leaves`, but loads the leaves of a PyTree using fsspec. @@ -451,13 +503,12 @@ class CheckpointerConfig: def expanded_path(self, run_id): return os.path.expanduser(os.path.join(self.base_path, run_id)) - def create(self, run_id, keep_params: PyTree[FilterSpec] = True) -> Checkpointer: + def create(self, run_id) -> Checkpointer: keeps = [CheckpointInterval(**k) for k in self.keep] return Checkpointer( base_path=self.expanded_path(run_id), save_interval=self.save_interval, step_policies=keeps, - keep_params=keep_params, ) def __post_init__(self): @@ -470,6 +521,6 @@ def __post_init__(self): if prev_interval is not None: assert prev_interval["until"] is not None, "Only the last checkpoint interval can be None" assert ( - interval["until"] is None or interval["until"] > prev_interval["until"] + interval["until"] is None or interval["until"] > prev_interval["until"] ), "Checkpoint intervals must be monotonic" prev_interval = interval diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 5904cca8b..b709dfb14 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -94,7 +94,7 @@ def compute_loss(model, example: LmExample, key=None): state = trainer.initial_state(training_key, model=model) all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) levanter.tracker.log_summary( { diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 9809665f3..78037ce3d 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -31,8 +31,6 @@ def tree_serialize_leaves_tensorstore(checkpoint_dir, pytree): specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=is_named_array) # TODO: jax array_ser has a fancy async manager thing to checkpoint while training, would be good but not right now. - # array_ser only supports saving sharded arrays, so we can't use its top-level function run_serialization. - # however we're inspired by its implementation, meaning we'll make a tree of futures and wait on them. async def _do_serialize(): futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=is_named_array) return await asyncio.gather(*jtu.tree_leaves(futures)) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index db7cbd3a5..1e089169b 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -1,5 +1,6 @@ import atexit import copy +import dataclasses import functools import logging as pylogging import os @@ -49,35 +50,33 @@ X = TypeVar("X") # Input M = TypeVar("M", bound=PyTree) -S = TypeVar("S", bound=PyTree) DEFAULT_JAX_CONFIG = { "jax_threefry_partitionable": True, "jax_softmax_custom_jvp": True, } -# A note on the semantics of "step" vs "next_step": -# The "step" of a TrainerState is the state after `step` steps have been taken. -# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. - -@dataclass -class TrainerState(Generic[M]): +# TODO: figure out how to get Generic[M] back +class TrainerState(eqx.Module): step: int model: M opt_state: OptState training_key: PRNGKeyArray + is_trainable: PyTree[FilterSpec] = eqx.field(static=True) +# A note on the semantics of "step" vs "next_step": +# The "step" of a TrainerState is the state after `step` steps have been taken. +# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @dataclass -class StepInfo(Generic[M]): - state: TrainerState[M] +class StepInfo: + state: TrainerState loss: float step_duration: float model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) - next_key = property(lambda self: self.state.training_key) step = property(lambda self: self.state.step - 1) """ @@ -118,7 +117,6 @@ class Trainer: optimizer: GradientTransformation hooks: TrainerHooks tracker: levanter.tracker.Tracker - is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable _cmanagers: List[typing.ContextManager] = [] @@ -128,7 +126,6 @@ def __init__( optimizer: GradientTransformation, loss_fn: Callable, *, - is_trainable: PyTree[FilterSpec] = True, add_default_hooks: bool = True, ): """ @@ -138,15 +135,11 @@ def __init__( optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.trainer.OptimizerConfig][] loss_fn (Callable): the loss function. This should be a function that takes a model and some inputs and returns a scalar loss. It should be jit-able and should not have any side effects. - is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable - parameters for the optimizer state and for computing gradients. Non-trainable parameters are also - not checkpointed. If you don't specify this, all parameters are assumed to be trainable. """ self.hooks = TrainerHooks() self.config = config self._raw_loss_function = loss_fn self.optimizer = optimizer - self.is_trainable_param = is_trainable if isinstance(config.tracker, Sequence): self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) else: @@ -245,11 +238,19 @@ def __exit__(self, *args): raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] def initial_state( - self, training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None + self, training_key: PRNGKeyArray, + model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None, + *, + is_trainable: PyTree[FilterSpec] = True, ) -> TrainerState: """ Initializes the model, optimizer state, and random key. Also handles loading a checkpoint if needed. + Args + is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable + parameters for the optimizer state and for computing gradients. Non-trainable parameters are also + not checkpointed. If you don't specify this, all parameters are assumed to be trainable. + Returns: model, opt_state, key, resume_step """ @@ -260,55 +261,56 @@ def initial_state( raise ValueError("one of model and model_init must be specified") if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init, is_trainable) # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) + trainable_model_shape = _trainable_params_only(model_shape, is_trainable) - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ) + trainer_state_shape = TrainerState(0, + trainable_model_shape, + opt_state_shape, + training_key, + is_trainable=is_trainable) + + ckpt = self._maybe_load_checkpoint(trainer_state_shape) if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt + opt_state = ckpt.opt_state + training_key = ckpt.training_key + step = ckpt.step + if model is not None: - model = eqx.combine(trainable_model, model) + model = eqx.combine(ckpt.model, model) elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): # if we're resuming, we need to re-initialize the non-trainable parameters to their original values non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) model = eqx.combine(trainable_model, non_trainable) else: model = trainable_model - step = completed_step + 1 else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) + model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init, is_trainable) step = 0 - return TrainerState(step, model, opt_state, training_key) + return TrainerState(step, model, opt_state, training_key, is_trainable) - def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: + def train_step(self, state: TrainerState, *batch: X, **batch_kwargs) -> StepInfo: """ Performs a single training step. """ with capture_time() as step_time, levanter.current_tracker(self.tracker): - key, new_key = jax.random.split(state.training_key) - loss, new_model, new_optstate = self._train_step_fn( - state.model, state.opt_state, *batch, **batch_kwargs, key=key + loss, new_state = self._train_step_fn( + state, *batch, **batch_kwargs, key=key ) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore - return StepInfo(TrainerState(state.step + 1, new_model, new_optstate, new_key), loss, step_time()) + return StepInfo(new_state, loss, step_time()) def training_steps( - self, state: TrainerState[M], train_loader, run_hooks: bool = True + self, state: TrainerState, train_loader, run_hooks: bool = True ) -> typing.Iterator[StepInfo]: """ Generator that yields training steps and runs hooks. @@ -332,7 +334,7 @@ def training_steps( yield info - def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: + def train(self, state: TrainerState, train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo: """ Performs training until the number of steps is reached. """ @@ -353,7 +355,7 @@ def _add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): if eval_dataset is not None: self.add_eval_hook(eval_dataset) # engine.add_hook(callbacks.log_memory_usage(), every=1) - checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) + checkpointer = self.config.checkpointer.create(self.run_id) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency def add_eval_hook(self, eval_dataset, name: Optional[str] = None): @@ -406,10 +408,12 @@ def _train_step_fn(self): @named_jit( axis_resources=self.parameter_axis_mapping, out_axis_resources=self.parameter_axis_mapping, - donate_args=(True, True), + donate_args=(True,), ) - def train_step(model, opt_state, *batch, **batch_kwargs): - model = inference_mode(model, False) + def train_step(state, *batch, **batch_kwargs): + key, new_key = jax.random.split(state.training_key) + opt_state = state.opt_state + model = inference_mode(state.model, False) # we do this so that we only take the gradients of the trainable parameters trainable_model, rest_model = self.partition_trainable_params(model) @@ -425,14 +429,22 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) model = eqx.apply_updates(model, updates) - return loss, model, opt_state + new_state = dataclasses.replace( + state, + step=state.step + 1, + model=model, + opt_state=opt_state, + training_key=new_key + ) + + return loss, new_state return train_step - def _init_model_and_opt_state(self, model_init): + def _init_model_and_opt_state(self, model_init, is_trainable): model = model_init() # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) + trainable, non_trainable = _partition_trainable_params(model, is_trainable) trainable = self.mp.cast_to_param(trainable) non_trainable = self.mp.cast_to_compute(non_trainable) model = eqx.combine(trainable, non_trainable) @@ -446,58 +458,29 @@ def _init_non_trainable_params(self, model_init): non_trainable = self.mp.cast_to_compute(non_trainable) return non_trainable - def trainable_params_only(self, model: M) -> M: - """ - Filters out non-trainable parameters from the model. This is used internally to - for the optimizer state and to compute gradients, but you can also use it to filter out - params for logging or something. - """ - return self.partition_trainable_params(model)[0] - - def partition_trainable_params(self, model): - """ - Partitions the model into trainable and non-trainable parameters. This is used internally - for the gradient calculation and checkpointing, but you can also use it to filter out params for logging - or something. - Returns: - trainable, non-trainable - """ - - def trainable_and_diffable(pred): - if callable(pred): - return lambda x: pred(x) and is_inexact_arrayish(x) - elif pred is True: - return is_inexact_arrayish - else: - return pred - - combined_mask = jax.tree_util.tree_map(trainable_and_diffable, self.is_trainable_param) - return eqx.partition(model, combined_mask) - - def maybe_load_checkpoint( - self, model: M, training_state: S, *, axis_mapping=None, mesh=None - ) -> Optional[Tuple[M, S, int]]: + def _maybe_load_checkpoint(self, state: TrainerState) -> Optional[TrainerState]: """Loads a checkpoint if one exists and we're supposed to load it, otherwise returns the model and training state as is""" - if self.config.load_checkpoint is not False: - # TODO: don't remake the checkpointer every time - checkpointer = self.config.checkpointer.create(self.run_id) - load_checkpoint_path = self.config.load_checkpoint_path + with self.device_mesh: + if self.config.load_checkpoint is not False: + # TODO: don't remake the checkpointer every time + checkpointer = self.config.checkpointer.create(self.run_id) + load_checkpoint_path = self.config.load_checkpoint_path - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + if load_checkpoint_path is None: + load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - ckpt = checkpointer.load_checkpoint( - model, training_state, load_checkpoint_path, axis_mapping=axis_mapping, mesh=mesh - ) + ckpt = checkpointer.load_checkpoint( + state, load_checkpoint_path, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh + ) - if ckpt is None and self.config.load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") + if ckpt is None and self.config.load_checkpoint is True: + raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") - return ckpt - else: - return None + return ckpt + else: + return None @dataclass @@ -804,3 +787,28 @@ def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): return int(ratio_or_steps * num_train_steps) else: return int(ratio_or_steps) + + +def _trainable_params_only(model: M, filter: PyTree[FilterSpec]) -> M: + return _partition_trainable_params(model, filter)[0] + +def _partition_trainable_params(model, filter): + """ + Partitions the model into trainable and non-trainable parameters. This is used internally + for the gradient calculation and checkpointing, but you can also use it to filter out params for logging + or something. + + Returns: + trainable, non-trainable + """ + + def trainable_and_diffable(pred): + if callable(pred): + return lambda x: pred(x) and is_inexact_arrayish(x) + elif pred is True: + return is_inexact_arrayish + else: + return pred + + combined_mask = jax.tree_util.tree_map(trainable_and_diffable, filter) + return eqx.partition(model, combined_mask) \ No newline at end of file From 639d334bba5ef57abfb4b474c4ae3e8511f37efc Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 24 Nov 2023 14:08:01 -0600 Subject: [PATCH 027/205] make it so we can evaluate if we have a cache but no sources --- src/levanter/data/text.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 5a4890efb..b106f92df 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -552,19 +552,22 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): def train_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> ShardableDataset[np.ndarray]: - return self.token_seq_dataset("train", seq_len, monitors) + ds = self.token_seq_dataset("train", seq_len, monitors) + if ds is None: + raise ValueError("No training set!") + return ds - def validation_set(self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True): - if self._has_validation_set: - return self.token_seq_dataset("validation", seq_len, monitors) - else: - return None + def validation_set( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Optional[TokenSeqDataset]: + return self.token_seq_dataset("validation", seq_len, monitors) def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Mapping[str, ShardableDataset[np.ndarray]]: - if self._has_validation_set: - return {"": self.validation_set(seq_len, monitors)} + validation_set = self.validation_set(seq_len, monitors) + if validation_set is not None: + return {"": validation_set} else: return {} @@ -585,22 +588,27 @@ def _has_validation_set(self): def token_seq_dataset( self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> TokenSeqDataset: + ) -> Optional[TokenSeqDataset]: cache = self.build_or_load_cache(split, monitors=monitors) + if cache is None: + return None return TokenSeqDataset(cache, seq_len) def build_or_load_cache( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Optional[TokenizedDocumentCache]: - source = self.get_shard_source(split) - if source is None: - return None - split_cache_dir = os.path.join(self.cache_dir, split) try: return TokenizedDocumentCache.load(split_cache_dir, flatten_docs=True) except FileNotFoundError: - logger.info(f"Building cache for {split}...") + pass + + source = self.get_shard_source(split) + if source is None: + logger.info(f"No data for {split}") + return None + + logger.info(f"Building cache for {split}...") if monitors is True: monitors = [ From ec35e9bf5a4e920bd6fc2235efa693c83b00c826 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 15:21:17 -0800 Subject: [PATCH 028/205] about got the checkpoint refactor done --- src/levanter/checkpoint.py | 131 +++++++++++----------- src/levanter/main/lora_lm.py | 128 ++++++++++----------- src/levanter/main/train_lm.py | 5 +- src/levanter/tensorstore_serialization.py | 10 +- src/levanter/trainer.py | 64 +++++------ src/levanter/utils/jax_utils.py | 7 +- tests/test_checkpoint.py | 83 ++++++-------- 7 files changed, 212 insertions(+), 216 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 26a312155..2d8681b16 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -56,13 +56,13 @@ class Checkpointer: _last_temporary_checkpoint: Optional[str] = None def __init__( - self, - base_path: PathLike, - save_interval: Optional[datetime.timedelta], - step_policies: Sequence[CheckpointInterval], - *, - keep_params: PyTree[FilterSpec] = True, - dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, + self, + base_path: PathLike, + save_interval: Optional[datetime.timedelta], + step_policies: Sequence[CheckpointInterval], + *, + keep_params: PyTree[FilterSpec] = True, + dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, ): """ Class for managing checkpoints. Saves checkpoints according to two policies: time and step. @@ -102,38 +102,36 @@ def __init__( raise ValueError("Step policies must be sorted by 'until' value") def load_checkpoint( - self, - state: M, - path: Optional[PathLike] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, int]]: + self, + state: M, + path: Optional[PathLike] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, + ) -> Optional[M]: if path is None: path = self.base_path - return load_checkpoint( - state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) + return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh) def load_model( - self, - model: M, - path: Optional[str] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, int]]: + self, + model: M, + path: Optional[str] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, + ) -> Optional[M]: """ - Convenience method/holdover from previous API for loading checkpoints. - Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + Convenience method/holdover from previous API for loading checkpoints. + Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. """ - ret_dict = self.load_checkpoint({"model": model}, - path, - discover_latest=discover_latest, - axis_mapping=axis_mapping, - mesh=mesh) + ret_dict = self.load_checkpoint( + {"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh + ) + if ret_dict is None: + return None return ret_dict["model"] def on_step(self, info, force: bool = False): @@ -219,7 +217,7 @@ def _rm_checkpoint(self, checkpoint): def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - state = equinox.partition(info.state, info.state.is_trainable) + state = equinox.filter(info.state, info.state.is_trainable) save_checkpoint( state, step=info.step, @@ -275,13 +273,13 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( - tree: M, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[tuple[M, int]]: + tree: M, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, +) -> Optional[M]: fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) @@ -298,8 +296,10 @@ def load_checkpoint( try: tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) + return tree except: # noqa from levanter.trainer import TrainerState + if not isinstance(tree, TrainerState): raise else: @@ -315,30 +315,29 @@ def load_checkpoint( key = None else: training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh + os.path.join(checkpoint_path, "training_state"), + training_state, + axis_mapping=axis_mapping, + mesh=mesh, ) opt_state, key = training_state # TODO: pretty sure this is right, but should verify step = metadata["step"] new_state = dataclasses.replace( - tree, # type: ignore - step=step + 1, - model=model, - opt_state=opt_state, - training_key=key) - return new_state, step - + tree, step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore + ) + return new_state def _old_load_checkpoint( - model: M, - training_state: S, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, + model: M, + training_state: S, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, ) -> Optional[Tuple[M, S, int]]: """ Load a checkpoint from a given path. @@ -422,10 +421,10 @@ def checkpoint_sort_key(ckpt_dir): def tree_serialise_leaves( - path: PathLike, - pytree: PyTree, - filter_spec=default_serialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, + path: PathLike, + pytree: PyTree, + filter_spec=default_serialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, ) -> None: """Analog to `equinox.tree_serialise_leaves`, but saves the leaves of a PyTree using fsspec.""" @@ -443,11 +442,11 @@ def __serialise(y): def tree_deserialise_leaves( - path: PathLike, - like: PyTree, - filter_spec=default_deserialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, - fs=None, + path: PathLike, + like: PyTree, + filter_spec=default_deserialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, + fs=None, ) -> PyTree: """ Analog to `equinox.tree_deserialise_leaves`, but loads the leaves of a PyTree using fsspec. @@ -521,6 +520,6 @@ def __post_init__(self): if prev_interval is not None: assert prev_interval["until"] is not None, "Only the last checkpoint interval can be None" assert ( - interval["until"] is None or interval["until"] > prev_interval["until"] + interval["until"] is None or interval["until"] > prev_interval["until"] ), "Checkpoint intervals must be monotonic" prev_interval = interval diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index b709dfb14..e86ef1669 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from typing import Optional +import equinox as eqx import jax.random as jrandom import haliax.random @@ -66,7 +67,12 @@ def main(config: LoraLmConfig): Pos = model_config.Pos KeyPos = model_config.KeyPos - with config.trainer.device_mesh: + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + def compute_loss(model, example: LmExample, key=None): + return model.compute_loss(example, key=key).scalar() + + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -82,74 +88,68 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - optimizer = config.optimizer.build(config.trainer.num_train_steps) - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - eval_datasets = config.data.validation_sets(Pos.size) + eval_datasets = config.data.validation_sets(Pos.size) + + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) + + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) + + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # data loaders + if len(eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + + for name, eval_dataset in eval_datasets.items(): + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) + trainer.add_eval_hook(eval_dataset, name=name) + + train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) + train_loader = trainer.sharded_loader(train_dataset, Batch) + + # boilerplate hooks and such + trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) + if config.peft_save_path is not None: + full_save_path = os.path.join(config.peft_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload + ), + every=config.hf_save_steps, + ) + + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) - state = trainer.initial_state(training_key, model=model) + # data loader. may need to seek to the right place if we're resuming + iter_data = non_caching_cycle(train_loader) - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) + if state.step > 0: + # step is after the batch, so we need to seek to step + # TODO: implement iter_data.seek(resume_step +1) + import tqdm - levanter.tracker.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } - ) + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): + next(iter_data) - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # data loaders - if len(eval_datasets) == 0: - logger.warning("No evaluation datasets provided.") - - for name, eval_dataset in eval_datasets.items(): - eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) - trainer.add_eval_hook(eval_dataset, name=name) - - train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) - train_loader = trainer.sharded_loader(train_dataset, Batch) - - # boilerplate hooks and such - trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) - if config.peft_save_path is not None: - full_save_path = os.path.join(config.peft_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload - ), - every=config.hf_save_steps, - ) - - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) - - # data loader. may need to seek to the right place if we're resuming - iter_data = non_caching_cycle(train_loader) - - if state.step > 0: - # step is after the batch, so we need to seek to step - # TODO: implement iter_data.seek(resume_step +1) - import tqdm - - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): - next(iter_data) - - ## OK, actually run training! - trainer.train(state, iter_data) + ## OK, actually run training! + trainer.train(state, iter_data) if __name__ == "__main__": diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 132d1ed33..b03745988 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,6 +5,7 @@ from typing import Optional, Union import jax.random as jrandom +import wandb import haliax as hax from haliax import Axis @@ -181,12 +182,14 @@ def compute_log_probs(model, example: LmExample): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(train_loader) ## OK, actually run training! + trainer.add_hook(lambda s: print(s.loss), every=20) trainer.train(state, train_loader) # checkpointer.on_step(last_step, force=True) + wandb.finish() if __name__ == "__main__": diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 78037ce3d..51a253163 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -26,13 +26,17 @@ logger = logging.getLogger(__name__) +def _is_named_or_none(x): + return x is None or is_named_array(x) + + def tree_serialize_leaves_tensorstore(checkpoint_dir, pytree): - leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) - specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=is_named_array) + leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=_is_named_or_none) + specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=_is_named_or_none) # TODO: jax array_ser has a fancy async manager thing to checkpoint while training, would be good but not right now. async def _do_serialize(): - futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=is_named_array) + futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=_is_named_or_none) return await asyncio.gather(*jtu.tree_leaves(futures)) asyncio.run(_do_serialize()) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 1e089169b..ada138004 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -57,8 +57,7 @@ } -# TODO: figure out how to get Generic[M] back -class TrainerState(eqx.Module): +class TrainerState(eqx.Module, Generic[M]): step: int model: M opt_state: OptState @@ -70,8 +69,8 @@ class TrainerState(eqx.Module): # The "step" of a TrainerState is the state after `step` steps have been taken. # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @dataclass -class StepInfo: - state: TrainerState +class StepInfo(Generic[M]): + state: TrainerState[M] loss: float step_duration: float @@ -238,11 +237,13 @@ def __exit__(self, *args): raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] def initial_state( - self, training_key: PRNGKeyArray, - model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None, - *, - is_trainable: PyTree[FilterSpec] = True, - ) -> TrainerState: + self, + training_key: PRNGKeyArray, + model: Optional[M] = None, + model_init: Optional[Callable[[], M]] = None, + *, + is_trainable: PyTree[FilterSpec] = True, + ) -> TrainerState[M]: """ Initializes the model, optimizer state, and random key. Also handles loading a checkpoint if needed. @@ -264,20 +265,21 @@ def initial_state( # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init, is_trainable) + model_shape, opt_state_shape = eqx.filter_eval_shape( + self._init_model_and_opt_state, model_init, is_trainable + ) # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones trainable_model_shape = _trainable_params_only(model_shape, is_trainable) - trainer_state_shape = TrainerState(0, - trainable_model_shape, - opt_state_shape, - training_key, - is_trainable=is_trainable) + trainer_state_shape: TrainerState = TrainerState( + 0, trainable_model_shape, opt_state_shape, training_key, is_trainable=is_trainable + ) ckpt = self._maybe_load_checkpoint(trainer_state_shape) if ckpt is not None: + trainable_model = ckpt.model opt_state = ckpt.opt_state training_key = ckpt.training_key step = ckpt.step @@ -291,27 +293,27 @@ def initial_state( else: model = trainable_model else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init, is_trainable) + model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)( + model_init, is_trainable + ) step = 0 return TrainerState(step, model, opt_state, training_key, is_trainable) - def train_step(self, state: TrainerState, *batch: X, **batch_kwargs) -> StepInfo: + def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ with capture_time() as step_time, levanter.current_tracker(self.tracker): - loss, new_state = self._train_step_fn( - state, *batch, **batch_kwargs, key=key - ) + loss, new_state = self._train_step_fn(state, *batch, **batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore return StepInfo(new_state, loss, step_time()) def training_steps( - self, state: TrainerState, train_loader, run_hooks: bool = True - ) -> typing.Iterator[StepInfo]: + self, state: TrainerState[M], train_loader, run_hooks: bool = True + ) -> typing.Iterator[StepInfo[M]]: """ Generator that yields training steps and runs hooks. """ @@ -334,7 +336,7 @@ def training_steps( yield info - def train(self, state: TrainerState, train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo: + def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ Performs training until the number of steps is reached. """ @@ -410,17 +412,17 @@ def _train_step_fn(self): out_axis_resources=self.parameter_axis_mapping, donate_args=(True,), ) - def train_step(state, *batch, **batch_kwargs): + def train_step(state: TrainerState, *batch, **batch_kwargs): key, new_key = jax.random.split(state.training_key) opt_state = state.opt_state model = inference_mode(state.model, False) # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = self.partition_trainable_params(model) + trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) def split_loss_fn(trainable_model, *batch, **batch_kwargs): model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs) + return self.loss_fn(model, *batch, **batch_kwargs, key=key) loss, grads = accumulate_gradients_sharded( split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping @@ -430,11 +432,7 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): model = eqx.apply_updates(model, updates) new_state = dataclasses.replace( - state, - step=state.step + 1, - model=model, - opt_state=opt_state, - training_key=new_key + state, step=state.step + 1, model=model, opt_state=opt_state, training_key=new_key ) return loss, new_state @@ -458,7 +456,6 @@ def _init_non_trainable_params(self, model_init): non_trainable = self.mp.cast_to_compute(non_trainable) return non_trainable - def _maybe_load_checkpoint(self, state: TrainerState) -> Optional[TrainerState]: """Loads a checkpoint if one exists and we're supposed to load it, otherwise returns the model and training state as is""" @@ -792,6 +789,7 @@ def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): def _trainable_params_only(model: M, filter: PyTree[FilterSpec]) -> M: return _partition_trainable_params(model, filter)[0] + def _partition_trainable_params(model, filter): """ Partitions the model into trainable and non-trainable parameters. This is used internally @@ -811,4 +809,4 @@ def trainable_and_diffable(pred): return pred combined_mask = jax.tree_util.tree_map(trainable_and_diffable, filter) - return eqx.partition(model, combined_mask) \ No newline at end of file + return eqx.partition(model, combined_mask) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 038c5e9b5..cd018819f 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -243,7 +243,12 @@ def leaf_key_paths( rec_value = rec(field, field_name) rec_values.append(rec_value) - return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values) + + _, tree_def = eqx.tree_flatten_one_level(pytree) + out = jax.tree_util.tree_unflatten(tree_def, rec_values) + return out + # this doesn't work reliably because tree_at doesn't like none values + # return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values, is_leaf=lambda x: x is None) else: leaves, treedef = jax.tree_util.tree_flatten(pytree, is_leaf=is_leaf) if len(leaves) == 1: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index b8f588df4..f181dce7f 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import pathlib import tempfile @@ -30,6 +31,7 @@ def _dummy_step_info(step): model=None, opt_state=(), training_key=(), + is_trainable=True, ), loss=0.0, step_duration=0.0, @@ -139,43 +141,42 @@ def advance_time(delta_seconds): assert _get_checkpoint_steps(tmpdir) == [2, 4, 6, 8, 10, 15, 20, 30, 40, 49] # 49 is last temporary checkpoint +def _make_state(step, key): + model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) + optim = optax.adam(1e-4) + opt_state = optim.init(arrays_only(model)) + + return TrainerState(step, model, opt_state, key, True) + + def test_checkpoint_simple(): key0 = jax.random.PRNGKey(0) key1 = jax.random.PRNGKey(1) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - optim = optax.adam(1e-4) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) - rep_model, rep_state, rep_key = make_state(key1) + initial_state = _make_state(10, key0) + rep_state = _make_state(2, key1) - assert_trees_not_close(initial_model, rep_model) + assert_trees_not_close(initial_state.model, rep_state.model) with tempfile.TemporaryDirectory() as tmpdir: save_checkpoint( - initial_model, - (initial_opt_state, initial_key), - step=10, + initial_state, + step=initial_state.step, checkpoint_path=tmpdir, exist_ok=True, ) - restored_model, (restored_optstate, rkey), step = load_checkpoint( - rep_model, - (rep_state, rep_key), + restored_state = load_checkpoint( + rep_state, checkpoint_path=tmpdir, discover_latest=False, ) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(initial_model)), + jax.tree_util.tree_leaves(arrays_only(restored_state.model)), + jax.tree_util.tree_leaves(arrays_only(initial_state.model)), ) - assert all(np.isclose(rkey, initial_key)) - assert step == 10 + assert all(np.isclose(restored_state.training_key, initial_state.training_key)) + assert restored_state.step == initial_state.step def test_checkpoint_steps(): @@ -184,13 +185,7 @@ def test_checkpoint_steps(): optim = optax.adam(1e-4) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) + initial_state = _make_state(10, key0) data = jax.random.uniform(key0, (2, 2)) @eqx.filter_grad @@ -198,41 +193,33 @@ def loss_fn(model, data): m = jax.vmap(model) return jnp.mean(jnp.square(m(data))) - model, state = initial_model, initial_opt_state + state = initial_state for i in range(3): - grad = loss_fn(model, data) - updates, state = optim.update(grad, state) - model = eqx.apply_updates(model, updates) + grad = loss_fn(state.model, data) + updates, new_state = optim.update(grad, state.opt_state) + model = eqx.apply_updates(state.model, updates) + state = dataclasses.replace(state, step=state.step + 1, model=model, opt_state=new_state) - assert_trees_not_close(model, initial_model) - assert_trees_not_close(state, initial_opt_state) + assert_trees_not_close(state, initial_state) - rep_model, rep_state, rep_key = make_state(key1) - assert_trees_not_close(model, rep_model) + rep_state = _make_state(42, key1) assert_trees_not_close(state, rep_state) with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(model, state, step=3, checkpoint_path=tmpdir, exist_ok=True) - restored_model, restored_optstate, step = load_checkpoint( - rep_model, rep_state, checkpoint_path=tmpdir, discover_latest=False - ) + save_checkpoint(state, step=3, checkpoint_path=tmpdir, exist_ok=True) + restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir, discover_latest=False) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(model)), - ) - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_optstate)), + jax.tree_util.tree_leaves(arrays_only(restored_state)), jax.tree_util.tree_leaves(arrays_only(state)), ) - assert step == 3 def test_checkpoint_discovery(): with tempfile.TemporaryDirectory() as tempdir: - save_checkpoint(model=1, training_state=2, step=10, checkpoint_path=f"{tempdir}/step-10") - save_checkpoint(model=3, training_state=4, step=20, checkpoint_path=f"{tempdir}/step-20") - save_checkpoint(model=5, training_state=6, step=30, checkpoint_path=f"{tempdir}/step-30") + save_checkpoint(dict(model=1, training_state=2), step=10, checkpoint_path=f"{tempdir}/step-10") + save_checkpoint(dict(model=3, training_state=4), step=20, checkpoint_path=f"{tempdir}/step-20") + save_checkpoint(dict(model=5, training_state=6), step=30, checkpoint_path=f"{tempdir}/step-30") latest = discover_latest_checkpoint(tempdir) assert latest == f"{tempdir}/step-30" From ed1350261b20217be2c5db799741f141dfecf48e Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 15:40:34 -0800 Subject: [PATCH 029/205] about got the checkpoint refactor done --- examples/alpaca-lora/alpaca_lora.py | 98 ++++++++++++++-------------- src/levanter/checkpoint.py | 4 ++ src/levanter/main/eval_lm.py | 10 +-- src/levanter/main/export_lm_to_hf.py | 5 +- src/levanter/main/viz_logprobs.py | 4 +- tests/test_eval_lm.py | 5 +- tests/test_export_to_hf.py | 6 +- tests/test_text.py | 12 ++-- 8 files changed, 76 insertions(+), 68 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 49e1ac9dc..3784e80fc 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Optional +import equinox as eqx import jax.random as jrandom import transformers @@ -79,7 +80,12 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - with config.trainer.device_mesh: + def compute_loss(model: LmHeadModel, example: LmExample, key=None): + return model.compute_loss(example, key=key).scalar() + + # end major difference from Alpaca + + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -97,60 +103,54 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) - # end major difference from Alpaca + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) - with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - state = trainer.initial_state(training_key, model=model) + levanter.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) - levanter.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, ) - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) - - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) - - trainer.train(state, loader) + trainer.train(state, loader) if __name__ == "__main__": diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 2d8681b16..f9c94e058 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -276,6 +276,7 @@ def load_checkpoint( tree: M, checkpoint_path: PathLike, *, + subpath: Optional[str] = None, discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, @@ -294,6 +295,9 @@ def load_checkpoint( logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) + if subpath: + checkpoint_path = os.path.join(checkpoint_path, subpath) + try: tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) return tree diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index bea7a5e2b..806127173 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -51,7 +51,11 @@ def main(config: EvalLmConfig): if config.eval_on_train: raw_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) else: - raw_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) + validation_set = config.data.validation_set(Pos.size) + if validation_set is None: + raise ValueError("Can't eval on validation_set b/c there isn't one!") + + raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch) compute_axis_mapping = config.trainer.compute_axis_mapping @@ -81,14 +85,12 @@ def compute_loss(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) + ckpt = load_checkpoint(model, config.checkpoint_path, subpath="model") assert ckpt is not None - model, _, _ = ckpt model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) - # TODO: switch to throwing instead of returning None loss = callbacks.eval_loss_loop(compute_loss, model, eval_loader, max_batches=total) del model diff --git a/src/levanter/main/export_lm_to_hf.py b/src/levanter/main/export_lm_to_hf.py index 50a8e4b92..7fd4d073d 100644 --- a/src/levanter/main/export_lm_to_hf.py +++ b/src/levanter/main/export_lm_to_hf.py @@ -51,10 +51,9 @@ def main(config: ConvertLmConfig): model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key) trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(trainable, None, config.checkpoint_path) + trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model") - assert ckpt is not None - trainable, _, _ = ckpt + assert trainable is not None model = eqx.combine(trainable, non_trainable) if config.override_vocab_size: diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index ad85a0c7d..6f8d08640 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -46,7 +46,7 @@ def main(config: VizGpt2Config): KeyPos = config.model.KeyPos eval_loader = ReplicatedBatchLoader( - CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), + CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore config.trainer.device_mesh, EvalBatch, ) @@ -83,7 +83,7 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) + ckpt = load_checkpoint(model, config.checkpoint_path) assert ckpt is not None model, _, _ = ckpt diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index 178069f26..a6bf3c8d9 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -13,6 +13,7 @@ from levanter.distributed import RayConfig from levanter.models.gpt2 import Gpt2LMHeadModel from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerState from levanter.utils.py_utils import logical_cpu_core_count @@ -43,7 +44,9 @@ def test_eval_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + state = TrainerState(0, model, model, jax.random.PRNGKey(0), True) + + save_checkpoint(state, 0, f"{f}/ckpt") config = eval_lm.EvalLmConfig( data=data_config, diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index b50bde9cb..84d3c3081 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -34,7 +34,7 @@ def test_export_lm_to_hf(): # in our trainer, we only export the trainable params trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) - save_checkpoint(trainable, None, 0, f"{tmpdir}/ckpt") + save_checkpoint({"model": trainable}, 0, f"{tmpdir}/ckpt") try: config = export_lm_to_hf.ConvertLmConfig( @@ -50,8 +50,8 @@ def test_export_lm_to_hf(): export_lm_to_hf.main(config) if has_torch(): - m = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") - print(m) + # mostly just make sure it loads + AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") finally: try: diff --git a/tests/test_text.py b/tests/test_text.py index 21e2887db..07d9436a5 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,11 +1,11 @@ +from tempfile import TemporaryDirectory + from levanter.data.text import LMDatasetConfig def test_dont_blow_up_without_validation_set(): - config = LMDatasetConfig( - train_urls=["kaa"], - validation_urls=[], - ) + with TemporaryDirectory() as td: + config = LMDatasetConfig(train_urls=["kaa"], validation_urls=[], cache_dir=f"{td}") - # mostly just making sure this doesn't blow up - assert config.validation_set(10) is None + # mostly just making sure this doesn't blow up + assert config.validation_set(10) is None From 634407ec1a9e0df988589d0fc98166ba8610df4a Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 15:42:51 -0800 Subject: [PATCH 030/205] minor dead code removal --- src/levanter/main/viz_logprobs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 6f8d08640..2b6d32406 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -51,10 +51,6 @@ def main(config: VizGpt2Config): EvalBatch, ) - # some axes we use outside the model proper - Pos = config.model.Pos - KeyPos = config.model.KeyPos - compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping From 4208e03f2bc0fb0aa81e319525c89968f625a012 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 21:04:46 -0800 Subject: [PATCH 031/205] fix tests --- src/levanter/checkpoint.py | 48 +--------- src/levanter/main/eval_lm.py | 4 +- src/levanter/main/train_lm.py | 3 - src/levanter/main/viz_logprobs.py | 5 +- src/levanter/tracker/tracker_fns.py | 4 +- src/levanter/trainer.py | 131 +++++++++++++--------------- tests/test_viz_lm.py | 2 +- 7 files changed, 68 insertions(+), 129 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index f9c94e058..eab851fa4 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -7,7 +7,7 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union import equinox import fsspec @@ -334,52 +334,6 @@ def load_checkpoint( return new_state -def _old_load_checkpoint( - model: M, - training_state: S, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[Tuple[M, S, int]]: - """ - Load a checkpoint from a given path. - - Returns the loaded model state, training state, and step. If discover_latest is True, - the latest checkpoint in the given path will be loaded. Otherwise, the checkpoint at - the given path will be loaded. If no checkpoint is found, returns None - - If training_state is None, no training state will be loaded. - """ - fs: AbstractFileSystem - fs, _ = _get_fs_and_plain_path(checkpoint_path) - - checkpoint_path = str(checkpoint_path) - - if discover_latest: - checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore - - if checkpoint_path is None or not fs.exists(checkpoint_path): - return None - - logger.info(f"Loading checkpoint from {checkpoint_path}") - metadata = load_metadata(checkpoint_path, fs) - - model = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh - ) - - if training_state is None: - training_state = None - else: - training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh - ) - - return model, training_state, metadata["step"] - - def load_metadata(checkpoint_path, fs=None): if fs is None: fs: AbstractFileSystem diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 806127173..340e7e496 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -85,9 +85,7 @@ def compute_loss(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, config.checkpoint_path, subpath="model") - - assert ckpt is not None + model = load_checkpoint(model, config.checkpoint_path, subpath="model") model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index b03745988..4d0c2f5b5 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,7 +5,6 @@ from typing import Optional, Union import jax.random as jrandom -import wandb import haliax as hax from haliax import Axis @@ -186,10 +185,8 @@ def compute_log_probs(model, example: LmExample): next(train_loader) ## OK, actually run training! - trainer.add_hook(lambda s: print(s.loss), every=20) trainer.train(state, train_loader) # checkpointer.on_step(last_step, force=True) - wandb.finish() if __name__ == "__main__": diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 2b6d32406..2bc3d43cc 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -79,10 +79,9 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, config.checkpoint_path) + model = load_checkpoint(model, config.checkpoint_path, subpath="model") - assert ckpt is not None - model, _, _ = ckpt + assert model is not None model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index fb845d53d..26d0db444 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -83,7 +83,7 @@ def current_tracker( >>> from levanter.tracker.wandb import WandbTracker >>> with current_tracker(WandbTracker()): ... log_metrics({"foo": 1}, step=0) - ... current_tracker().log_metrics({"foo": 2}, step=1) + ... current_tracker().log({"foo": 2}, step=1) """ global _global_tracker if tracker is None: @@ -136,6 +136,8 @@ def __enter__(self): self.old_tracker = _global_tracker _global_tracker = self.tracker + return self.tracker + def __exit__(self, exc_type, exc_val, exc_tb): global _global_tracker _global_tracker = self.old_tracker diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index ada138004..b9aa7565b 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -32,7 +32,7 @@ import levanter.logging import levanter.tracker import levanter.tracker.wandb -from levanter import logging, tracker +from levanter import tracker from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader @@ -65,6 +65,9 @@ class TrainerState(eqx.Module, Generic[M]): is_trainable: PyTree[FilterSpec] = eqx.field(static=True) +S = TypeVar("S", bound=TrainerState) + + # A note on the semantics of "step" vs "next_step": # The "step" of a TrainerState is the state after `step` steps have been taken. # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @@ -209,30 +212,28 @@ def EvalBatch(self): return self.config.EvalBatch def __enter__(self): - if len(self._cmanagers) > 0: - raise RuntimeError("Trainer is already entered") - - self._cmanagers = [ + this_managers = [ levanter.current_tracker(self.tracker), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] + self._cmanagers.append(this_managers) - for cmanager in self._cmanagers: + for cmanager in this_managers: cmanager.__enter__() return self def __exit__(self, *args): + assert len(self._cmanagers) > 0, "Trainer.__exit__ called without corresponding Trainer.__enter__" + cur_managers = self._cmanagers.pop() problems = [] - for cmanager in reversed(self._cmanagers): + for cmanager in reversed(cur_managers): try: cmanager.__exit__(*args) except Exception as e: problems.append(e) - self._cmanagers = [] - if len(problems) > 0: raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] @@ -255,50 +256,56 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - with levanter.tracker.current_tracker(self.tracker): - if model is not None and model_init is not None: - raise ValueError("only one of model and model_init should be specified") - elif model is None and model_init is None: - raise ValueError("one of model and model_init must be specified") - - if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials - model_init = jax.tree_util.Partial(lambda m: m, model) - - model_shape, opt_state_shape = eqx.filter_eval_shape( - self._init_model_and_opt_state, model_init, is_trainable - ) + if model is not None and model_init is not None: + raise ValueError("only one of model and model_init should be specified") + elif model is None and model_init is None: + raise ValueError("one of model and model_init must be specified") - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = _trainable_params_only(model_shape, is_trainable) + if model is not None: + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials + model_init = jax.tree_util.Partial(lambda m: m, model) - trainer_state_shape: TrainerState = TrainerState( - 0, trainable_model_shape, opt_state_shape, training_key, is_trainable=is_trainable - ) + with self: + if self.config.load_checkpoint is not False: + trainer_state_shape = eqx.filter_eval_shape( + self._initialize_state_from_scratch, model_init, training_key, is_trainable + ) - ckpt = self._maybe_load_checkpoint(trainer_state_shape) + # TODO: don't remake the checkpointer every time + checkpointer = self.config.checkpointer.create(self.run_id) + load_checkpoint_path = self.config.load_checkpoint_path - if ckpt is not None: - trainable_model = ckpt.model - opt_state = ckpt.opt_state - training_key = ckpt.training_key - step = ckpt.step + if load_checkpoint_path is None: + load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - if model is not None: - model = eqx.combine(ckpt.model, model) - elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) - else: - model = trainable_model - else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)( - model_init, is_trainable + ckpt = checkpointer.load_checkpoint( + trainer_state_shape, + load_checkpoint_path, + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, ) - step = 0 - return TrainerState(step, model, opt_state, training_key, is_trainable) + if ckpt is None: + if self.config.load_checkpoint is True: + raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") + else: + if model is not None: + model = eqx.combine(ckpt.model, model) + elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(ckpt.model)): + # if we're resuming, we need to re-initialize the non-trainable parameters to their original values + # TODO: do we want to extend this to non-model things that don't get initialized from a ckpt? + non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)( + model_init + ) + model = eqx.combine(ckpt.model, non_trainable) + else: + model = ckpt.model + + return dataclasses.replace(ckpt, model=model) + + return named_jit(self._initialize_state_from_scratch, self.parameter_axis_mapping)( + model_init, training_key, is_trainable + ) def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ @@ -439,15 +446,20 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): return train_step - def _init_model_and_opt_state(self, model_init, is_trainable): + def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): model = model_init() + # only force trainable params to param precision. Other params are cast to compute precision trainable, non_trainable = _partition_trainable_params(model, is_trainable) trainable = self.mp.cast_to_param(trainable) non_trainable = self.mp.cast_to_compute(non_trainable) model = eqx.combine(trainable, non_trainable) + opt_state = self.optimizer.init(trainable) - return model, opt_state + + trainer_state: TrainerState = TrainerState(0, model, opt_state, training_key, is_trainable=is_trainable) + + return trainer_state def _init_non_trainable_params(self, model_init): model = model_init() @@ -456,29 +468,6 @@ def _init_non_trainable_params(self, model_init): non_trainable = self.mp.cast_to_compute(non_trainable) return non_trainable - def _maybe_load_checkpoint(self, state: TrainerState) -> Optional[TrainerState]: - """Loads a checkpoint if one exists and we're supposed to load it, - otherwise returns the model and training state as is""" - with self.device_mesh: - if self.config.load_checkpoint is not False: - # TODO: don't remake the checkpointer every time - checkpointer = self.config.checkpointer.create(self.run_id) - load_checkpoint_path = self.config.load_checkpoint_path - - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - - ckpt = checkpointer.load_checkpoint( - state, load_checkpoint_path, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh - ) - - if ckpt is None and self.config.load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") - - return ckpt - else: - return None - @dataclass class TrainerConfig: diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index cf4fb74a6..25d5e8fb0 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -43,7 +43,7 @@ def test_viz_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + save_checkpoint({"model": model}, 0, f"{f}/ckpt") config = viz_logprobs.VizGpt2Config( data=data_config, From 958488403d1113091182ff0242e9983a840254b6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 21:09:51 -0800 Subject: [PATCH 032/205] cleanup --- src/levanter/checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index eab851fa4..965343433 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -30,7 +30,6 @@ PathLike = Union[str, pathlib.Path] M = TypeVar("M", bound=PyTree) -S = TypeVar("S") @dataclass(frozen=True) From 5a1867874a03c9c586948fa734257e08e2f52930 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 21:27:30 -0800 Subject: [PATCH 033/205] cleanup --- src/levanter/trainer.py | 57 ++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index b9aa7565b..c886fe85a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -28,6 +28,7 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit +from haliax.types import Scalar import levanter.logging import levanter.tracker @@ -114,6 +115,17 @@ def decorator(fn: Callable[[StepInfo], None]): return decorator(fn) +# A note on extending Trainer: +# First, consider whether you can do what you want with hooks. Hooks can cover a lot of use cases. +# Sometimes, however, you need to do something more complicated. In that case, you can extend Trainer. +# In order to do that, you need to: +# * Extend TrainerState to add your additional state +# * Override `_train_step` to add your additional logic +# * Override `initial_state` or `_initialize_state_from_scratch` to initialize your additional state. (The latter is +# simpler and means you don't need to handle the checkpointing logic yourself.) +# * You might also need to override `training_steps` if you want to make the type checker happy. + + class Trainer: config: "TrainerConfig" optimizer: GradientTransformation @@ -414,37 +426,36 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar @cached_property def _train_step_fn(self): - @named_jit( + return named_jit( axis_resources=self.parameter_axis_mapping, out_axis_resources=self.parameter_axis_mapping, donate_args=(True,), - ) - def train_step(state: TrainerState, *batch, **batch_kwargs): - key, new_key = jax.random.split(state.training_key) - opt_state = state.opt_state - model = inference_mode(state.model, False) + )(self._train_step) - # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) + def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: + key, new_key = jax.random.split(state.training_key) + opt_state = state.opt_state + model = inference_mode(state.model, False) - def split_loss_fn(trainable_model, *batch, **batch_kwargs): - model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs, key=key) + # we do this so that we only take the gradients of the trainable parameters + trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) - loss, grads = accumulate_gradients_sharded( - split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping - )(trainable_model, *batch, **batch_kwargs) + def split_loss_fn(trainable_model, *batch, **batch_kwargs): + model = eqx.combine(trainable_model, rest_model) + return self.loss_fn(model, *batch, **batch_kwargs, key=key) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) - model = eqx.apply_updates(model, updates) + loss, grads = accumulate_gradients_sharded( + split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping + )(trainable_model, *batch, **batch_kwargs) - new_state = dataclasses.replace( - state, step=state.step + 1, model=model, opt_state=opt_state, training_key=new_key - ) + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + model = eqx.apply_updates(model, updates) - return loss, new_state + new_state = dataclasses.replace( + state, step=state.step + 1, model=model, opt_state=opt_state, training_key=new_key + ) - return train_step + return loss, new_state def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): model = model_init() @@ -457,9 +468,7 @@ def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_k opt_state = self.optimizer.init(trainable) - trainer_state: TrainerState = TrainerState(0, model, opt_state, training_key, is_trainable=is_trainable) - - return trainer_state + return TrainerState(0, model, opt_state, training_key, is_trainable) def _init_non_trainable_params(self, model_init): model = model_init() From c355106ac176792e60b3ec5649b01aca3e8e495e Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 22:56:10 -0800 Subject: [PATCH 034/205] minor --- src/levanter/checkpoint.py | 17 +++++++++++++++++ src/levanter/trainer.py | 10 ++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 965343433..a633f1e09 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -280,6 +280,23 @@ def load_checkpoint( axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, ) -> Optional[M]: + """ + Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint + in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint + loads only that subpath of the checkpoint. This is useful for loading, e.g., just the model and not + the entire training state. + + Args: + tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any] + checkpoint_path: the path to load the checkpoint from + subpath: the subpath to load from the checkpoint + discover_latest: whether to discover the latest checkpoint in the given path + axis_mapping: the axis mapping to use for loading the checkpoint + mesh: the mesh to use for loading the checkpoint + Returns: + the loaded checkpoint, with the same structure as the exemplar tree + + """ fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index c886fe85a..bc078414d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -34,7 +34,7 @@ import levanter.tracker import levanter.tracker.wandb from levanter import tracker -from levanter.checkpoint import CheckpointerConfig +from levanter.checkpoint import CheckpointerConfig, load_checkpoint from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig @@ -283,14 +283,12 @@ def initial_state( self._initialize_state_from_scratch, model_init, training_key, is_trainable ) - # TODO: don't remake the checkpointer every time - checkpointer = self.config.checkpointer.create(self.run_id) load_checkpoint_path = self.config.load_checkpoint_path if load_checkpoint_path is None: load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - ckpt = checkpointer.load_checkpoint( + ckpt = load_checkpoint( trainer_state_shape, load_checkpoint_path, axis_mapping=self.parameter_axis_mapping, @@ -440,13 +438,13 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal # we do this so that we only take the gradients of the trainable parameters trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) - def split_loss_fn(trainable_model, *batch, **batch_kwargs): + def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): model = eqx.combine(trainable_model, rest_model) return self.loss_fn(model, *batch, **batch_kwargs, key=key) loss, grads = accumulate_gradients_sharded( split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping - )(trainable_model, *batch, **batch_kwargs) + )(trainable_model, rest_model, *batch, **batch_kwargs) updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) model = eqx.apply_updates(model, updates) From d2e0de12904cea77346a4341b94d8b108370939b Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 23:15:54 -0800 Subject: [PATCH 035/205] wip --- src/levanter/doremi.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 9dc0192cc..4ddb04298 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,4 +1,5 @@ -from typing import Iterator, TypeVar +import dataclasses +from typing import Callable, Iterator, TypeVar import equinox as eqx import jax @@ -12,13 +13,31 @@ from levanter.data import ShardableDataset from levanter.data.mixture import MixtureDataset -from levanter.logging import capture_time -from levanter.trainer import StepInfo, Trainer, TrainerState +from levanter.trainer import M, StepInfo, Trainer, TrainerState from levanter.types import ComputeLossFunction T = TypeVar("T") +class DoremiState(TrainerState): + alpha: hax.NamedArray + average_alpha: hax.NamedArray + + def __init__(self, step: int, model, opt_state, training_key, alpha): + super().__init__(step, model, opt_state, training_key) + self.alpha = alpha + + def update_alpha(self, alpha): + # make it stable + average_alpha = self.average_alpha + (alpha - self.average_alpha) / (self.step + 1) + return dataclasses.replace(self, alpha=alpha, average_alpha=average_alpha) + + +# class DoReMiTrainer(Trainer): +# def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): +# base_state = super()._initialize_state_from_scratch(model_init, training_key, is_trainable) +# + def estimate_mixture_weights( trainer: Trainer, From be99631ed76ebf93ceaa2bafa511a416569b91ef Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 22 Nov 2023 00:02:22 -0600 Subject: [PATCH 036/205] register default hooks by default... --- examples/alpaca-lora/alpaca_lora.py | 2 - examples/alpaca/alpaca.py | 1 - src/levanter/main/lora_lm.py | 110 +++++++++++++++------------- src/levanter/main/train_lm.py | 4 +- src/levanter/trainer.py | 6 +- 5 files changed, 64 insertions(+), 59 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 31de93252..c3cf71ce6 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -103,8 +103,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # end major difference from Alpaca with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) # log some info about the model diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 20cb98a33..d49056506 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -234,7 +234,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) if state.step != 0: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 4e621239e..5904cca8b 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -88,62 +88,68 @@ def compute_loss(model, example: LmExample, key=None): optimizer = config.optimizer.build(config.trainer.num_train_steps) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - state = trainer.initial_state(training_key, model=model) - - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - levanter.tracker.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } - ) - - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # data loaders - eval_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) - - train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) - train_loader = trainer.sharded_loader(train_dataset, Batch) - - # boilerplate hooks and such - trainer.add_default_hooks(eval_dataset) - trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) - if config.peft_save_path is not None: - full_save_path = os.path.join(config.peft_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload - ), - every=config.hf_save_steps, - ) - - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + eval_datasets = config.data.validation_sets(Pos.size) - # data loader. may need to seek to the right place if we're resuming - iter_data = non_caching_cycle(train_loader) + state = trainer.initial_state(training_key, model=model) - if state.step > 0: - # step is after the batch, so we need to seek to step - # TODO: implement iter_data.seek(resume_step +1) - import tqdm + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): - next(iter_data) + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) - ## OK, actually run training! - trainer.train(state, iter_data) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # data loaders + if len(eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + + for name, eval_dataset in eval_datasets.items(): + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) + trainer.add_eval_hook(eval_dataset, name=name) + + train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) + train_loader = trainer.sharded_loader(train_dataset, Batch) + + # boilerplate hooks and such + trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) + if config.peft_save_path is not None: + full_save_path = os.path.join(config.peft_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload + ), + every=config.hf_save_steps, + ) + + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + # data loader. may need to seek to the right place if we're resuming + iter_data = non_caching_cycle(train_loader) + + if state.step > 0: + # step is after the batch, so we need to seek to step + # TODO: implement iter_data.seek(resume_step +1) + import tqdm + + for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + next(iter_data) + + ## OK, actually run training! + trainer.train(state, iter_data) if __name__ == "__main__": diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 60d5dbbb6..132d1ed33 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -136,9 +136,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): } ) - # boilerplate hooks and such - trainer.add_default_hooks() - if len(eval_datasets) == 0: logger.warning("No evaluation datasets provided.") @@ -146,6 +143,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) trainer.add_eval_hook(eval_dataset, name=name) + # Register hooks trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.hf_save_path is not None: full_save_path = os.path.join(config.hf_save_path, trainer.run_id) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 4636b50d0..db7cbd3a5 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -129,6 +129,7 @@ def __init__( loss_fn: Callable, *, is_trainable: PyTree[FilterSpec] = True, + add_default_hooks: bool = True, ): """ @@ -152,6 +153,9 @@ def __init__( self.tracker = config.tracker.init(self.run_id) self._cmanagers = [] + if add_default_hooks: + self._add_default_hooks() + @cached_property def loss_fn(self): """ @@ -341,7 +345,7 @@ def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bo return info - def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): + def _add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) From 5d033eb74ada00ae379bb0872ab51a15e0a9f4ef Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 24 Nov 2023 13:59:33 -0600 Subject: [PATCH 037/205] wip --- config/gpt2_nano.yaml | 4 +- examples/alpaca-lora/alpaca_lora.py | 2 +- src/levanter/checkpoint.py | 175 +++++++++++++------- src/levanter/main/lora_lm.py | 2 +- src/levanter/tensorstore_serialization.py | 2 - src/levanter/trainer.py | 190 +++++++++++----------- 6 files changed, 216 insertions(+), 159 deletions(-) diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 993302670..15d2b2482 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -1,5 +1,5 @@ -data: - id: dlwh/wikitext_103_detokenized +#data: +# id: dlwh/wikitext_103_detokenized model: type: gpt2 hidden_dim: 32 diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index c3cf71ce6..49e1ac9dc 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -107,7 +107,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # log some info about the model all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) levanter.log_summary( { diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index a9eaa0a22..26a312155 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -29,7 +29,7 @@ PathLike = Union[str, pathlib.Path] -M = TypeVar("M") +M = TypeVar("M", bound=PyTree) S = TypeVar("S") @@ -56,13 +56,13 @@ class Checkpointer: _last_temporary_checkpoint: Optional[str] = None def __init__( - self, - base_path: PathLike, - save_interval: Optional[datetime.timedelta], - step_policies: Sequence[CheckpointInterval], - *, - keep_params: PyTree[FilterSpec] = True, - dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, + self, + base_path: PathLike, + save_interval: Optional[datetime.timedelta], + step_policies: Sequence[CheckpointInterval], + *, + keep_params: PyTree[FilterSpec] = True, + dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, ): """ Class for managing checkpoints. Saves checkpoints according to two policies: time and step. @@ -102,39 +102,39 @@ def __init__( raise ValueError("Step policies must be sorted by 'until' value") def load_checkpoint( - self, - model: M, - training_state: S, - path: Optional[PathLike] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, S, int]]: + self, + state: M, + path: Optional[PathLike] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, + ) -> Optional[Tuple[M, int]]: if path is None: path = self.base_path return load_checkpoint( - model, training_state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh + state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh ) def load_model( - self, - model: M, - path: Optional[str] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, + self, + model: M, + path: Optional[str] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, ) -> Optional[Tuple[M, int]]: - if path is None: - path = self.base_path - ckpt = load_checkpoint( - model, None, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) - if ckpt is None: - return None - model, _, step = ckpt - return model, step + """ + Convenience method/holdover from previous API for loading checkpoints. + Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + """ + ret_dict = self.load_checkpoint({"model": model}, + path, + discover_latest=discover_latest, + axis_mapping=axis_mapping, + mesh=mesh) + return ret_dict["model"] def on_step(self, info, force: bool = False): step = info.step @@ -213,17 +213,15 @@ def _rm_checkpoint(self, checkpoint): cp_path = os.path.join(plain_path, checkpoint) logger.info(f"Deleting checkpoint {checkpoint} from {cp_path}") fs.rm(cp_path, recursive=True) - # don't let this take down a run except Exception: # pylint: disable=broad-except logger.exception("Failed to delete checkpoint", exc_info=True) def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - model = equinox.filter(info.model, self.keep_params) + state = equinox.partition(info.state, info.state.is_trainable) save_checkpoint( - model=model, - training_state=(info.opt_state, info.next_key), + state, step=info.step, checkpoint_path=path, ) @@ -237,7 +235,7 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") -def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike, *, exist_ok: bool = False): +def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike, *, exist_ok: bool = False): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint will be saved even if a checkpoint already exists at the given path. @@ -255,10 +253,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike, fs, plain_path = _get_fs_and_plain_path(checkpoint_path) fs.makedirs(plain_path, exist_ok=exist_ok) - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model) - if training_state is not None: - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "training_state"), training_state) - + tree_serialize_leaves_tensorstore(checkpoint_path, tree) save_metadata(checkpoint_path, fs, step) logger.info(f"Saved checkpoint for step {step}") @@ -280,13 +275,70 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( - model: M, - training_state: S, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, + tree: M, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, +) -> Optional[tuple[M, int]]: + fs: AbstractFileSystem + fs, _ = _get_fs_and_plain_path(checkpoint_path) + + checkpoint_path = str(checkpoint_path) + + if discover_latest: + checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore + + if checkpoint_path is None or not fs.exists(checkpoint_path): + return None + + logger.info(f"Loading checkpoint from {checkpoint_path}") + metadata = load_metadata(checkpoint_path, fs) + + try: + tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) + except: # noqa + from levanter.trainer import TrainerState + if not isinstance(tree, TrainerState): + raise + else: + logger.warning("Attempting to load old-style checkpoint") + model, training_state = tree.model, (tree.opt_state, tree.training_key) + + model = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh + ) + + if training_state is None: + opt_state = None + key = None + else: + training_state = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh + ) + opt_state, key = training_state + + # TODO: pretty sure this is right, but should verify + step = metadata["step"] + new_state = dataclasses.replace( + tree, # type: ignore + step=step + 1, + model=model, + opt_state=opt_state, + training_key=key) + return new_state, step + + + +def _old_load_checkpoint( + model: M, + training_state: S, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, ) -> Optional[Tuple[M, S, int]]: """ Load a checkpoint from a given path. @@ -370,10 +422,10 @@ def checkpoint_sort_key(ckpt_dir): def tree_serialise_leaves( - path: PathLike, - pytree: PyTree, - filter_spec=default_serialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, + path: PathLike, + pytree: PyTree, + filter_spec=default_serialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, ) -> None: """Analog to `equinox.tree_serialise_leaves`, but saves the leaves of a PyTree using fsspec.""" @@ -391,11 +443,11 @@ def __serialise(y): def tree_deserialise_leaves( - path: PathLike, - like: PyTree, - filter_spec=default_deserialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, - fs=None, + path: PathLike, + like: PyTree, + filter_spec=default_deserialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, + fs=None, ) -> PyTree: """ Analog to `equinox.tree_deserialise_leaves`, but loads the leaves of a PyTree using fsspec. @@ -451,13 +503,12 @@ class CheckpointerConfig: def expanded_path(self, run_id): return os.path.expanduser(os.path.join(self.base_path, run_id)) - def create(self, run_id, keep_params: PyTree[FilterSpec] = True) -> Checkpointer: + def create(self, run_id) -> Checkpointer: keeps = [CheckpointInterval(**k) for k in self.keep] return Checkpointer( base_path=self.expanded_path(run_id), save_interval=self.save_interval, step_policies=keeps, - keep_params=keep_params, ) def __post_init__(self): @@ -470,6 +521,6 @@ def __post_init__(self): if prev_interval is not None: assert prev_interval["until"] is not None, "Only the last checkpoint interval can be None" assert ( - interval["until"] is None or interval["until"] > prev_interval["until"] + interval["until"] is None or interval["until"] > prev_interval["until"] ), "Checkpoint intervals must be monotonic" prev_interval = interval diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 5904cca8b..b709dfb14 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -94,7 +94,7 @@ def compute_loss(model, example: LmExample, key=None): state = trainer.initial_state(training_key, model=model) all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) levanter.tracker.log_summary( { diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 9809665f3..78037ce3d 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -31,8 +31,6 @@ def tree_serialize_leaves_tensorstore(checkpoint_dir, pytree): specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=is_named_array) # TODO: jax array_ser has a fancy async manager thing to checkpoint while training, would be good but not right now. - # array_ser only supports saving sharded arrays, so we can't use its top-level function run_serialization. - # however we're inspired by its implementation, meaning we'll make a tree of futures and wait on them. async def _do_serialize(): futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=is_named_array) return await asyncio.gather(*jtu.tree_leaves(futures)) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index db7cbd3a5..1e089169b 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -1,5 +1,6 @@ import atexit import copy +import dataclasses import functools import logging as pylogging import os @@ -49,35 +50,33 @@ X = TypeVar("X") # Input M = TypeVar("M", bound=PyTree) -S = TypeVar("S", bound=PyTree) DEFAULT_JAX_CONFIG = { "jax_threefry_partitionable": True, "jax_softmax_custom_jvp": True, } -# A note on the semantics of "step" vs "next_step": -# The "step" of a TrainerState is the state after `step` steps have been taken. -# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. - -@dataclass -class TrainerState(Generic[M]): +# TODO: figure out how to get Generic[M] back +class TrainerState(eqx.Module): step: int model: M opt_state: OptState training_key: PRNGKeyArray + is_trainable: PyTree[FilterSpec] = eqx.field(static=True) +# A note on the semantics of "step" vs "next_step": +# The "step" of a TrainerState is the state after `step` steps have been taken. +# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @dataclass -class StepInfo(Generic[M]): - state: TrainerState[M] +class StepInfo: + state: TrainerState loss: float step_duration: float model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) - next_key = property(lambda self: self.state.training_key) step = property(lambda self: self.state.step - 1) """ @@ -118,7 +117,6 @@ class Trainer: optimizer: GradientTransformation hooks: TrainerHooks tracker: levanter.tracker.Tracker - is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable _cmanagers: List[typing.ContextManager] = [] @@ -128,7 +126,6 @@ def __init__( optimizer: GradientTransformation, loss_fn: Callable, *, - is_trainable: PyTree[FilterSpec] = True, add_default_hooks: bool = True, ): """ @@ -138,15 +135,11 @@ def __init__( optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.trainer.OptimizerConfig][] loss_fn (Callable): the loss function. This should be a function that takes a model and some inputs and returns a scalar loss. It should be jit-able and should not have any side effects. - is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable - parameters for the optimizer state and for computing gradients. Non-trainable parameters are also - not checkpointed. If you don't specify this, all parameters are assumed to be trainable. """ self.hooks = TrainerHooks() self.config = config self._raw_loss_function = loss_fn self.optimizer = optimizer - self.is_trainable_param = is_trainable if isinstance(config.tracker, Sequence): self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) else: @@ -245,11 +238,19 @@ def __exit__(self, *args): raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] def initial_state( - self, training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None + self, training_key: PRNGKeyArray, + model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None, + *, + is_trainable: PyTree[FilterSpec] = True, ) -> TrainerState: """ Initializes the model, optimizer state, and random key. Also handles loading a checkpoint if needed. + Args + is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable + parameters for the optimizer state and for computing gradients. Non-trainable parameters are also + not checkpointed. If you don't specify this, all parameters are assumed to be trainable. + Returns: model, opt_state, key, resume_step """ @@ -260,55 +261,56 @@ def initial_state( raise ValueError("one of model and model_init must be specified") if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init, is_trainable) # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) + trainable_model_shape = _trainable_params_only(model_shape, is_trainable) - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ) + trainer_state_shape = TrainerState(0, + trainable_model_shape, + opt_state_shape, + training_key, + is_trainable=is_trainable) + + ckpt = self._maybe_load_checkpoint(trainer_state_shape) if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt + opt_state = ckpt.opt_state + training_key = ckpt.training_key + step = ckpt.step + if model is not None: - model = eqx.combine(trainable_model, model) + model = eqx.combine(ckpt.model, model) elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): # if we're resuming, we need to re-initialize the non-trainable parameters to their original values non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) model = eqx.combine(trainable_model, non_trainable) else: model = trainable_model - step = completed_step + 1 else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) + model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init, is_trainable) step = 0 - return TrainerState(step, model, opt_state, training_key) + return TrainerState(step, model, opt_state, training_key, is_trainable) - def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: + def train_step(self, state: TrainerState, *batch: X, **batch_kwargs) -> StepInfo: """ Performs a single training step. """ with capture_time() as step_time, levanter.current_tracker(self.tracker): - key, new_key = jax.random.split(state.training_key) - loss, new_model, new_optstate = self._train_step_fn( - state.model, state.opt_state, *batch, **batch_kwargs, key=key + loss, new_state = self._train_step_fn( + state, *batch, **batch_kwargs, key=key ) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore - return StepInfo(TrainerState(state.step + 1, new_model, new_optstate, new_key), loss, step_time()) + return StepInfo(new_state, loss, step_time()) def training_steps( - self, state: TrainerState[M], train_loader, run_hooks: bool = True + self, state: TrainerState, train_loader, run_hooks: bool = True ) -> typing.Iterator[StepInfo]: """ Generator that yields training steps and runs hooks. @@ -332,7 +334,7 @@ def training_steps( yield info - def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: + def train(self, state: TrainerState, train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo: """ Performs training until the number of steps is reached. """ @@ -353,7 +355,7 @@ def _add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): if eval_dataset is not None: self.add_eval_hook(eval_dataset) # engine.add_hook(callbacks.log_memory_usage(), every=1) - checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) + checkpointer = self.config.checkpointer.create(self.run_id) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency def add_eval_hook(self, eval_dataset, name: Optional[str] = None): @@ -406,10 +408,12 @@ def _train_step_fn(self): @named_jit( axis_resources=self.parameter_axis_mapping, out_axis_resources=self.parameter_axis_mapping, - donate_args=(True, True), + donate_args=(True,), ) - def train_step(model, opt_state, *batch, **batch_kwargs): - model = inference_mode(model, False) + def train_step(state, *batch, **batch_kwargs): + key, new_key = jax.random.split(state.training_key) + opt_state = state.opt_state + model = inference_mode(state.model, False) # we do this so that we only take the gradients of the trainable parameters trainable_model, rest_model = self.partition_trainable_params(model) @@ -425,14 +429,22 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) model = eqx.apply_updates(model, updates) - return loss, model, opt_state + new_state = dataclasses.replace( + state, + step=state.step + 1, + model=model, + opt_state=opt_state, + training_key=new_key + ) + + return loss, new_state return train_step - def _init_model_and_opt_state(self, model_init): + def _init_model_and_opt_state(self, model_init, is_trainable): model = model_init() # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) + trainable, non_trainable = _partition_trainable_params(model, is_trainable) trainable = self.mp.cast_to_param(trainable) non_trainable = self.mp.cast_to_compute(non_trainable) model = eqx.combine(trainable, non_trainable) @@ -446,58 +458,29 @@ def _init_non_trainable_params(self, model_init): non_trainable = self.mp.cast_to_compute(non_trainable) return non_trainable - def trainable_params_only(self, model: M) -> M: - """ - Filters out non-trainable parameters from the model. This is used internally to - for the optimizer state and to compute gradients, but you can also use it to filter out - params for logging or something. - """ - return self.partition_trainable_params(model)[0] - - def partition_trainable_params(self, model): - """ - Partitions the model into trainable and non-trainable parameters. This is used internally - for the gradient calculation and checkpointing, but you can also use it to filter out params for logging - or something. - Returns: - trainable, non-trainable - """ - - def trainable_and_diffable(pred): - if callable(pred): - return lambda x: pred(x) and is_inexact_arrayish(x) - elif pred is True: - return is_inexact_arrayish - else: - return pred - - combined_mask = jax.tree_util.tree_map(trainable_and_diffable, self.is_trainable_param) - return eqx.partition(model, combined_mask) - - def maybe_load_checkpoint( - self, model: M, training_state: S, *, axis_mapping=None, mesh=None - ) -> Optional[Tuple[M, S, int]]: + def _maybe_load_checkpoint(self, state: TrainerState) -> Optional[TrainerState]: """Loads a checkpoint if one exists and we're supposed to load it, otherwise returns the model and training state as is""" - if self.config.load_checkpoint is not False: - # TODO: don't remake the checkpointer every time - checkpointer = self.config.checkpointer.create(self.run_id) - load_checkpoint_path = self.config.load_checkpoint_path + with self.device_mesh: + if self.config.load_checkpoint is not False: + # TODO: don't remake the checkpointer every time + checkpointer = self.config.checkpointer.create(self.run_id) + load_checkpoint_path = self.config.load_checkpoint_path - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + if load_checkpoint_path is None: + load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - ckpt = checkpointer.load_checkpoint( - model, training_state, load_checkpoint_path, axis_mapping=axis_mapping, mesh=mesh - ) + ckpt = checkpointer.load_checkpoint( + state, load_checkpoint_path, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh + ) - if ckpt is None and self.config.load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") + if ckpt is None and self.config.load_checkpoint is True: + raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") - return ckpt - else: - return None + return ckpt + else: + return None @dataclass @@ -804,3 +787,28 @@ def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): return int(ratio_or_steps * num_train_steps) else: return int(ratio_or_steps) + + +def _trainable_params_only(model: M, filter: PyTree[FilterSpec]) -> M: + return _partition_trainable_params(model, filter)[0] + +def _partition_trainable_params(model, filter): + """ + Partitions the model into trainable and non-trainable parameters. This is used internally + for the gradient calculation and checkpointing, but you can also use it to filter out params for logging + or something. + + Returns: + trainable, non-trainable + """ + + def trainable_and_diffable(pred): + if callable(pred): + return lambda x: pred(x) and is_inexact_arrayish(x) + elif pred is True: + return is_inexact_arrayish + else: + return pred + + combined_mask = jax.tree_util.tree_map(trainable_and_diffable, filter) + return eqx.partition(model, combined_mask) \ No newline at end of file From c4a916027d82ab570357420c11a7e9e83ed09adf Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 24 Nov 2023 14:08:01 -0600 Subject: [PATCH 038/205] make it so we can evaluate if we have a cache but no sources --- src/levanter/data/text.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 4ad535114..3562efa8a 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -552,19 +552,22 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): def train_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> ShardableDataset[np.ndarray]: - return self.token_seq_dataset("train", seq_len, monitors) + ds = self.token_seq_dataset("train", seq_len, monitors) + if ds is None: + raise ValueError("No training set!") + return ds - def validation_set(self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True): - if self._has_validation_set: - return self.token_seq_dataset("validation", seq_len, monitors) - else: - return None + def validation_set( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Optional[TokenSeqDataset]: + return self.token_seq_dataset("validation", seq_len, monitors) def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Mapping[str, ShardableDataset[np.ndarray]]: - if self._has_validation_set: - return {"": self.validation_set(seq_len, monitors)} + validation_set = self.validation_set(seq_len, monitors) + if validation_set is not None: + return {"": validation_set} else: return {} @@ -585,22 +588,27 @@ def _has_validation_set(self): def token_seq_dataset( self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> TokenSeqDataset: + ) -> Optional[TokenSeqDataset]: cache = self.build_or_load_cache(split, monitors=monitors) + if cache is None: + return None return TokenSeqDataset(cache, seq_len) def build_or_load_cache( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Optional[TokenizedDocumentCache]: - source = self.get_shard_source(split) - if source is None: - return None - split_cache_dir = os.path.join(self.cache_dir, split) try: return TokenizedDocumentCache.load(split_cache_dir, flatten_docs=True) except FileNotFoundError: - logger.info(f"Building cache for {split}...") + pass + + source = self.get_shard_source(split) + if source is None: + logger.info(f"No data for {split}") + return None + + logger.info(f"Building cache for {split}...") if monitors is True: monitors = [ From b888065758ca1865fe85e62c6297756fa56974df Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 15:21:17 -0800 Subject: [PATCH 039/205] about got the checkpoint refactor done --- src/levanter/checkpoint.py | 131 +++++++++++----------- src/levanter/main/lora_lm.py | 128 ++++++++++----------- src/levanter/main/train_lm.py | 5 +- src/levanter/tensorstore_serialization.py | 10 +- src/levanter/trainer.py | 64 +++++------ src/levanter/utils/jax_utils.py | 7 +- tests/test_checkpoint.py | 83 ++++++-------- 7 files changed, 212 insertions(+), 216 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 26a312155..2d8681b16 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -56,13 +56,13 @@ class Checkpointer: _last_temporary_checkpoint: Optional[str] = None def __init__( - self, - base_path: PathLike, - save_interval: Optional[datetime.timedelta], - step_policies: Sequence[CheckpointInterval], - *, - keep_params: PyTree[FilterSpec] = True, - dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, + self, + base_path: PathLike, + save_interval: Optional[datetime.timedelta], + step_policies: Sequence[CheckpointInterval], + *, + keep_params: PyTree[FilterSpec] = True, + dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, ): """ Class for managing checkpoints. Saves checkpoints according to two policies: time and step. @@ -102,38 +102,36 @@ def __init__( raise ValueError("Step policies must be sorted by 'until' value") def load_checkpoint( - self, - state: M, - path: Optional[PathLike] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, int]]: + self, + state: M, + path: Optional[PathLike] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, + ) -> Optional[M]: if path is None: path = self.base_path - return load_checkpoint( - state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) + return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh) def load_model( - self, - model: M, - path: Optional[str] = None, - *, - discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, int]]: + self, + model: M, + path: Optional[str] = None, + *, + discover_latest: bool = True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[haliax.partitioning.Mesh] = None, + ) -> Optional[M]: """ - Convenience method/holdover from previous API for loading checkpoints. - Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + Convenience method/holdover from previous API for loading checkpoints. + Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. """ - ret_dict = self.load_checkpoint({"model": model}, - path, - discover_latest=discover_latest, - axis_mapping=axis_mapping, - mesh=mesh) + ret_dict = self.load_checkpoint( + {"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh + ) + if ret_dict is None: + return None return ret_dict["model"] def on_step(self, info, force: bool = False): @@ -219,7 +217,7 @@ def _rm_checkpoint(self, checkpoint): def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - state = equinox.partition(info.state, info.state.is_trainable) + state = equinox.filter(info.state, info.state.is_trainable) save_checkpoint( state, step=info.step, @@ -275,13 +273,13 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( - tree: M, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[tuple[M, int]]: + tree: M, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, +) -> Optional[M]: fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) @@ -298,8 +296,10 @@ def load_checkpoint( try: tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) + return tree except: # noqa from levanter.trainer import TrainerState + if not isinstance(tree, TrainerState): raise else: @@ -315,30 +315,29 @@ def load_checkpoint( key = None else: training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh + os.path.join(checkpoint_path, "training_state"), + training_state, + axis_mapping=axis_mapping, + mesh=mesh, ) opt_state, key = training_state # TODO: pretty sure this is right, but should verify step = metadata["step"] new_state = dataclasses.replace( - tree, # type: ignore - step=step + 1, - model=model, - opt_state=opt_state, - training_key=key) - return new_state, step - + tree, step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore + ) + return new_state def _old_load_checkpoint( - model: M, - training_state: S, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, + model: M, + training_state: S, + checkpoint_path: PathLike, + *, + discover_latest=True, + axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, + mesh: Optional[jax.sharding.Mesh] = None, ) -> Optional[Tuple[M, S, int]]: """ Load a checkpoint from a given path. @@ -422,10 +421,10 @@ def checkpoint_sort_key(ckpt_dir): def tree_serialise_leaves( - path: PathLike, - pytree: PyTree, - filter_spec=default_serialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, + path: PathLike, + pytree: PyTree, + filter_spec=default_serialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, ) -> None: """Analog to `equinox.tree_serialise_leaves`, but saves the leaves of a PyTree using fsspec.""" @@ -443,11 +442,11 @@ def __serialise(y): def tree_deserialise_leaves( - path: PathLike, - like: PyTree, - filter_spec=default_deserialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, - fs=None, + path: PathLike, + like: PyTree, + filter_spec=default_deserialise_filter_spec, + is_leaf: Optional[Callable[[Any], bool]] = None, + fs=None, ) -> PyTree: """ Analog to `equinox.tree_deserialise_leaves`, but loads the leaves of a PyTree using fsspec. @@ -521,6 +520,6 @@ def __post_init__(self): if prev_interval is not None: assert prev_interval["until"] is not None, "Only the last checkpoint interval can be None" assert ( - interval["until"] is None or interval["until"] > prev_interval["until"] + interval["until"] is None or interval["until"] > prev_interval["until"] ), "Checkpoint intervals must be monotonic" prev_interval = interval diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index b709dfb14..e86ef1669 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from typing import Optional +import equinox as eqx import jax.random as jrandom import haliax.random @@ -66,7 +67,12 @@ def main(config: LoraLmConfig): Pos = model_config.Pos KeyPos = model_config.KeyPos - with config.trainer.device_mesh: + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + def compute_loss(model, example: LmExample, key=None): + return model.compute_loss(example, key=key).scalar() + + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -82,74 +88,68 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - optimizer = config.optimizer.build(config.trainer.num_train_steps) - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - eval_datasets = config.data.validation_sets(Pos.size) + eval_datasets = config.data.validation_sets(Pos.size) + + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) + + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) + + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # data loaders + if len(eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + + for name, eval_dataset in eval_datasets.items(): + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) + trainer.add_eval_hook(eval_dataset, name=name) + + train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) + train_loader = trainer.sharded_loader(train_dataset, Batch) + + # boilerplate hooks and such + trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) + if config.peft_save_path is not None: + full_save_path = os.path.join(config.peft_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload + ), + every=config.hf_save_steps, + ) + + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) - state = trainer.initial_state(training_key, model=model) + # data loader. may need to seek to the right place if we're resuming + iter_data = non_caching_cycle(train_loader) - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) + if state.step > 0: + # step is after the batch, so we need to seek to step + # TODO: implement iter_data.seek(resume_step +1) + import tqdm - levanter.tracker.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } - ) + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): + next(iter_data) - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # data loaders - if len(eval_datasets) == 0: - logger.warning("No evaluation datasets provided.") - - for name, eval_dataset in eval_datasets.items(): - eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) - trainer.add_eval_hook(eval_dataset, name=name) - - train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) - train_loader = trainer.sharded_loader(train_dataset, Batch) - - # boilerplate hooks and such - trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) - if config.peft_save_path is not None: - full_save_path = os.path.join(config.peft_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload - ), - every=config.hf_save_steps, - ) - - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) - - # data loader. may need to seek to the right place if we're resuming - iter_data = non_caching_cycle(train_loader) - - if state.step > 0: - # step is after the batch, so we need to seek to step - # TODO: implement iter_data.seek(resume_step +1) - import tqdm - - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): - next(iter_data) - - ## OK, actually run training! - trainer.train(state, iter_data) + ## OK, actually run training! + trainer.train(state, iter_data) if __name__ == "__main__": diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 132d1ed33..b03745988 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,6 +5,7 @@ from typing import Optional, Union import jax.random as jrandom +import wandb import haliax as hax from haliax import Axis @@ -181,12 +182,14 @@ def compute_log_probs(model, example: LmExample): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(train_loader) ## OK, actually run training! + trainer.add_hook(lambda s: print(s.loss), every=20) trainer.train(state, train_loader) # checkpointer.on_step(last_step, force=True) + wandb.finish() if __name__ == "__main__": diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 78037ce3d..51a253163 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -26,13 +26,17 @@ logger = logging.getLogger(__name__) +def _is_named_or_none(x): + return x is None or is_named_array(x) + + def tree_serialize_leaves_tensorstore(checkpoint_dir, pytree): - leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) - specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=is_named_array) + leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=_is_named_or_none) + specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=_is_named_or_none) # TODO: jax array_ser has a fancy async manager thing to checkpoint while training, would be good but not right now. async def _do_serialize(): - futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=is_named_array) + futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=_is_named_or_none) return await asyncio.gather(*jtu.tree_leaves(futures)) asyncio.run(_do_serialize()) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 1e089169b..ada138004 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -57,8 +57,7 @@ } -# TODO: figure out how to get Generic[M] back -class TrainerState(eqx.Module): +class TrainerState(eqx.Module, Generic[M]): step: int model: M opt_state: OptState @@ -70,8 +69,8 @@ class TrainerState(eqx.Module): # The "step" of a TrainerState is the state after `step` steps have been taken. # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @dataclass -class StepInfo: - state: TrainerState +class StepInfo(Generic[M]): + state: TrainerState[M] loss: float step_duration: float @@ -238,11 +237,13 @@ def __exit__(self, *args): raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] def initial_state( - self, training_key: PRNGKeyArray, - model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None, - *, - is_trainable: PyTree[FilterSpec] = True, - ) -> TrainerState: + self, + training_key: PRNGKeyArray, + model: Optional[M] = None, + model_init: Optional[Callable[[], M]] = None, + *, + is_trainable: PyTree[FilterSpec] = True, + ) -> TrainerState[M]: """ Initializes the model, optimizer state, and random key. Also handles loading a checkpoint if needed. @@ -264,20 +265,21 @@ def initial_state( # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init, is_trainable) + model_shape, opt_state_shape = eqx.filter_eval_shape( + self._init_model_and_opt_state, model_init, is_trainable + ) # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones trainable_model_shape = _trainable_params_only(model_shape, is_trainable) - trainer_state_shape = TrainerState(0, - trainable_model_shape, - opt_state_shape, - training_key, - is_trainable=is_trainable) + trainer_state_shape: TrainerState = TrainerState( + 0, trainable_model_shape, opt_state_shape, training_key, is_trainable=is_trainable + ) ckpt = self._maybe_load_checkpoint(trainer_state_shape) if ckpt is not None: + trainable_model = ckpt.model opt_state = ckpt.opt_state training_key = ckpt.training_key step = ckpt.step @@ -291,27 +293,27 @@ def initial_state( else: model = trainable_model else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init, is_trainable) + model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)( + model_init, is_trainable + ) step = 0 return TrainerState(step, model, opt_state, training_key, is_trainable) - def train_step(self, state: TrainerState, *batch: X, **batch_kwargs) -> StepInfo: + def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ with capture_time() as step_time, levanter.current_tracker(self.tracker): - loss, new_state = self._train_step_fn( - state, *batch, **batch_kwargs, key=key - ) + loss, new_state = self._train_step_fn(state, *batch, **batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore return StepInfo(new_state, loss, step_time()) def training_steps( - self, state: TrainerState, train_loader, run_hooks: bool = True - ) -> typing.Iterator[StepInfo]: + self, state: TrainerState[M], train_loader, run_hooks: bool = True + ) -> typing.Iterator[StepInfo[M]]: """ Generator that yields training steps and runs hooks. """ @@ -334,7 +336,7 @@ def training_steps( yield info - def train(self, state: TrainerState, train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo: + def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ Performs training until the number of steps is reached. """ @@ -410,17 +412,17 @@ def _train_step_fn(self): out_axis_resources=self.parameter_axis_mapping, donate_args=(True,), ) - def train_step(state, *batch, **batch_kwargs): + def train_step(state: TrainerState, *batch, **batch_kwargs): key, new_key = jax.random.split(state.training_key) opt_state = state.opt_state model = inference_mode(state.model, False) # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = self.partition_trainable_params(model) + trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) def split_loss_fn(trainable_model, *batch, **batch_kwargs): model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs) + return self.loss_fn(model, *batch, **batch_kwargs, key=key) loss, grads = accumulate_gradients_sharded( split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping @@ -430,11 +432,7 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): model = eqx.apply_updates(model, updates) new_state = dataclasses.replace( - state, - step=state.step + 1, - model=model, - opt_state=opt_state, - training_key=new_key + state, step=state.step + 1, model=model, opt_state=opt_state, training_key=new_key ) return loss, new_state @@ -458,7 +456,6 @@ def _init_non_trainable_params(self, model_init): non_trainable = self.mp.cast_to_compute(non_trainable) return non_trainable - def _maybe_load_checkpoint(self, state: TrainerState) -> Optional[TrainerState]: """Loads a checkpoint if one exists and we're supposed to load it, otherwise returns the model and training state as is""" @@ -792,6 +789,7 @@ def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): def _trainable_params_only(model: M, filter: PyTree[FilterSpec]) -> M: return _partition_trainable_params(model, filter)[0] + def _partition_trainable_params(model, filter): """ Partitions the model into trainable and non-trainable parameters. This is used internally @@ -811,4 +809,4 @@ def trainable_and_diffable(pred): return pred combined_mask = jax.tree_util.tree_map(trainable_and_diffable, filter) - return eqx.partition(model, combined_mask) \ No newline at end of file + return eqx.partition(model, combined_mask) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 038c5e9b5..cd018819f 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -243,7 +243,12 @@ def leaf_key_paths( rec_value = rec(field, field_name) rec_values.append(rec_value) - return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values) + + _, tree_def = eqx.tree_flatten_one_level(pytree) + out = jax.tree_util.tree_unflatten(tree_def, rec_values) + return out + # this doesn't work reliably because tree_at doesn't like none values + # return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values, is_leaf=lambda x: x is None) else: leaves, treedef = jax.tree_util.tree_flatten(pytree, is_leaf=is_leaf) if len(leaves) == 1: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index b8f588df4..f181dce7f 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import pathlib import tempfile @@ -30,6 +31,7 @@ def _dummy_step_info(step): model=None, opt_state=(), training_key=(), + is_trainable=True, ), loss=0.0, step_duration=0.0, @@ -139,43 +141,42 @@ def advance_time(delta_seconds): assert _get_checkpoint_steps(tmpdir) == [2, 4, 6, 8, 10, 15, 20, 30, 40, 49] # 49 is last temporary checkpoint +def _make_state(step, key): + model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) + optim = optax.adam(1e-4) + opt_state = optim.init(arrays_only(model)) + + return TrainerState(step, model, opt_state, key, True) + + def test_checkpoint_simple(): key0 = jax.random.PRNGKey(0) key1 = jax.random.PRNGKey(1) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - optim = optax.adam(1e-4) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) - rep_model, rep_state, rep_key = make_state(key1) + initial_state = _make_state(10, key0) + rep_state = _make_state(2, key1) - assert_trees_not_close(initial_model, rep_model) + assert_trees_not_close(initial_state.model, rep_state.model) with tempfile.TemporaryDirectory() as tmpdir: save_checkpoint( - initial_model, - (initial_opt_state, initial_key), - step=10, + initial_state, + step=initial_state.step, checkpoint_path=tmpdir, exist_ok=True, ) - restored_model, (restored_optstate, rkey), step = load_checkpoint( - rep_model, - (rep_state, rep_key), + restored_state = load_checkpoint( + rep_state, checkpoint_path=tmpdir, discover_latest=False, ) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(initial_model)), + jax.tree_util.tree_leaves(arrays_only(restored_state.model)), + jax.tree_util.tree_leaves(arrays_only(initial_state.model)), ) - assert all(np.isclose(rkey, initial_key)) - assert step == 10 + assert all(np.isclose(restored_state.training_key, initial_state.training_key)) + assert restored_state.step == initial_state.step def test_checkpoint_steps(): @@ -184,13 +185,7 @@ def test_checkpoint_steps(): optim = optax.adam(1e-4) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) + initial_state = _make_state(10, key0) data = jax.random.uniform(key0, (2, 2)) @eqx.filter_grad @@ -198,41 +193,33 @@ def loss_fn(model, data): m = jax.vmap(model) return jnp.mean(jnp.square(m(data))) - model, state = initial_model, initial_opt_state + state = initial_state for i in range(3): - grad = loss_fn(model, data) - updates, state = optim.update(grad, state) - model = eqx.apply_updates(model, updates) + grad = loss_fn(state.model, data) + updates, new_state = optim.update(grad, state.opt_state) + model = eqx.apply_updates(state.model, updates) + state = dataclasses.replace(state, step=state.step + 1, model=model, opt_state=new_state) - assert_trees_not_close(model, initial_model) - assert_trees_not_close(state, initial_opt_state) + assert_trees_not_close(state, initial_state) - rep_model, rep_state, rep_key = make_state(key1) - assert_trees_not_close(model, rep_model) + rep_state = _make_state(42, key1) assert_trees_not_close(state, rep_state) with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(model, state, step=3, checkpoint_path=tmpdir, exist_ok=True) - restored_model, restored_optstate, step = load_checkpoint( - rep_model, rep_state, checkpoint_path=tmpdir, discover_latest=False - ) + save_checkpoint(state, step=3, checkpoint_path=tmpdir, exist_ok=True) + restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir, discover_latest=False) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(model)), - ) - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_optstate)), + jax.tree_util.tree_leaves(arrays_only(restored_state)), jax.tree_util.tree_leaves(arrays_only(state)), ) - assert step == 3 def test_checkpoint_discovery(): with tempfile.TemporaryDirectory() as tempdir: - save_checkpoint(model=1, training_state=2, step=10, checkpoint_path=f"{tempdir}/step-10") - save_checkpoint(model=3, training_state=4, step=20, checkpoint_path=f"{tempdir}/step-20") - save_checkpoint(model=5, training_state=6, step=30, checkpoint_path=f"{tempdir}/step-30") + save_checkpoint(dict(model=1, training_state=2), step=10, checkpoint_path=f"{tempdir}/step-10") + save_checkpoint(dict(model=3, training_state=4), step=20, checkpoint_path=f"{tempdir}/step-20") + save_checkpoint(dict(model=5, training_state=6), step=30, checkpoint_path=f"{tempdir}/step-30") latest = discover_latest_checkpoint(tempdir) assert latest == f"{tempdir}/step-30" From c47ae97a1d68fc44c5c9aaa0fd0dea706bfdf73d Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 15:40:34 -0800 Subject: [PATCH 040/205] about got the checkpoint refactor done --- examples/alpaca-lora/alpaca_lora.py | 98 ++++++++++++++-------------- src/levanter/checkpoint.py | 4 ++ src/levanter/main/eval_lm.py | 10 +-- src/levanter/main/export_lm_to_hf.py | 5 +- src/levanter/main/viz_logprobs.py | 4 +- tests/test_eval_lm.py | 5 +- tests/test_export_to_hf.py | 6 +- tests/test_text.py | 12 ++-- 8 files changed, 76 insertions(+), 68 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 49e1ac9dc..3784e80fc 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Optional +import equinox as eqx import jax.random as jrandom import transformers @@ -79,7 +80,12 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - with config.trainer.device_mesh: + def compute_loss(model: LmHeadModel, example: LmExample, key=None): + return model.compute_loss(example, key=key).scalar() + + # end major difference from Alpaca + + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -97,60 +103,54 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) - # end major difference from Alpaca + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) - with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - state = trainer.initial_state(training_key, model=model) + levanter.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer._trainable_params_only(state.model)) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) - levanter.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, ) - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) - - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) - - trainer.train(state, loader) + trainer.train(state, loader) if __name__ == "__main__": diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 2d8681b16..f9c94e058 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -276,6 +276,7 @@ def load_checkpoint( tree: M, checkpoint_path: PathLike, *, + subpath: Optional[str] = None, discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, @@ -294,6 +295,9 @@ def load_checkpoint( logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) + if subpath: + checkpoint_path = os.path.join(checkpoint_path, subpath) + try: tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) return tree diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index bea7a5e2b..806127173 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -51,7 +51,11 @@ def main(config: EvalLmConfig): if config.eval_on_train: raw_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) else: - raw_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) + validation_set = config.data.validation_set(Pos.size) + if validation_set is None: + raise ValueError("Can't eval on validation_set b/c there isn't one!") + + raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch) compute_axis_mapping = config.trainer.compute_axis_mapping @@ -81,14 +85,12 @@ def compute_loss(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) + ckpt = load_checkpoint(model, config.checkpoint_path, subpath="model") assert ckpt is not None - model, _, _ = ckpt model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) - # TODO: switch to throwing instead of returning None loss = callbacks.eval_loss_loop(compute_loss, model, eval_loader, max_batches=total) del model diff --git a/src/levanter/main/export_lm_to_hf.py b/src/levanter/main/export_lm_to_hf.py index 50a8e4b92..7fd4d073d 100644 --- a/src/levanter/main/export_lm_to_hf.py +++ b/src/levanter/main/export_lm_to_hf.py @@ -51,10 +51,9 @@ def main(config: ConvertLmConfig): model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key) trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(trainable, None, config.checkpoint_path) + trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model") - assert ckpt is not None - trainable, _, _ = ckpt + assert trainable is not None model = eqx.combine(trainable, non_trainable) if config.override_vocab_size: diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index ad85a0c7d..6f8d08640 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -46,7 +46,7 @@ def main(config: VizGpt2Config): KeyPos = config.model.KeyPos eval_loader = ReplicatedBatchLoader( - CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), + CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore config.trainer.device_mesh, EvalBatch, ) @@ -83,7 +83,7 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) + ckpt = load_checkpoint(model, config.checkpoint_path) assert ckpt is not None model, _, _ = ckpt diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index 178069f26..a6bf3c8d9 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -13,6 +13,7 @@ from levanter.distributed import RayConfig from levanter.models.gpt2 import Gpt2LMHeadModel from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerState from levanter.utils.py_utils import logical_cpu_core_count @@ -43,7 +44,9 @@ def test_eval_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + state = TrainerState(0, model, model, jax.random.PRNGKey(0), True) + + save_checkpoint(state, 0, f"{f}/ckpt") config = eval_lm.EvalLmConfig( data=data_config, diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index b50bde9cb..84d3c3081 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -34,7 +34,7 @@ def test_export_lm_to_hf(): # in our trainer, we only export the trainable params trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) - save_checkpoint(trainable, None, 0, f"{tmpdir}/ckpt") + save_checkpoint({"model": trainable}, 0, f"{tmpdir}/ckpt") try: config = export_lm_to_hf.ConvertLmConfig( @@ -50,8 +50,8 @@ def test_export_lm_to_hf(): export_lm_to_hf.main(config) if has_torch(): - m = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") - print(m) + # mostly just make sure it loads + AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") finally: try: diff --git a/tests/test_text.py b/tests/test_text.py index 21e2887db..07d9436a5 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,11 +1,11 @@ +from tempfile import TemporaryDirectory + from levanter.data.text import LMDatasetConfig def test_dont_blow_up_without_validation_set(): - config = LMDatasetConfig( - train_urls=["kaa"], - validation_urls=[], - ) + with TemporaryDirectory() as td: + config = LMDatasetConfig(train_urls=["kaa"], validation_urls=[], cache_dir=f"{td}") - # mostly just making sure this doesn't blow up - assert config.validation_set(10) is None + # mostly just making sure this doesn't blow up + assert config.validation_set(10) is None From f0613c789f6e95e3dc9023f1a9a8dc8787db61c6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 15:42:51 -0800 Subject: [PATCH 041/205] minor dead code removal --- src/levanter/main/viz_logprobs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 6f8d08640..2b6d32406 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -51,10 +51,6 @@ def main(config: VizGpt2Config): EvalBatch, ) - # some axes we use outside the model proper - Pos = config.model.Pos - KeyPos = config.model.KeyPos - compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping From 85c56780d189f0407fd998f237bf56e0ed105d4c Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 21:04:46 -0800 Subject: [PATCH 042/205] fix tests --- src/levanter/checkpoint.py | 48 +--------- src/levanter/main/eval_lm.py | 4 +- src/levanter/main/train_lm.py | 3 - src/levanter/main/viz_logprobs.py | 5 +- src/levanter/tracker/tracker_fns.py | 4 +- src/levanter/trainer.py | 131 +++++++++++++--------------- tests/test_viz_lm.py | 2 +- 7 files changed, 68 insertions(+), 129 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index f9c94e058..eab851fa4 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -7,7 +7,7 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union import equinox import fsspec @@ -334,52 +334,6 @@ def load_checkpoint( return new_state -def _old_load_checkpoint( - model: M, - training_state: S, - checkpoint_path: PathLike, - *, - discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[Tuple[M, S, int]]: - """ - Load a checkpoint from a given path. - - Returns the loaded model state, training state, and step. If discover_latest is True, - the latest checkpoint in the given path will be loaded. Otherwise, the checkpoint at - the given path will be loaded. If no checkpoint is found, returns None - - If training_state is None, no training state will be loaded. - """ - fs: AbstractFileSystem - fs, _ = _get_fs_and_plain_path(checkpoint_path) - - checkpoint_path = str(checkpoint_path) - - if discover_latest: - checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore - - if checkpoint_path is None or not fs.exists(checkpoint_path): - return None - - logger.info(f"Loading checkpoint from {checkpoint_path}") - metadata = load_metadata(checkpoint_path, fs) - - model = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh - ) - - if training_state is None: - training_state = None - else: - training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh - ) - - return model, training_state, metadata["step"] - - def load_metadata(checkpoint_path, fs=None): if fs is None: fs: AbstractFileSystem diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 806127173..340e7e496 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -85,9 +85,7 @@ def compute_loss(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, config.checkpoint_path, subpath="model") - - assert ckpt is not None + model = load_checkpoint(model, config.checkpoint_path, subpath="model") model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index b03745988..4d0c2f5b5 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,7 +5,6 @@ from typing import Optional, Union import jax.random as jrandom -import wandb import haliax as hax from haliax import Axis @@ -186,10 +185,8 @@ def compute_log_probs(model, example: LmExample): next(train_loader) ## OK, actually run training! - trainer.add_hook(lambda s: print(s.loss), every=20) trainer.train(state, train_loader) # checkpointer.on_step(last_step, force=True) - wandb.finish() if __name__ == "__main__": diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 2b6d32406..2bc3d43cc 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -79,10 +79,9 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, config.checkpoint_path) + model = load_checkpoint(model, config.checkpoint_path, subpath="model") - assert ckpt is not None - model, _, _ = ckpt + assert model is not None model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index 69ab4ca0b..70d59cacb 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -84,7 +84,7 @@ def current_tracker( >>> from levanter.tracker.wandb import WandbTracker >>> with current_tracker(WandbTracker()): ... log_metrics({"foo": 1}, step=0) - ... current_tracker().log_metrics({"foo": 2}, step=1) + ... current_tracker().log({"foo": 2}, step=1) """ global _global_tracker if tracker is None: @@ -147,6 +147,8 @@ def __enter__(self): self.old_tracker = _global_tracker _global_tracker = self.tracker + return self.tracker + def __exit__(self, exc_type, exc_val, exc_tb): global _global_tracker _global_tracker = self.old_tracker diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index ada138004..b9aa7565b 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -32,7 +32,7 @@ import levanter.logging import levanter.tracker import levanter.tracker.wandb -from levanter import logging, tracker +from levanter import tracker from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader @@ -65,6 +65,9 @@ class TrainerState(eqx.Module, Generic[M]): is_trainable: PyTree[FilterSpec] = eqx.field(static=True) +S = TypeVar("S", bound=TrainerState) + + # A note on the semantics of "step" vs "next_step": # The "step" of a TrainerState is the state after `step` steps have been taken. # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @@ -209,30 +212,28 @@ def EvalBatch(self): return self.config.EvalBatch def __enter__(self): - if len(self._cmanagers) > 0: - raise RuntimeError("Trainer is already entered") - - self._cmanagers = [ + this_managers = [ levanter.current_tracker(self.tracker), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] + self._cmanagers.append(this_managers) - for cmanager in self._cmanagers: + for cmanager in this_managers: cmanager.__enter__() return self def __exit__(self, *args): + assert len(self._cmanagers) > 0, "Trainer.__exit__ called without corresponding Trainer.__enter__" + cur_managers = self._cmanagers.pop() problems = [] - for cmanager in reversed(self._cmanagers): + for cmanager in reversed(cur_managers): try: cmanager.__exit__(*args) except Exception as e: problems.append(e) - self._cmanagers = [] - if len(problems) > 0: raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] @@ -255,50 +256,56 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - with levanter.tracker.current_tracker(self.tracker): - if model is not None and model_init is not None: - raise ValueError("only one of model and model_init should be specified") - elif model is None and model_init is None: - raise ValueError("one of model and model_init must be specified") - - if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials - model_init = jax.tree_util.Partial(lambda m: m, model) - - model_shape, opt_state_shape = eqx.filter_eval_shape( - self._init_model_and_opt_state, model_init, is_trainable - ) + if model is not None and model_init is not None: + raise ValueError("only one of model and model_init should be specified") + elif model is None and model_init is None: + raise ValueError("one of model and model_init must be specified") - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = _trainable_params_only(model_shape, is_trainable) + if model is not None: + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials + model_init = jax.tree_util.Partial(lambda m: m, model) - trainer_state_shape: TrainerState = TrainerState( - 0, trainable_model_shape, opt_state_shape, training_key, is_trainable=is_trainable - ) + with self: + if self.config.load_checkpoint is not False: + trainer_state_shape = eqx.filter_eval_shape( + self._initialize_state_from_scratch, model_init, training_key, is_trainable + ) - ckpt = self._maybe_load_checkpoint(trainer_state_shape) + # TODO: don't remake the checkpointer every time + checkpointer = self.config.checkpointer.create(self.run_id) + load_checkpoint_path = self.config.load_checkpoint_path - if ckpt is not None: - trainable_model = ckpt.model - opt_state = ckpt.opt_state - training_key = ckpt.training_key - step = ckpt.step + if load_checkpoint_path is None: + load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - if model is not None: - model = eqx.combine(ckpt.model, model) - elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) - else: - model = trainable_model - else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)( - model_init, is_trainable + ckpt = checkpointer.load_checkpoint( + trainer_state_shape, + load_checkpoint_path, + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, ) - step = 0 - return TrainerState(step, model, opt_state, training_key, is_trainable) + if ckpt is None: + if self.config.load_checkpoint is True: + raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") + else: + if model is not None: + model = eqx.combine(ckpt.model, model) + elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(ckpt.model)): + # if we're resuming, we need to re-initialize the non-trainable parameters to their original values + # TODO: do we want to extend this to non-model things that don't get initialized from a ckpt? + non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)( + model_init + ) + model = eqx.combine(ckpt.model, non_trainable) + else: + model = ckpt.model + + return dataclasses.replace(ckpt, model=model) + + return named_jit(self._initialize_state_from_scratch, self.parameter_axis_mapping)( + model_init, training_key, is_trainable + ) def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ @@ -439,15 +446,20 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): return train_step - def _init_model_and_opt_state(self, model_init, is_trainable): + def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): model = model_init() + # only force trainable params to param precision. Other params are cast to compute precision trainable, non_trainable = _partition_trainable_params(model, is_trainable) trainable = self.mp.cast_to_param(trainable) non_trainable = self.mp.cast_to_compute(non_trainable) model = eqx.combine(trainable, non_trainable) + opt_state = self.optimizer.init(trainable) - return model, opt_state + + trainer_state: TrainerState = TrainerState(0, model, opt_state, training_key, is_trainable=is_trainable) + + return trainer_state def _init_non_trainable_params(self, model_init): model = model_init() @@ -456,29 +468,6 @@ def _init_non_trainable_params(self, model_init): non_trainable = self.mp.cast_to_compute(non_trainable) return non_trainable - def _maybe_load_checkpoint(self, state: TrainerState) -> Optional[TrainerState]: - """Loads a checkpoint if one exists and we're supposed to load it, - otherwise returns the model and training state as is""" - with self.device_mesh: - if self.config.load_checkpoint is not False: - # TODO: don't remake the checkpointer every time - checkpointer = self.config.checkpointer.create(self.run_id) - load_checkpoint_path = self.config.load_checkpoint_path - - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - - ckpt = checkpointer.load_checkpoint( - state, load_checkpoint_path, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh - ) - - if ckpt is None and self.config.load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") - - return ckpt - else: - return None - @dataclass class TrainerConfig: diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index cf4fb74a6..25d5e8fb0 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -43,7 +43,7 @@ def test_viz_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + save_checkpoint({"model": model}, 0, f"{f}/ckpt") config = viz_logprobs.VizGpt2Config( data=data_config, From 8f848221d087132a113e91a679d873f1e98d036f Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 21:09:51 -0800 Subject: [PATCH 043/205] cleanup --- src/levanter/checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index eab851fa4..965343433 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -30,7 +30,6 @@ PathLike = Union[str, pathlib.Path] M = TypeVar("M", bound=PyTree) -S = TypeVar("S") @dataclass(frozen=True) From e54bad05e8740aed1b94ea0e7aad6d59d70955c7 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 21:27:30 -0800 Subject: [PATCH 044/205] cleanup --- src/levanter/trainer.py | 57 ++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index b9aa7565b..c886fe85a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -28,6 +28,7 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit +from haliax.types import Scalar import levanter.logging import levanter.tracker @@ -114,6 +115,17 @@ def decorator(fn: Callable[[StepInfo], None]): return decorator(fn) +# A note on extending Trainer: +# First, consider whether you can do what you want with hooks. Hooks can cover a lot of use cases. +# Sometimes, however, you need to do something more complicated. In that case, you can extend Trainer. +# In order to do that, you need to: +# * Extend TrainerState to add your additional state +# * Override `_train_step` to add your additional logic +# * Override `initial_state` or `_initialize_state_from_scratch` to initialize your additional state. (The latter is +# simpler and means you don't need to handle the checkpointing logic yourself.) +# * You might also need to override `training_steps` if you want to make the type checker happy. + + class Trainer: config: "TrainerConfig" optimizer: GradientTransformation @@ -414,37 +426,36 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar @cached_property def _train_step_fn(self): - @named_jit( + return named_jit( axis_resources=self.parameter_axis_mapping, out_axis_resources=self.parameter_axis_mapping, donate_args=(True,), - ) - def train_step(state: TrainerState, *batch, **batch_kwargs): - key, new_key = jax.random.split(state.training_key) - opt_state = state.opt_state - model = inference_mode(state.model, False) + )(self._train_step) - # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) + def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: + key, new_key = jax.random.split(state.training_key) + opt_state = state.opt_state + model = inference_mode(state.model, False) - def split_loss_fn(trainable_model, *batch, **batch_kwargs): - model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs, key=key) + # we do this so that we only take the gradients of the trainable parameters + trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) - loss, grads = accumulate_gradients_sharded( - split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping - )(trainable_model, *batch, **batch_kwargs) + def split_loss_fn(trainable_model, *batch, **batch_kwargs): + model = eqx.combine(trainable_model, rest_model) + return self.loss_fn(model, *batch, **batch_kwargs, key=key) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) - model = eqx.apply_updates(model, updates) + loss, grads = accumulate_gradients_sharded( + split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping + )(trainable_model, *batch, **batch_kwargs) - new_state = dataclasses.replace( - state, step=state.step + 1, model=model, opt_state=opt_state, training_key=new_key - ) + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + model = eqx.apply_updates(model, updates) - return loss, new_state + new_state = dataclasses.replace( + state, step=state.step + 1, model=model, opt_state=opt_state, training_key=new_key + ) - return train_step + return loss, new_state def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): model = model_init() @@ -457,9 +468,7 @@ def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_k opt_state = self.optimizer.init(trainable) - trainer_state: TrainerState = TrainerState(0, model, opt_state, training_key, is_trainable=is_trainable) - - return trainer_state + return TrainerState(0, model, opt_state, training_key, is_trainable) def _init_non_trainable_params(self, model_init): model = model_init() From 85dd89b658e6b88d1847f2cf0d1eedaf1bf0a98b Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 25 Nov 2023 22:56:10 -0800 Subject: [PATCH 045/205] minor --- src/levanter/checkpoint.py | 17 +++++++++++++++++ src/levanter/trainer.py | 10 ++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 965343433..a633f1e09 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -280,6 +280,23 @@ def load_checkpoint( axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, ) -> Optional[M]: + """ + Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint + in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint + loads only that subpath of the checkpoint. This is useful for loading, e.g., just the model and not + the entire training state. + + Args: + tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any] + checkpoint_path: the path to load the checkpoint from + subpath: the subpath to load from the checkpoint + discover_latest: whether to discover the latest checkpoint in the given path + axis_mapping: the axis mapping to use for loading the checkpoint + mesh: the mesh to use for loading the checkpoint + Returns: + the loaded checkpoint, with the same structure as the exemplar tree + + """ fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index c886fe85a..bc078414d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -34,7 +34,7 @@ import levanter.tracker import levanter.tracker.wandb from levanter import tracker -from levanter.checkpoint import CheckpointerConfig +from levanter.checkpoint import CheckpointerConfig, load_checkpoint from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig @@ -283,14 +283,12 @@ def initial_state( self._initialize_state_from_scratch, model_init, training_key, is_trainable ) - # TODO: don't remake the checkpointer every time - checkpointer = self.config.checkpointer.create(self.run_id) load_checkpoint_path = self.config.load_checkpoint_path if load_checkpoint_path is None: load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - ckpt = checkpointer.load_checkpoint( + ckpt = load_checkpoint( trainer_state_shape, load_checkpoint_path, axis_mapping=self.parameter_axis_mapping, @@ -440,13 +438,13 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal # we do this so that we only take the gradients of the trainable parameters trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) - def split_loss_fn(trainable_model, *batch, **batch_kwargs): + def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): model = eqx.combine(trainable_model, rest_model) return self.loss_fn(model, *batch, **batch_kwargs, key=key) loss, grads = accumulate_gradients_sharded( split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping - )(trainable_model, *batch, **batch_kwargs) + )(trainable_model, rest_model, *batch, **batch_kwargs) updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) model = eqx.apply_updates(model, updates) From c61824ebb0c5fe4903d2655b30889a9478c4721e Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 26 Nov 2023 21:05:41 -0800 Subject: [PATCH 046/205] generalize and extract the checkpoint loading logic so it can be used separately from trainer --- src/levanter/checkpoint.py | 151 +++++++++++++++++++++---------------- src/levanter/trainer.py | 65 ++++++---------- tests/test_checkpoint.py | 53 +++++++++++++ 3 files changed, 166 insertions(+), 103 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index a633f1e09..e11182240 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -1,5 +1,7 @@ +import contextlib import dataclasses import datetime +import functools import json import logging import os @@ -7,19 +9,23 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union +from typing import Callable, List, Optional, ParamSpec, Sequence, TypeVar, Union import equinox +import equinox as eqx import fsspec import jax import jax.numpy as jnp from draccus import field -from equinox import default_deserialise_filter_spec, default_serialise_filter_spec from fsspec import AbstractFileSystem +from jax import ShapeDtypeStruct +from jax._src.interpreters.pxla import Mesh from jax.experimental.multihost_utils import broadcast_one_to_all, sync_global_devices from jaxtyping import PyTree +import haliax as hax import haliax.partitioning +from haliax.partitioning import ResourceMapping from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore from levanter.types import FilterSpec @@ -394,53 +400,6 @@ def checkpoint_sort_key(ckpt_dir): return None -def tree_serialise_leaves( - path: PathLike, - pytree: PyTree, - filter_spec=default_serialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, -) -> None: - """Analog to `equinox.tree_serialise_leaves`, but saves the leaves of a PyTree using fsspec.""" - - with fsspec.open(str(path), "wb") as f: - logger.info(f"Serializing to {path}") - - def _serialise(spec, x): - def __serialise(y): - spec(f, y) - return y - - jax.tree_map(__serialise, x, is_leaf=is_leaf) - - jax.tree_map(_serialise, filter_spec, pytree) - - -def tree_deserialise_leaves( - path: PathLike, - like: PyTree, - filter_spec=default_deserialise_filter_spec, - is_leaf: Optional[Callable[[Any], bool]] = None, - fs=None, -) -> PyTree: - """ - Analog to `equinox.tree_deserialise_leaves`, but loads the leaves of a PyTree using fsspec. - """ - - fs, path_to_open = _get_fs_and_plain_path(path, fs) - - with fs.open(path_to_open, "rb") as f: - - def _deserialise(spec, x): - def __deserialise(y): - return spec(f, y) - - return jax.tree_util.tree_map(__deserialise, x, is_leaf=is_leaf) - - out = jax.tree_util.tree_map(_deserialise, filter_spec, like) - jax.tree_util.tree_map(_assert_same, out, like, is_leaf=is_leaf) - return out - - def _get_fs_and_plain_path(path, fs=None): if fs is None: fs, _, (path_to_open,) = fsspec.get_fs_token_paths(str(path)) @@ -449,20 +408,6 @@ def _get_fs_and_plain_path(path, fs=None): return fs, path_to_open -# similar to eqx but it's a bit more permissive: it just wants things that have shapes and dtypes to be the same -def _assert_same(new, old): - if hasattr(new, "shape") and hasattr(old, "shape"): - assert new.shape == old.shape, f"Shapes don't match: {new.shape} vs {old.shape}" - if hasattr(new, "dtype") and hasattr(old, "dtype"): - assert new.dtype == old.dtype, f"Dtypes don't match: {new.dtype} vs {old.dtype}" - - # now get mad if one has a shape and the other doesn't - if hasattr(new, "shape") != hasattr(old, "shape"): - raise ValueError(f"One has a shape and the other doesn't: {new} vs {old}") - if hasattr(new, "dtype") != hasattr(old, "dtype"): - raise ValueError(f"One has a dtype and the other doesn't: {new} vs {old}") - - @dataclass class CheckpointerConfig: base_path: str = "checkpoints/" @@ -497,3 +442,83 @@ def __post_init__(self): interval["until"] is None or interval["until"] > prev_interval["until"] ), "Checkpoint intervals must be monotonic" prev_interval = interval + + +P = ParamSpec("P") + + +def load_from_checkpoint_or_initialize( + init_fn: Callable[P, M], + checkpoint_path: str, + axis_mapping: Optional[ResourceMapping] = None, + mesh: Optional[Mesh] = None, + *, + # TODO: add this back in + # allow_partial_checkpoint: bool, + force_load_checkpoint: Optional[bool] = None, + is_checkpointed: Optional[PyTree[FilterSpec]] = True, +) -> Callable[P, M]: + """ + Loads a checkpoint if it exists, otherwise initializes from scratch. + + Args: + init_fn: the initialization function. This should be a function that takes some arguments and returns a + model or state. It should be jit-able and should not have any (destructive) side effects. It will + likely be called 2-3 times. + checkpoint_path: the path to the checkpoint + axis_mapping: the axis mapping to use for initialization. If None, the default axis mapping will be used. + mesh: the mesh to use for initialization. If None, the default mesh will be used. + force_load_checkpoint: if True, we must load a checkpoint. If False, we will not load a checkpoint. If None, + we will load a checkpoint if it exists. + is_checkpointed: a filter spec for the checkpointed parameters. This is used to filter out non-checkpointed + parameters for the initialization. If you don't specify this, all parameters are assumed to be + checkpointed. + """ + if force_load_checkpoint is False: + cmanager = mesh or contextlib.nullcontext() + + @functools.wraps(init_fn) + @hax.named_jit(axis_resources=axis_mapping) + def fn(*args, **kwargs): + with cmanager: + return init_fn(*args, **kwargs) + + return fn + else: + + @functools.wraps(init_fn) + def fn(*args, **kwargs): + with contextlib.ExitStack() as stack: + if mesh is not None: + stack.enter_context(mesh) + if axis_mapping is not None: + stack.enter_context(hax.axis_mapping(axis_mapping)) + + ckpt_shape = eqx.filter_eval_shape(init_fn, *args, **kwargs) + + ckpt = load_checkpoint( + eqx.filter(ckpt_shape, is_checkpointed), checkpoint_path, axis_mapping=axis_mapping, mesh=mesh + ) + + if ckpt is None: + if force_load_checkpoint is True: + raise ValueError(f"Could not load checkpoint from {checkpoint_path}") + else: + out = hax.named_jit(init_fn, axis_mapping)(*args, **kwargs) + return out + else: + ckpt = eqx.combine(ckpt, ckpt_shape) + if any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(ckpt)): + # if we're resuming, we need to initialize any non-checkpointed values + @hax.named_jit(axis_resources=axis_mapping) + def partial_init(init_fn, *args, **kwargs): + m = init_fn(*args, **kwargs) + return eqx.filter(m, is_checkpointed, inverse=True) + + non_checkpointed = partial_init(init_fn, *args, **kwargs) + ckpt = eqx.filter(ckpt, lambda x: not isinstance(x, ShapeDtypeStruct)) + return eqx.combine(ckpt, non_checkpointed) + else: + return ckpt + + return fn diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index bc078414d..71695cf11 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -19,7 +19,6 @@ import numpy as np import optax from draccus import field -from jax import ShapeDtypeStruct from jax.experimental import multihost_utils from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree @@ -34,7 +33,7 @@ import levanter.tracker import levanter.tracker.wandb from levanter import tracker -from levanter.checkpoint import CheckpointerConfig, load_checkpoint +from levanter.checkpoint import CheckpointerConfig, load_from_checkpoint_or_initialize from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig @@ -266,7 +265,7 @@ def initial_state( not checkpointed. If you don't specify this, all parameters are assumed to be trainable. Returns: - model, opt_state, key, resume_step + TrainerState: the initial state, """ if model is not None and model_init is not None: raise ValueError("only one of model and model_init should be specified") @@ -278,45 +277,31 @@ def initial_state( model_init = jax.tree_util.Partial(lambda m: m, model) with self: - if self.config.load_checkpoint is not False: - trainer_state_shape = eqx.filter_eval_shape( - self._initialize_state_from_scratch, model_init, training_key, is_trainable - ) - - load_checkpoint_path = self.config.load_checkpoint_path - - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - - ckpt = load_checkpoint( - trainer_state_shape, - load_checkpoint_path, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ) - - if ckpt is None: - if self.config.load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") - else: - if model is not None: - model = eqx.combine(ckpt.model, model) - elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(ckpt.model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - # TODO: do we want to extend this to non-model things that don't get initialized from a ckpt? - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)( - model_init - ) - model = eqx.combine(ckpt.model, non_trainable) - else: - model = ckpt.model - - return dataclasses.replace(ckpt, model=model) - - return named_jit(self._initialize_state_from_scratch, self.parameter_axis_mapping)( - model_init, training_key, is_trainable + load_checkpoint_path = self.config.load_checkpoint_path + + if load_checkpoint_path is None: + load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + + # if we're loading a checkpoint, we need to know which parameters are trainable + is_checkpointed = TrainerState(True, is_trainable, True, True, is_trainable) # type: ignore + + assert model_init is not None + + state = load_from_checkpoint_or_initialize( + self._initialize_state_from_scratch, + load_checkpoint_path, + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, + force_load_checkpoint=self.config.load_checkpoint, + is_checkpointed=is_checkpointed, + )( + model_init, + training_key, + is_trainable, ) + return state + def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index f181dce7f..bca7000c3 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -9,13 +9,19 @@ import numpy as np import optax from chex import assert_trees_all_close +from jax import ShapeDtypeStruct from jax import numpy as jnp +from jax import tree_util as jtu + +import haliax.nn +from haliax import Axis from levanter.checkpoint import ( Checkpointer, CheckpointInterval, discover_latest_checkpoint, load_checkpoint, + load_from_checkpoint_or_initialize, load_metadata, save_checkpoint, ) @@ -225,3 +231,50 @@ def test_checkpoint_discovery(): assert latest == f"{tempdir}/step-30" assert discover_latest_checkpoint("file:///tmp/does-not-exist") is None + + +def test_load_from_checkpoint_or_initialize(): + In = Axis("in", 2) + Out = Axis("out", 1) + + def init_fn(key): + return haliax.nn.MLP.init(In, Out, 2, 3, key=key) + + k0 = jax.random.PRNGKey(0) + k1 = jax.random.PRNGKey(1) + + model0 = init_fn(k0) + model1 = init_fn(k1) + + is_checkpointed = jtu.tree_map(lambda _: False, model0) + is_checkpointed = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) + is_checkpointed1 = jtu.tree_map(lambda _: False, model1) + is_checkpointed1 = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) + with jax.sharding.Mesh(jax.devices(), ("devices",)): + with tempfile.TemporaryDirectory() as tmpdir: + filtered = eqx.filter(model0, is_checkpointed) + save_checkpoint(filtered, step=0, checkpoint_path=tmpdir, exist_ok=True) + + loaded = load_from_checkpoint_or_initialize(init_fn, tmpdir, is_checkpointed=is_checkpointed)(k1) + + assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct)))) + + assert_trees_all_close( + jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), + jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed))), + ) + + assert_trees_not_close( + jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed, inverse=True))), + jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed, inverse=True))), + ) + + assert_trees_not_close( + jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), + jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed))), + ) + + assert_trees_all_close( + jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed, inverse=True))), + jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed1, inverse=True))), + ) From 739147540ff7641df365bc494611d568b417cebc Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 27 Nov 2023 22:55:06 -0800 Subject: [PATCH 047/205] Revert "Temporarily Revert "Generic Tracker interface, support for TB logging (#367)"" This reverts commit 137bd4b6cd724a5f484a4697e8a09196712b2b26. --- README.md | 3 +- config/backpack.yaml | 2 +- config/gpt2_1536.yaml | 2 +- config/gpt2_20b.yaml | 2 +- config/gpt2_7b.yaml | 2 +- config/gpt2_large.yaml | 2 +- config/gpt2_medium.yaml | 2 +- config/gpt2_micro.yaml | 2 +- config/gpt2_nano_tb.yaml | 26 +++ config/gpt2_small.yaml | 2 +- config/gpt2_small_fast.yaml | 9 +- config/gpt2_small_fast_mix.yaml | 2 +- config/gpt2_small_fast_pile.yaml | 2 +- config/gpt2_small_fast_wiki.yaml | 2 +- config/gpt2_xl.yaml | 2 +- config/llama2_7b.yaml | 3 +- config/llama2_7b_continued.yaml | 3 +- config/llama2_nano.yaml | 2 +- config/lora/mpt_biomed.yaml | 3 +- config/mpt_7b_continued.yaml | 2 +- config/mpt_7b_continued_biomedlm.yaml | 2 +- docs/Configuration-Guide.md | 106 ++++++++++-- docs/Training-On-Your-Data.md | 3 +- docs/{ => dev}/Port-Models.md | 4 +- docs/dev/Trackers.md | 104 ++++++++++++ examples/alpaca-lora/alpaca_lora.py | 91 +++++----- examples/alpaca/alpaca.py | 4 +- mkdocs.yml | 3 +- pyproject.toml | 2 +- src/levanter/__init__.py | 2 + src/levanter/callbacks.py | 44 ++--- src/levanter/data/shard_cache.py | 28 ++-- src/levanter/data/text.py | 4 +- src/levanter/logging.py | 230 ++------------------------ src/levanter/lora.py | 16 +- src/levanter/main/cache_dataset.py | 17 +- src/levanter/main/lora_lm.py | 11 +- src/levanter/main/train_lm.py | 19 ++- src/levanter/tracker/__init__.py | 16 ++ src/levanter/tracker/helpers.py | 71 ++++++++ src/levanter/tracker/tensorboard.py | 86 ++++++++++ src/levanter/tracker/tracker.py | 117 +++++++++++++ src/levanter/tracker/tracker_fns.py | 152 +++++++++++++++++ src/levanter/tracker/wandb.py | 213 ++++++++++++++++++++++++ src/levanter/trainer.py | 170 +++++++++++-------- src/levanter/utils/jax_utils.py | 5 + tests/test_eval_lm.py | 2 +- tests/test_logging.py | 4 +- tests/test_tracker.py | 80 +++++++++ tests/test_train_lm.py | 2 +- tests/test_viz_lm.py | 2 +- 51 files changed, 1242 insertions(+), 443 deletions(-) create mode 100644 config/gpt2_nano_tb.yaml rename docs/{ => dev}/Port-Models.md (98%) create mode 100644 docs/dev/Trackers.md create mode 100644 src/levanter/tracker/__init__.py create mode 100644 src/levanter/tracker/helpers.py create mode 100644 src/levanter/tracker/tensorboard.py create mode 100644 src/levanter/tracker/tracker.py create mode 100644 src/levanter/tracker/tracker_fns.py create mode 100644 src/levanter/tracker/wandb.py create mode 100644 tests/test_tracker.py diff --git a/README.md b/README.md index 3f41614e3..bbc5cc6c6 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/backpack.yaml b/config/backpack.yaml index 5b6cef3cb..493be77a3 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -10,7 +10,7 @@ model: num_senses: 16 sense_intermediate_scale: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "backpack" ] diff --git a/config/gpt2_1536.yaml b/config/gpt2_1536.yaml index 50ccbd882..a3633bf65 100644 --- a/config/gpt2_1536.yaml +++ b/config/gpt2_1536.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_20b.yaml b/config/gpt2_20b.yaml index 76bf6ba96..6f5f40e1b 100644 --- a/config/gpt2_20b.yaml +++ b/config/gpt2_20b.yaml @@ -12,7 +12,7 @@ model: use_bias: false fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_7b.yaml b/config/gpt2_7b.yaml index affb67aa5..36a3d4fd2 100644 --- a/config/gpt2_7b.yaml +++ b/config/gpt2_7b.yaml @@ -11,7 +11,7 @@ model: resid_pdrop: 0.0 fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_large.yaml b/config/gpt2_large.yaml index 525a92c99..d772f9fdf 100644 --- a/config/gpt2_large.yaml +++ b/config/gpt2_large.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_medium.yaml b/config/gpt2_medium.yaml index 9ea4408bc..47e21799c 100644 --- a/config/gpt2_medium.yaml +++ b/config/gpt2_medium.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_micro.yaml b/config/gpt2_micro.yaml index 274ecddaa..0a8283e78 100644 --- a/config/gpt2_micro.yaml +++ b/config/gpt2_micro.yaml @@ -6,7 +6,7 @@ model: num_heads: 8 num_layers: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_nano_tb.yaml b/config/gpt2_nano_tb.yaml new file mode 100644 index 000000000..9ada16aa3 --- /dev/null +++ b/config/gpt2_nano_tb.yaml @@ -0,0 +1,26 @@ +data: + id: dlwh/wikitext_103_detokenized +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_eval_parallelism: 1 + per_device_parallelism: 1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + tracker: + type: tensorboard + logdir: tb_logs/ diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index 74d0e031a..c657fe787 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 4c8434f38..a7375de2f 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -8,9 +8,12 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: - project: "levanter" - tags: [ "openwebtext", "gpt2", "itest"] + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] + - type: tensorboard + logdir: gs://levanter-checkpoints/tblogs/ mp: p=f32,c=bfloat16 model_axis_size: 1 diff --git a/config/gpt2_small_fast_mix.yaml b/config/gpt2_small_fast_mix.yaml index 0785e9103..ca9fa2ca6 100644 --- a/config/gpt2_small_fast_mix.yaml +++ b/config/gpt2_small_fast_mix.yaml @@ -21,7 +21,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext+wiki", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index f30743c1d..a0336da45 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_wiki.yaml b/config/gpt2_small_fast_wiki.yaml index 407d8705b..a25736434 100644 --- a/config/gpt2_small_fast_wiki.yaml +++ b/config/gpt2_small_fast_wiki.yaml @@ -9,7 +9,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] diff --git a/config/gpt2_xl.yaml b/config/gpt2_xl.yaml index 8230b56a5..026fc077e 100644 --- a/config/gpt2_xl.yaml +++ b/config/gpt2_xl.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 diff --git a/config/llama2_7b.yaml b/config/llama2_7b.yaml index 68931f3fa..b4ebe705f 100644 --- a/config/llama2_7b.yaml +++ b/config/llama2_7b.yaml @@ -11,7 +11,8 @@ model: # initialize_from_hf: "meta-llama/Llama-2-7b-hf" # use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["openwebtext", "llama"] diff --git a/config/llama2_7b_continued.yaml b/config/llama2_7b_continued.yaml index e03be7168..edb72a7e4 100644 --- a/config/llama2_7b_continued.yaml +++ b/config/llama2_7b_continued.yaml @@ -6,7 +6,8 @@ model: initialize_from_hf: true use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "llama2"] diff --git a/config/llama2_nano.yaml b/config/llama2_nano.yaml index d7196c59b..6b6f8d93f 100644 --- a/config/llama2_nano.yaml +++ b/config/llama2_nano.yaml @@ -11,7 +11,7 @@ model: num_heads: 4 num_layers: 2 trainer: - wandb: + tracker: project: "levanter" tags: ["openwebtext", "llama"] mp: p=f32 diff --git a/config/lora/mpt_biomed.yaml b/config/lora/mpt_biomed.yaml index f49267ca1..6b19d0ab5 100644 --- a/config/lora/mpt_biomed.yaml +++ b/config/lora/mpt_biomed.yaml @@ -11,7 +11,8 @@ lora: alpha: 32.0 target_modules: ["Wqkv"] trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["mpt", "lora", "pubmed"] diff --git a/config/mpt_7b_continued.yaml b/config/mpt_7b_continued.yaml index a7eaf800b..980b4aaaf 100644 --- a/config/mpt_7b_continued.yaml +++ b/config/mpt_7b_continued.yaml @@ -4,7 +4,7 @@ model: initialize_from_hf: true use_hf_model_config: true trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "mpt"] diff --git a/config/mpt_7b_continued_biomedlm.yaml b/config/mpt_7b_continued_biomedlm.yaml index 44961df46..504f1a3ba 100644 --- a/config/mpt_7b_continued_biomedlm.yaml +++ b/config/mpt_7b_continued_biomedlm.yaml @@ -10,7 +10,7 @@ model: initialize_from_hf: "mosaicml/mpt-7b@68e1a8e0ebb9b30f3c45c1ef6195980f29063ae2" use_hf_model_config: true trainer: - wandb: + tracker: project: "levanter" tags: ["pubmed", "mpt", "continued"] diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index c7891e1e9..daa7d3da7 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -35,7 +35,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] @@ -178,12 +179,34 @@ The default step-based checkpoint policy is to save a checkpoint every 10,000 st -## WandB +## Trackers and Logging -We mostly use wandb for logging, including using wandb for allocating the run id. We may change this. -These all live in a nested object `wandb` inside `trainer`. Most of these are the same as the corresponding `wandb.init` -parameters. +We mostly use [W&B](https://wandb.ai/site) for tracking values and other metadata about a run. However, we also support +Tensorboard and a few other trackers. You can also use multiple trackers at once, or even write your own. +See [Trackers](dev/Trackers.md) for more information. + +### W&B + +Wandb is the default tracker and is installed by default. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +Because wandb is the default, you can also just do: + +```yaml +trainer: + tracker: + project: my-project + entity: my-entity +``` + | Parameter | Description | Default | @@ -205,6 +228,35 @@ of your main script. To use it, you must also set the right environment variables. Something like `XLA_FLAGS="--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*`. We will automatically parse out the env variable. +### Tensorboard + +Tensorboard is also supported. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: tensorboard + logdir: logs +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + ## Ray Config Levanter will by default automatically start a Ray cluster with all @@ -212,11 +264,11 @@ the machines being used for training. This is useful for distributed preprocessing. You can disable this behavior using `auto_start_cluster: false`. -| Parameter | Description | Default | -|---------------------|-----------------------------------------------------------------------------|---------| -| `address` | The address of the Ray cluster to connect to. | `None` | -| `start_workers` | Whether to start Ray workers. If `False`, you must start them yourself. | `True` | -| `auto_start_cluster`| Whether to start a Ray cluster automatically. | `True` | +| Parameter | Description | Default | +|----------------------|-------------------------------------------------------------------------|---------| +| `address` | The address of the Ray cluster to connect to. | `None` | +| `start_workers` | Whether to start Ray workers. If `False`, you must start them yourself. | `True` | +| `auto_start_cluster` | Whether to start a Ray cluster automatically. | `True` | ## Distributed Config @@ -226,12 +278,12 @@ If you're not using SLURM or TPUs, you can specify the cluster manually using th **Don't use this on TPU, and possibly not on SLURM either.** -| Parameter | Description | Default | -|---------------------|-----------------------------------------------------------------------------|-------------------------| -| `coordinator_address`| The address of the coordinator. If `None`, we'll use the default address. | `None` | -| `num_processes` | The number of processes in the cluster. | `None` | -| `process_id` | The process id of this process. | `None` | -| `local_device_ids` | The local device ids of this process. | ${CUDA_VISIBLE_DEVICES} | +| Parameter | Description | Default | +|-----------------------|---------------------------------------------------------------------------|-------------------------| +| `coordinator_address` | The address of the coordinator. If `None`, we'll use the default address. | `None` | +| `num_processes` | The number of processes in the cluster. | `None` | +| `process_id` | The process id of this process. | `None` | +| `local_device_ids` | The local device ids of this process. | ${CUDA_VISIBLE_DEVICES} | @@ -276,8 +328,26 @@ We won't go into detail here. You can see the auto-generated docs below. ::: levanter.checkpoint.Checkpointer -### Wandb -::: levanter.logging.WandbConfig +### Trackers and Metrics + +See also [Trackers](dev/Trackers.md) for more information. Basic configuration is shown below. + +#### Single Tracker + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + + + +::: levanter.tracker.wandb.WandbConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + ### Distributed and Ray diff --git a/docs/Training-On-Your-Data.md b/docs/Training-On-Your-Data.md index edf33e0af..4c543b04f 100644 --- a/docs/Training-On-Your-Data.md +++ b/docs/Training-On-Your-Data.md @@ -214,7 +214,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" # TODO tags: ["gpt2"] diff --git a/docs/Port-Models.md b/docs/dev/Port-Models.md similarity index 98% rename from docs/Port-Models.md rename to docs/dev/Port-Models.md index f75fa7534..41228c2a3 100644 --- a/docs/Port-Models.md +++ b/docs/dev/Port-Models.md @@ -287,7 +287,7 @@ model: num_layers: 2 ``` -For more details on the training configuration, please refer to [Configuration Guide](./Configuration-Guide.md). +For more details on the training configuration, please refer to [Configuration Guide](../Configuration-Guide.md). ### Launch Training Job Once you have your training configuration ready and your training environment set up, you can launch a training job with the following command: @@ -299,7 +299,7 @@ HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \ python levanter/src/levanter/main/train_lm.py --config_path $CONFIG_PATH ``` -Check out [Training on Your Own Data](./Training-On-Your-Data.md) for more detailed guide on how to spin off a training cluster and launch a training job. +Check out [Training on Your Own Data](../Training-On-Your-Data.md) for more detailed guide on how to spin off a training cluster and launch a training job. ### Profile Your Model If you are interested in profiling the training throughput of your model, good news is that it comes for free with automatic job monitoring in Levanter, powered through Weights & Biases. diff --git a/docs/dev/Trackers.md b/docs/dev/Trackers.md new file mode 100644 index 000000000..1f1677d52 --- /dev/null +++ b/docs/dev/Trackers.md @@ -0,0 +1,104 @@ +# Trackers and Metrics + +Logging values and other metadata about a run is a core requirement for any ML framework. +Until recently, Levanter had a hard dependency on [W&B](https://wandb.ai/site) for tracking such values. + +In the latest version, we introduce the [levanter.tracker.Tracker][] interface, which allows you to use any tracking backend you want. +The interface name is taken from the [HuggingFace Accelerate](https://github.com/huggingface/accelerate/blob/0f2686c8d3e6d949c4b7efa15d7f2dee44f7ce91/src/accelerate/tracking.py#L395) +framework. + +Given Levanter's historical dependency on W&B, the interface is designed to look similar to W&B's API. +The methods currently exposed are: + +* [levanter.tracker.current_tracker][]: returns the current tracker instance or sets it. +* [levanter.tracker.log_metrics][]: logs a dictionary of metrics for a given step. +* [levanter.tracker.log_summary][]: logs a dictionary of "summary" information, analogous to W&B's version. +* [levanter.tracker.get_tracker][]: returns a tracker with the given name. +* [levanter.tracker.jit_log_metrics][]: a version of [levanter.tracker.log_metrics][] that works inside JAX jit. + +A basic example of using the tracker interface is shown below: + +```python +import wandb +from levanter.tracker import current_tracker, log_metrics, log_summary +from levanter.tracker.wandb import WandbTracker + +with current_tracker(WandbTracker(wandb.init())): + for step in range(100): + log_metrics({"loss": 100 -0.01 * step}, step=step) + + log_summary({"best_loss": 0.0}) +``` + +A more typical example would be to use it in a config file, as we do with Trainer: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + +## Adding your own tracker + +To add your own tracker, you need to implement the [levanter.tracker.Tracker][] interface. +You will also want to register your config with TrackerConfig as a "choice" in the choice type. +Follow the pattern for Tensorboard and W&B. + +TODO: expand this section. + + +## API Reference + +### Core Functions + +::: levanter.tracker.current_tracker + +::: levanter.tracker.log_metrics + +::: levanter.tracker.log_summary + +::: levanter.tracker.get_tracker + +::: levanter.tracker.jit_log_metrics + +### Trackers + +::: levanter.tracker.Tracker + +::: levanter.tracker.tracker.CompositeTracker + +::: levanter.tracker.tracker.NoopTracker + +::: levanter.tracker.tensorboard.TensorboardTracker + +::: levanter.tracker.wandb.WandbTracker + +### Tracker Config + +::: levanter.tracker.TrackerConfig + +::: levanter.tracker.tracker.NoopConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + +::: levanter.tracker.wandb.WandbConfig diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index a4380a92b..31de93252 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -8,7 +8,6 @@ import jax.random as jrandom import transformers -import wandb import haliax as hax @@ -101,53 +100,59 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + + trainer.add_default_hooks() + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + levanter.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 5240d9861..20cb98a33 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -217,9 +217,7 @@ def train(config: TrainArgs): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss) - - with trainer.device_mesh: + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = trainer.parameter_axis_mapping diff --git a/mkdocs.yml b/mkdocs.yml index 7e175d199..53ef6a35c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -89,7 +89,8 @@ nav: - "Fine-Tuning.md" - "LoRA.md" - 'Developer Guide': - - 'Port-Models.md' + - 'dev/Port-Models.md' + - 'dev/Trackers.md' - Other: - 'Levanter-1.0-Release.md' - "design/Data-Loader-Design.md" diff --git a/pyproject.toml b/pyproject.toml index df2eb7533..552e80bd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "transformers>=4.22.0", "optax", "wandb", - "draccus>=0.6", + "draccus>=0.7.1", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets==2.11.0", diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 33bcd249d..d89ea4945 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,4 +3,6 @@ import levanter.data as data import levanter.distributed as distributed import levanter.logging as logging +import levanter.tracker as tracker import levanter.visualization as visualization +from levanter.tracker import current_tracker, get_tracker, log_metrics, log_summary diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 2292c714a..154099e8a 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -1,5 +1,5 @@ import copy -import logging +import logging as pylogging import os import re import subprocess @@ -11,16 +11,18 @@ import humanfriendly import jax -import wandb from tqdm import tqdm -from levanter.logging import WandbConfig, log_optimizer_hyperparams, save_xla_dumps_to_wandb +import levanter.tracker +from levanter.logging import save_xla_dumps_to_wandb +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): @@ -57,11 +59,10 @@ def compute_validation_loss( def compute_loss(info: StepInfo): loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) - if wandb.run is not None: - prefix = "eval" - if name: - prefix += "/" + name - wandb.log({f"{prefix}/loss": loss}, step=info.step) + prefix = "eval" + if name: + prefix += "/" + name + levanter.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -73,12 +74,14 @@ def compute_loss(info: StepInfo): return compute_loss -def log_to_wandb(step: StepInfo): - wandb.log({"train/loss": step.loss, "global_step": step.step}, step=step.step) +def log_step_info(step: StepInfo): + levanter.tracker.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") def wandb_xla_logger(config: WandbConfig): + import wandb + last_mtime = wandb.run and wandb.run.start_time or time.time() def log_xla_to_wandb(step: StepInfo): @@ -108,14 +111,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -125,7 +128,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -152,7 +155,7 @@ def update_pbar(step: StepInfo): def log_memory_usage(sample_interval: float = 1.0, log_individual_devices: bool = False): """ - Logs memory usage to wandb. This runs a loop that samples memory usage every `sample_interval` seconds. + Logs memory usage. This runs a loop that samples memory usage every `sample_interval` seconds. We only log when hooks are invoked, so there's not much point in running this much more frequently than you invoke the hook. @@ -218,7 +221,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - wandb.log({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -229,14 +232,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - wandb.log({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - wandb.log({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage @@ -262,6 +265,9 @@ def compute_and_viz_log_probs(step: StepInfo): path = os.path.join(html_dir, f"step_{step}.html") viz_probs(path, model, tokenizer, log_prob_fn, test_data, max_docs=max_docs) + # TODO: convert to generic logging + import wandb + wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 4607459c1..83e4a6c85 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1,7 +1,7 @@ # Dataset for preprocessing data, tokenizing, and caching to disk. import asyncio import dataclasses -import logging +import logging as pylogging import os import sys import threading @@ -31,7 +31,6 @@ import pyarrow.parquet as pq import ray import tblib -import wandb from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from ray.actor import ActorHandle @@ -46,6 +45,9 @@ TimeRemainingColumn, ) +import levanter.tracker + +from .. import logging from . import ShardableDataset from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch from .sharded_dataset import ShardedDataset @@ -55,7 +57,7 @@ T_co = TypeVar("T_co", covariant=True) _ExcInfo = Tuple[Optional[BaseException], tblib.Traceback] -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) DEFAULT_ROWS_PER_CHUNK = 1024 * 32 LEDGER_FILE_NAME = "cache_ledger.json" @@ -265,7 +267,7 @@ def _produce_cache_for_shard( """Produces chunks of preprocessed data from a single shard and writes them to disk. Chunks are written to sink, which is an actor of ChunkCacheBuilder.""" # TODO: thread logging level through calls - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) # load or create shard metadata (for recovery) try: shard_name = source.shard_names[shard_idx] @@ -475,7 +477,7 @@ def _init_progress(self, metrics): self.progress.start() -class WandbMetricsMonitor(MetricsMonitor): +class LoggingMetricsMonitor(MetricsMonitor): last_metrics: Optional[InProgressCacheMetrics] last_time: Optional[float] @@ -517,16 +519,16 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - wandb.log(to_log, commit=self.commit) + levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): # TODO: I'd like to get the trainer pbar migrated to rich and just use rich everywhere, but until then, # we have separate logging - def __init__(self, logger: Optional[Union[logging.Logger, str]] = None, level=logging.INFO): + def __init__(self, logger: Optional[Union[pylogging.Logger, str]] = None, level=pylogging.INFO): if isinstance(logger, str): - logger = logging.getLogger(logger) - self.logger = logger or logging.getLogger(__name__) + logger = pylogging.getLogger(logger) + self.logger = logger or pylogging.getLogger(__name__) self.level = level def __call__(self, metrics: InProgressCacheMetrics): @@ -570,7 +572,7 @@ def is_producing(self): def _mk_process_task(processor: BatchProcessor[T]): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(batch: List[T]) -> pa.RecordBatch: - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) return processor(batch) return process_task @@ -579,7 +581,7 @@ def process_task(batch: List[T]) -> pa.RecordBatch: def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(batch: List[T]) -> pa.RecordBatch: - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) ray.get(queue.task_running.remote()) result = processor(batch) del batch @@ -674,7 +676,7 @@ def __init__( processor: BatchProcessor[T], rows_per_chunk: int, ): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() self._current_round_robin = [] @@ -813,7 +815,7 @@ class ChunkCacheBroker: _finished_promise: asyncio.Future[None] def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) self.chunks = [] self._reader_promises = {} self._is_finished = False diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 5a4890efb..4ad535114 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -45,9 +45,9 @@ from levanter.data.shard_cache import ( # noqa ChunkMetadata, LoggerMetricsMonitor, + LoggingMetricsMonitor, MetricsMonitor, ShardCache, - WandbMetricsMonitor, _serialize_json_and_commit, build_cache, ) @@ -604,7 +604,7 @@ def build_or_load_cache( if monitors is True: monitors = [ - WandbMetricsMonitor(prefix=f"preprocessing/{split}", commit=False), + LoggingMetricsMonitor(prefix=f"preprocessing/{split}", commit=False), LoggerMetricsMonitor(f"preprocessing.{split}"), ] elif monitors is False: diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 4906d0484..23cf63047 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -1,43 +1,16 @@ import contextlib -import dataclasses import logging as pylogging -import os -import tempfile import time -import warnings -from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Union +from typing import List, Union -import draccus import jax -import wandb -from draccus import field -from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax import MultiStepsState -from levanter.utils import jax_utils -from levanter.utils.jax_utils import jnp_to_python +pylogger = pylogging.getLogger(__name__) -logger = pylogging.getLogger(__name__) - -def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state - - def wrap_key(key): - if prefix: - return f"{prefix}/{key}" - return key - - if hasattr(opt_state, "hyperparams"): - params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - wandb.log(params, step=step) - - -def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: +def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: """ Initialize logging.Logger with the appropriate name, console, and file handlers. @@ -61,13 +34,21 @@ def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os + from levanter.tracker.wandb import is_wandb_available + + if not is_wandb_available(): + pylogger.warning("Wandb is not available, so we can't save XLA dumps") + return + + import wandb + # attempt to parse xla_flags to see if we're dumping assembly files flags = os.getenv("XLA_FLAGS", None) if flags is not None and "xla_dump_to" in flags: # parse the path # this isn't robust to quotes path = flags.split("xla_dump_to=")[1].split(" ")[0] - logger.info(f"Found xla_dump_to={path}, logging to wandb") + pylogger.info(f"Found xla_dump_to={path}, logging to wandb") if wandb.run: # only want to save the files that were generated during this run # XLA_FLAGS has to be set before the first jax call, so we can't just set it in the middle of the run @@ -79,7 +60,7 @@ def include_file(path: str): wandb.run.log_code(root=path, name="xla_dumps", include_fn=include_file) else: - logger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") + pylogger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") @contextlib.contextmanager @@ -97,23 +78,6 @@ def fn(): end = time.time() -@contextlib.contextmanager -def log_time_to_wandb(name: str, *, step=None): - with capture_time() as fn: - yield fn - wandb.log({name: fn()}, step=step) - - -def jittable_wandb_log(data, *, step=None): - """uses jax effect callback to log to wandb from the host""" - if is_wandb_available(): - jax.debug.callback(wandb.log, data, step=step) - - -def is_wandb_available(): - return wandb is not None and wandb.run is not None - - def silence_transformer_nag(): # this is a hack to silence the transformers' "None of PyTorch, TensorFlow 2.0 or Flax have been found..." thing # which is annoying and not useful @@ -123,171 +87,3 @@ def silence_transformer_nag(): # log propagation bites us here when using ray logger.propagate = False - - -@dataclass -class WandbConfig: - """ - Configuration for wandb. - """ - - entity: Optional[str] = None # An entity is a username or team name where you send runs - project: Optional[str] = None # The name of the project where you are sending the enw run. - name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. - tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. - id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project - group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. - mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. - resume: Optional[Union[bool, str]] = None # - """ - Set the resume behavior. Options: "allow", "must", "never", "auto" or None. - By default, if the new run has the same ID as a previous run, this run overwrites that data. - Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) - document for more details. - """ - - save_code: Union[bool, str] = True - """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we - typically don't run from the root of the repo).""" - - save_xla_dumps: bool = False - """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - - def init(self, run_id: Optional[str], hparams=None, **extra_hparams): - import wandb - - if run_id is not None and self.id is not None and run_id != self.id: - warnings.warn( - f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" - " config." - ) - - id = self.id - if id is None: - id = run_id - - if hparams is None: - hparams_to_save = {} - elif dataclasses.is_dataclass(hparams): - hparams_to_save = dataclasses.asdict(hparams) - else: - hparams_to_save = dict(hparams) - - if extra_hparams: - hparams_to_save.update(extra_hparams) - - # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled - # however, we do share information about the run id, so that we can link to it from the other workers - mode = self.mode - if jax.process_index() != 0: - mode = "disabled" - - if isinstance(self.save_code, str): - code_dir = self.save_code - elif self.save_code: - code_dir = WandbConfig._infer_experiment_git_root() or "." # type: ignore - else: - code_dir = None - - other_settings = dict() - if code_dir is not None: - logger.info(f"Setting wandb code_dir to {code_dir}") - other_settings["code_dir"] = code_dir - other_settings["git_root"] = code_dir - # for some reason, wandb isn't populating the git commit, so we do it here - try: - repo = Repo(code_dir) - other_settings["git_commit"] = repo.head.commit.hexsha - hparams_to_save["git_commit"] = repo.head.commit.hexsha - except (NoSuchPathError, InvalidGitRepositoryError): - logger.warning(f"Could not find git repo at {code_dir}") - pass - - r = wandb.init( - entity=self.entity, - project=self.project, - name=self.name, - tags=self.tags, - id=id, - group=self.group, - resume=self.resume, - mode=mode, - config=hparams_to_save, - settings=other_settings, - ) - - assert r is not None - - if jax.process_count() > 1: - # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things - metadata_to_share = dict( - entity=r.entity, - project=r.project, - name=r.name, - tags=r.tags, - id=r.id, - group=r.group, - ) - metadata_to_share = jax_utils.multihost_broadcast_sync( - metadata_to_share, is_source=jax.process_index() == 0 - ) - - if jax.process_index() != 0: - assert r.mode == "disabled" - for k, v in metadata_to_share.items(): - setattr(r, k, v) - - logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - - if dataclasses.is_dataclass(hparams): - with tempfile.TemporaryDirectory() as tmpdir: - config_path = os.path.join(tmpdir, "config.yaml") - with open(config_path, "w") as f: - draccus.dump(hparams, f, encoding="utf-8") - if wandb.run is not None: - wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") - - # generate a pip freeze - with tempfile.TemporaryDirectory() as tmpdir: - requirements_path = os.path.join(tmpdir, "requirements.txt") - requirements = _generate_pip_freeze() - with open(requirements_path, "w") as f: - f.write(requirements) - if wandb.run is not None: - wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") - - wandb.summary["num_devices"] = jax.device_count() - wandb.summary["num_hosts"] = jax.process_count() - wandb.summary["backend"] = jax.default_backend() - - @staticmethod - def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: - # sniff out the main directory (since we typically don't run from the root of the repo) - # we'll walk the stack and directories for the files in the stack the until we're at a git root - import os - import traceback - - stack = traceback.extract_stack() - # start from the top of the stack and work our way down since we want to hit the main file first - top_git_root = None - for frame in stack: - dirname = os.path.dirname(frame.filename) - # bit hacky but we want to skip anything that's in the python env - if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): - continue - # see if it's under a git root - try: - repo = Repo(dirname, search_parent_directories=True) - top_git_root = repo.working_dir - break - except (NoSuchPathError, InvalidGitRepositoryError): - logger.debug(f"Skipping {dirname} since it's not a git root") - pass - return top_git_root - - -def _generate_pip_freeze(): - from importlib.metadata import distributions - - dists = distributions() - return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/lora.py b/src/levanter/lora.py index cf7480510..fb89fe1f7 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -363,14 +363,14 @@ def save_peft_checkpoint_callback( If hf_repo is provided, this will upload the checkpoint to the huggingface hub, passing any additional kwargs to the huggingface_hub.upload_folder function. - Args - base_path: the base path to save the checkpoint to. `/step-` will be appended to this. base_path - may be a GCS bucket path, in which case the checkpoint will be uploaded to GCS after being written to a tmp - config: the LoRA config to use - base_model_name_or_path: the name or path of the base model - tokenizer: If provided, will save the tokenizer to the checkpoint - upload_to_hf: the repo to upload to. If a string, will be interpreted as a repo name + branch - hf_upload_kwargs: kwargs to pass to the upload function + Args: + base_path: the base path to save the checkpoint to. `/step-` will be appended to this. base_path + may be a GCS bucket path, in which case the checkpoint will be uploaded to GCS after being written to a tmp + config: the LoRA config to use + base_model_name_or_path: the name or path of the base model + tokenizer: If provided, will save the tokenizer to the checkpoint + upload_to_hf: the repo to upload to. If a string, will be interpreted as a repo name + branch + hf_upload_kwargs: kwargs to pass to the upload function """ def cb(step: StepInfo): diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 84fad654c..077da674d 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -1,14 +1,13 @@ import logging import os -from dataclasses import dataclass - -import wandb +from dataclasses import dataclass, field import levanter -from levanter.data.shard_cache import RichMetricsMonitor, WandbMetricsMonitor, build_cache +from levanter.data.shard_cache import LoggingMetricsMonitor, RichMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig -from levanter.logging import init_logger +from levanter.logging import init_logging +from levanter.tracker import NoopConfig, TrackerConfig logger = logging.getLogger(__name__) @@ -16,19 +15,17 @@ @dataclass class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): - pass + tracker: TrackerConfig = field(default_factory=NoopConfig) @levanter.config.main() def main(args: RayCachedLMDatasetConfig): """Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset""" - init_logger("cache_dataset.log") + init_logging("cache_dataset.log") args.initialize() tokenizer = args.the_tokenizer - wandb.init(mode="offline") - for split in ["train", "validation"]: print(f"Caching {split} to {args.cache_dir}.") # connect or start the actor @@ -40,7 +37,7 @@ def main(args: RayCachedLMDatasetConfig): logger.warning(f"Skipping {split} because it is empty.") continue - monitors = [RichMetricsMonitor(source.num_shards), WandbMetricsMonitor("preprocess/" + split, commit=True)] + monitors = [RichMetricsMonitor(source.num_shards), LoggingMetricsMonitor("preprocess/" + split, commit=True)] cache = build_cache( cache_dir=split_cache_dir, diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index dbe597e30..4e621239e 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -4,7 +4,6 @@ from typing import Optional import jax.random as jrandom -import wandb import haliax.random @@ -95,8 +94,14 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 187a8d92d..60d5dbbb6 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,7 +5,6 @@ from typing import Optional, Union import jax.random as jrandom -import wandb import haliax as hax from haliax import Axis @@ -97,12 +96,14 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): optimizer = config.optimizer.build(config.trainer.num_train_steps) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss) + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics logger + with Trainer(config.trainer, optimizer, compute_loss) as trainer: + eval_datasets = config.data.validation_sets(Pos.size) + train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) - eval_datasets = config.data.validation_sets(Pos.size) - train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) - - with trainer.device_mesh: # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. @@ -129,7 +130,11 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): else: logger.info("No checkpoint found. Starting from scratch.") - wandb.summary["parameter_count"] = parameter_count(state.model) + levanter.tracker.log_summary( + { + "parameter_count": parameter_count(state.model), + } + ) # boilerplate hooks and such trainer.add_default_hooks() diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py new file mode 100644 index 000000000..02edfc9d2 --- /dev/null +++ b/src/levanter/tracker/__init__.py @@ -0,0 +1,16 @@ +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.tracker import CompositeTracker, NoopConfig, NoopTracker, Tracker, TrackerConfig +from levanter.tracker.tracker_fns import current_tracker, get_tracker, jit_log_metrics, log_metrics, log_summary + + +__all__ = [ + "Tracker", + "TrackerConfig", + "CompositeTracker", + "log_optimizer_hyperparams", + "NoopTracker", + "current_tracker", + "jit_log_metrics", + "log_metrics", + "log_summary", +] diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py new file mode 100644 index 000000000..31131d1ac --- /dev/null +++ b/src/levanter/tracker/helpers.py @@ -0,0 +1,71 @@ +import dataclasses +import logging +import os +from typing import Optional + +from git import InvalidGitRepositoryError, NoSuchPathError, Repo +from optax._src.wrappers import MultiStepsState + +import levanter.tracker +from levanter.utils.jax_utils import jnp_to_python + + +logger = logging.getLogger(__name__) + + +def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + + def wrap_key(key): + if prefix: + return f"{prefix}/{key}" + return key + + if hasattr(opt_state, "hyperparams"): + params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} + levanter.tracker.log_metrics(params, step=step) + + +def hparams_to_dict(hparams, **extra_hparams): + if hparams is None: + hparams_to_save = {} + elif dataclasses.is_dataclass(hparams): + hparams_to_save = dataclasses.asdict(hparams) + else: + hparams_to_save = dict(hparams) + if extra_hparams: + hparams_to_save.update(extra_hparams) + return hparams_to_save + + +def infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: + # sniff out the main directory (since we typically don't run from the root of the repo) + # we'll walk the stack and directories for the files in the stack the until we're at a git root + import os + import traceback + + stack = traceback.extract_stack() + # start from the top of the stack and work our way down since we want to hit the main file first + top_git_root = None + for frame in stack: + dirname = os.path.dirname(frame.filename) + # bit hacky but we want to skip anything that's in the python env + if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): + continue + # see if it's under a git root + try: + repo = Repo(dirname, search_parent_directories=True) + top_git_root = repo.working_dir + break + except (NoSuchPathError, InvalidGitRepositoryError): + logger.debug(f"Skipping {dirname} since it's not a git root") + pass + return top_git_root + + +def generate_pip_freeze(): + from importlib.metadata import distributions + + dists = distributions() + return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py new file mode 100644 index 000000000..a028eb24a --- /dev/null +++ b/src/levanter/tracker/tensorboard.py @@ -0,0 +1,86 @@ +import logging +import os +import typing +from dataclasses import dataclass +from typing import Any, Optional + +from levanter.tracker import Tracker, TrackerConfig, helpers + + +pylogger = logging.getLogger(__name__) + +if typing.TYPE_CHECKING: + from tensorboardX import SummaryWriter # noqa: F401 + + +class TensorboardTracker(Tracker): + name: str = "tensorboard" + + def __init__(self, writer: "SummaryWriter"): + self.writer = writer + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.writer.add_hparams(hparams, {"dummy": 0}) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + del commit + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_summary(self, metrics: dict[str, Any]): + for k, v in metrics.items(): + self.writer.add_scalar(k, v, global_step=None) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pylogger.warning("TensorboardLogger does not support logging artifacts yet") + pass + + +@TrackerConfig.register_subclass("tensorboard") +@dataclass +class TensorboardConfig(TrackerConfig): + logdir: str = "tblogs" + comment: Optional[str] = "" + purge_step: Optional[int] = None + max_queue: Optional[int] = 10 + flush_secs: Optional[int] = 120 + filename_suffix: Optional[str] = "" + write_to_disk: Optional[bool] = True + + def init(self, run_id: Optional[str], hparams=None) -> TensorboardTracker: + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + + pylogger.info(f"Writing Tensorboard logs to {dir_to_write}") + + from tensorboardX import SummaryWriter # noqa: F811 + + writer = SummaryWriter( + dir_to_write, + comment=self.comment, + purge_step=self.purge_step, + max_queue=self.max_queue, + flush_secs=self.flush_secs, + filename_suffix=self.filename_suffix, + write_to_disk=self.write_to_disk, + ) + + hparams_dict = helpers.hparams_to_dict(hparams) + hparams_dict = _flatten_nested_dict(hparams_dict) + + writer.add_hparams(hparams_dict, {"dummy": 0}) + + return TensorboardTracker(writer) + + +def _flatten_nested_dict(d): + def items(): + for key, value in d.items(): + if isinstance(value, dict): + for subkey, subvalue in _flatten_nested_dict(value).items(): + yield key + "/" + subkey, subvalue + else: + yield key, value + + return dict(items()) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py new file mode 100644 index 000000000..b9f0b427d --- /dev/null +++ b/src/levanter/tracker/tracker.py @@ -0,0 +1,117 @@ +import abc +import dataclasses +import typing +from typing import Any, List, Optional + +import draccus + + +class Tracker(abc.ABC): + """ + A tracker is responsible for logging metrics, hyperparameters, and artifacts. + Meant to be used with the [levanter.tracker.current_tracker][] context manager, but can also be used directly. + + The name is borrowed from HF Accelerate. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + """ + + name: str + + @abc.abstractmethod + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + @abc.abstractmethod + def log(self, metrics: dict[str, typing.Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the tracker. Step is always required. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + pass + + @abc.abstractmethod + def log_summary(self, metrics: dict[str, Any]): + pass + + @abc.abstractmethod + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + def __enter__(self): + import levanter.tracker.tracker_fns as tracker_fns + + if hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is already set as the global tracker") + setattr(self, "_tracker_cm", tracker_fns.current_tracker(self)) + self._tracker_cm.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is not set as the global tracker") + self._tracker_cm.__exit__(exc_type, exc_val, exc_tb) + delattr(self, "_tracker_cm") + + +class CompositeTracker(Tracker): + def __init__(self, loggers: List[Tracker]): + self.loggers = loggers + + def log_hyperparameters(self, hparams: dict[str, Any]): + for tracker in self.loggers: + tracker.log_hyperparameters(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + for tracker in self.loggers: + tracker.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + for tracker in self.loggers: + tracker.log_summary(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + for tracker in self.loggers: + tracker.log_artifact(artifact, name=name, type=type) + + +class TrackerConfig(draccus.PluginRegistry, abc.ABC): + discover_packages_path = "levanter.tracker" + + @abc.abstractmethod + def init(self, run_id: Optional[str], hparams=None) -> Tracker: + raise NotImplementedError + + @classmethod + def default_choice_name(cls) -> Optional[str]: + return "wandb" + + +class NoopTracker(Tracker): + name: str = "noop" + + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + pass + + def log_summary(self, metrics: dict[str, Any]): + pass + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + +@TrackerConfig.register_subclass("noop") +@dataclasses.dataclass +class NoopConfig(TrackerConfig): + def init(self, run_id: Optional[str], hparams=None) -> Tracker: + return NoopTracker() diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py new file mode 100644 index 000000000..69ab4ca0b --- /dev/null +++ b/src/levanter/tracker/tracker_fns.py @@ -0,0 +1,152 @@ +import typing +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional + +import jax + +from levanter.tracker import CompositeTracker, Tracker +from levanter.tracker.tensorboard import TensorboardTracker +from levanter.tracker.wandb import WandbTracker +from levanter.utils.jax_utils import is_inside_jit + + +_global_tracker: Optional["Tracker"] = None + + +def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the global tracker. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_tracker.log(metrics, step=step) + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global tracker. + + Args: + metrics: Metrics to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log_summary(metrics) + + +@typing.overload +def current_tracker() -> "Tracker": + ... + + +@typing.overload +def current_tracker(tracker: "Tracker") -> typing.ContextManager: + """Returns a context manager for setting the global tracker""" + ... + + +def current_tracker( + tracker: Optional[Tracker] = None, +) -> Tracker | typing.ContextManager: + """ + Get or set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + Args: + tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Returns: + If no tracker is provided, returns the current global tracker. + If a tracker is provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... current_tracker().log_metrics({"foo": 2}, step=1) + """ + global _global_tracker + if tracker is None: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + return _global_tracker + else: + return _GlobalLoggerContextManager(tracker) + + +@typing.overload +def get_tracker(name: Literal["wandb"]) -> WandbTracker: + ... + + +@typing.overload +def get_tracker(name: Literal["tensorboard"]) -> TensorboardTracker: + ... + + +@typing.overload +def get_tracker(name: str) -> Tracker: + ... + + +def get_tracker(name: str) -> Tracker: + """ + Lookup a tracker in the current global tracker with the provided name. + + Args: + name: Name of the tracker to lookup + + Returns: + The tracker with the provided name + + Examples: + >>> from levanter.tracker import get_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... get_tracker("wandb").log_metrics({"foo": 2}, step=1) + """ + tracker = current_tracker() + if isinstance(tracker, CompositeTracker): + for t in tracker.loggers: + if t.name == name: + return t + elif tracker.name == name: + return tracker + + raise KeyError(f"Tracker with name {name} not found") + + +class _GlobalLoggerContextManager(AbstractContextManager): + def __init__(self, tracker: "Tracker"): + self.tracker = tracker + + def __enter__(self): + global _global_tracker + self.old_tracker = _global_tracker + _global_tracker = self.tracker + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_tracker + _global_tracker = self.old_tracker diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py new file mode 100644 index 000000000..2d1422760 --- /dev/null +++ b/src/levanter/tracker/wandb.py @@ -0,0 +1,213 @@ +import dataclasses +import logging +import os +import tempfile +import typing +import warnings +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +import draccus +import jax +from draccus import field +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +from levanter.tracker import Tracker +from levanter.tracker.helpers import generate_pip_freeze, hparams_to_dict, infer_experiment_git_root +from levanter.tracker.tracker import TrackerConfig +from levanter.utils import jax_utils + + +if typing.TYPE_CHECKING: + import wandb + import wandb.sdk.lib.disabled + + +logger = logging.getLogger(__name__) + +WandbRun = Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.disabled.RunDisabled"] + + +class WandbTracker(Tracker): + name: str = "wandb" + run: WandbRun + + def __init__(self, run: Optional[WandbRun]): + import wandb + + if run is None: + if wandb.run is None: + logger.warning("Wandb run is not initialized. Initializing a new run.") + runx = wandb.init() + if runx is None: + raise RuntimeError("Wandb run is not initialized.") + self.run = runx + else: + self.run = wandb.run + else: + self.run = run + + def log_hyperparameters(self, hparams: dict[str, Any]): + if self.run is None: + raise RuntimeError("Must call init before logging hyperparameters") + self.run.config.update(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + if self.run is None: + raise RuntimeError("Must call init before logging metrics") + self.run.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + if self.run is None: + raise RuntimeError("Must call init before logging summary") + self.run.summary.update(metrics) + + def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + if self.run is None: + raise RuntimeError("Must call init before logging artifacts") + self.run.log_artifact(artifact, name=name, type=type) + + +def is_wandb_available(): + try: + import wandb + except ImportError: + return False + return wandb is not None and wandb.run is not None + + +@TrackerConfig.register_subclass("wandb") +@dataclass +class WandbConfig(TrackerConfig): + """ + Configuration for wandb. + """ + + entity: Optional[str] = None # An entity is a username or team name where you send runs + project: Optional[str] = None # The name of the project where you are sending the enw run. + name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. + tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. + id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project + group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. + mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be whatever W&B decides. + resume: Optional[Union[bool, str]] = None + """ + Set the resume behavior. Options: "allow", "must", "never", "auto" or None. + By default, if the new run has the same ID as a previous run, this run overwrites that data. + Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) + document for more details. + """ + + save_code: Union[bool, str] = True + """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we + typically don't run from the root of the repo).""" + + save_xla_dumps: bool = False + """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + + def init(self, run_id: Optional[str], hparams=None) -> WandbTracker: + import wandb + + if run_id is not None and self.id is not None and run_id != self.id: + warnings.warn( + f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" + " config." + ) + + id = self.id + if id is None: + id = run_id + + hparams_to_save = hparams_to_dict(hparams) + + # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled + # however, we do share information about the run id, so that we can link to it from the other workers + if jax.process_index() == 0: + mode = self.mode + else: + mode = "disabled" + + git_settings = self._git_settings() + + if "git_commit" in git_settings: + hparams_to_save["git_commit"] = git_settings["git_commit"] + + r = wandb.init( + entity=self.entity, + project=self.project, + name=self.name, + tags=self.tags, + id=id, + group=self.group, + resume=self.resume, + mode=mode, + config=hparams_to_save, + settings=git_settings, + ) + + assert r is not None + + if jax.process_count() > 1: + # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things + metadata_to_share = dict( + entity=r.entity, + project=r.project, + name=r.name, + tags=r.tags, + id=r.id, + group=r.group, + ) + metadata_to_share = jax_utils.multihost_broadcast_sync( + metadata_to_share, is_source=jax.process_index() == 0 + ) + + if jax.process_index() != 0: + assert r.mode == "disabled" + for k, v in metadata_to_share.items(): + setattr(r, k, v) + + logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") + + if dataclasses.is_dataclass(hparams): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + draccus.dump(hparams, f, encoding="utf-8") + if wandb.run is not None: + wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") + + # generate a pip freeze + with tempfile.TemporaryDirectory() as tmpdir: + requirements_path = os.path.join(tmpdir, "requirements.txt") + requirements = generate_pip_freeze() + with open(requirements_path, "w") as f: + f.write(requirements) + if wandb.run is not None: + wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") + + wandb.summary["num_devices"] = jax.device_count() + wandb.summary["num_hosts"] = jax.process_count() + wandb.summary["backend"] = jax.default_backend() + + return WandbTracker(r) + + def _git_settings(self): + other_settings = dict() + if isinstance(self.save_code, str): + code_dir = self.save_code + elif self.save_code: + code_dir = infer_experiment_git_root() or "." # type: ignore + else: + code_dir = None + if code_dir is not None: + logger.info(f"Setting wandb code_dir to {code_dir}") + other_settings["code_dir"] = code_dir + other_settings["git_root"] = code_dir + # for some reason, wandb isn't populating the git commit, so we do it here + try: + repo = Repo(code_dir) + other_settings["git_commit"] = repo.head.commit.hexsha + except (NoSuchPathError, InvalidGitRepositoryError): + logger.warning(f"Could not find git repo at {code_dir}") + pass + return other_settings diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index aadeb97a8..4636b50d0 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union import equinox as eqx import jax @@ -17,7 +17,6 @@ import jmp import numpy as np import optax -import wandb from draccus import field from jax import ShapeDtypeStruct from jax.experimental import multihost_utils @@ -30,12 +29,16 @@ from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit import levanter.logging +import levanter.tracker +import levanter.tracker.wandb +from levanter import logging, tracker from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded -from levanter.logging import WandbConfig, capture_time +from levanter.logging import capture_time +from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -114,8 +117,10 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks + tracker: levanter.tracker.Tracker is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable + _cmanagers: List[typing.ContextManager] = [] def __init__( self, @@ -141,6 +146,11 @@ def __init__( self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable + if isinstance(config.tracker, Sequence): + self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) + else: + self.tracker = config.tracker.init(self.run_id) + self._cmanagers = [] @cached_property def loss_fn(self): @@ -202,6 +212,34 @@ def TrainBatch(self): def EvalBatch(self): return self.config.EvalBatch + def __enter__(self): + if len(self._cmanagers) > 0: + raise RuntimeError("Trainer is already entered") + + self._cmanagers = [ + levanter.current_tracker(self.tracker), + self.device_mesh, + hax.axis_mapping(self.parameter_axis_mapping), + ] + + for cmanager in self._cmanagers: + cmanager.__enter__() + + return self + + def __exit__(self, *args): + problems = [] + for cmanager in reversed(self._cmanagers): + try: + cmanager.__exit__(*args) + except Exception as e: + problems.append(e) + + self._cmanagers = [] + + if len(problems) > 0: + raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] + def initial_state( self, training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None ) -> TrainerState: @@ -211,51 +249,51 @@ def initial_state( Returns: model, opt_state, key, resume_step """ + with levanter.tracker.current_tracker(self.tracker): + if model is not None and model_init is not None: + raise ValueError("only one of model and model_init should be specified") + elif model is None and model_init is None: + raise ValueError("one of model and model_init must be specified") - if model is not None and model_init is not None: - raise ValueError("only one of model and model_init should be specified") - elif model is None and model_init is None: - raise ValueError("one of model and model_init must be specified") - - if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees - model_init = jax.tree_util.Partial(lambda m: m, model) + if model is not None: + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials + # We can't use plain partials because they aren't pytrees + model_init = jax.tree_util.Partial(lambda m: m, model) - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) + # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones + trainable_model_shape = self.trainable_params_only(model_shape) - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ) + ckpt = self.maybe_load_checkpoint( + trainable_model_shape, + (opt_state_shape, training_key), + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, + ) - if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt - if model is not None: - model = eqx.combine(trainable_model, model) - elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) + if ckpt is not None: + trainable_model, (opt_state, training_key), completed_step = ckpt + if model is not None: + model = eqx.combine(trainable_model, model) + elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): + # if we're resuming, we need to re-initialize the non-trainable parameters to their original values + non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) + model = eqx.combine(trainable_model, non_trainable) + else: + model = trainable_model + step = completed_step + 1 else: - model = trainable_model - step = completed_step + 1 - else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) - step = 0 + model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) + step = 0 - return TrainerState(step, model, opt_state, training_key) + return TrainerState(step, model, opt_state, training_key) def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ - with capture_time() as step_time: + with capture_time() as step_time, levanter.current_tracker(self.tracker): key, new_key = jax.random.split(state.training_key) loss, new_model, new_optstate = self._train_step_fn( state.model, state.opt_state, *batch, **batch_kwargs, key=key @@ -272,24 +310,23 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) + with levanter.current_tracker(self.tracker): + while state.step < self.config.num_train_steps: + with capture_time() as loading_time: + example = next(iter_data) - while state.step < self.config.num_train_steps: - with capture_time() as loading_time: - example = next(iter_data) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) - # TODO: refactor logging - wandb.log({"throughput/loading_time": loading_time()}, step=state.step) + info = self.train_step(state, example) + state = info.state - info = self.train_step(state, example) - state = info.state + if run_hooks: + with capture_time() as hook_time: + self.run_hooks(info) - if run_hooks: - with capture_time() as hook_time: - self.run_hooks(info) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) - wandb.log({"throughput/hook_time": hook_time()}, step=state.step) - - yield info + yield info def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ @@ -308,10 +345,9 @@ def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) - self.add_hook(callbacks.log_to_wandb, every=1) + self.add_hook(callbacks.log_step_info, every=1) if eval_dataset is not None: self.add_eval_hook(eval_dataset) - self.add_hook(callbacks.wandb_xla_logger(self.config.wandb), every=self.config.steps_per_eval) # engine.add_hook(callbacks.log_memory_usage(), every=1) checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency @@ -465,11 +501,13 @@ class TrainerConfig: seed: int = 0 # random seed mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy - wandb: WandbConfig = field(default_factory=WandbConfig) + wandb: Optional[tracker.wandb.WandbConfig] = None log_dir: Path = Path("logs/") run_base_dir: Path = Path("runs/") id: Optional[str] = None # run id. if None, will be set to a random string + tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig) + # config related to partitioning batch_axis: Optional[str] = "batch" # Batch axis for data parallel. @@ -516,15 +554,6 @@ class TrainerConfig: # whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes shutdown_at_exit: Union[bool, float] = False - @property - def run_name(self) -> str: - try: - import wandb - - return wandb.run and (wandb.run.name or wandb.run.id) or "unnamed" - except ImportError: - return "unnamed" - @property def TrainBatch(self): return Axis("batch", self.train_batch_size) @@ -533,15 +562,20 @@ def TrainBatch(self): def EvalBatch(self): return Axis("batch", self.eval_batch_size) + def __post_init__(self): + if self.wandb is not None: + warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) + self.tracker = self.wandb + def initialize(self, all_config): - """Initializes jax, wandb, logging, setting the run name/id in the process""" - self.distributed.initialize() - self._maybe_set_id() - self.ray.initialize() + """Initializes jax, logging, setting the run name/id in the process""" self._initialize_jax_config() + self.distributed.initialize() self._validate_and_set_defaults() - self.wandb.init(self.id, all_config) + + self._maybe_set_id() self._initialize_logging() + self.ray.initialize() if self.require_accelerator is None: self.require_accelerator = not sys.platform.startswith("darwin") @@ -608,7 +642,7 @@ def _initialize_jax_config(self): def _initialize_logging(self): self.log_dir.mkdir(parents=True, exist_ok=True) - levanter.logging.init_logger(self.log_dir / f"{self.id}.log") + levanter.logging.init_logging(self.log_dir / f"{self.id}.log") def _maybe_set_id(self): # always do this so we don't get weird hangs if the id isn't set right @@ -622,7 +656,7 @@ def _maybe_set_id(self): # TODO: this doesn't work with wandb sweeps. need to reconcile when we merge if "RUN_ID" in os.environ: self.id = os.environ["RUN_ID"] - elif self.wandb.id is not None: + elif self.wandb is not None and self.wandb.id is not None: self.id = self.wandb.id else: # wandb run ids are 8 characters [a-z0-9], which we'll emulate here diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index cb9cd915d..038c5e9b5 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -41,6 +41,11 @@ def use_cpu_device(): yield +def is_inside_jit(): + """Returns True if we're currently inside a jit""" + return isinstance(jnp.zeros(()), jax.core.Tracer) + + def flops_estimate(fn, *args): """Estimates the flop count of a function using XLA/HLO fanciness. See https://github.com/google/flax/discussions/1854""" diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index f1193f4f4..178069f26 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -11,8 +11,8 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_logging.py b/tests/test_logging.py index dc74c78ed..7c537b182 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from levanter.logging import WandbConfig +from levanter.tracker.helpers import infer_experiment_git_root def test_infer_experiment_git_root(): @@ -13,7 +13,7 @@ def test_infer_experiment_git_root(): except (InvalidGitRepositoryError, NoSuchPathError): pytest.skip("test not running in a git repo") - root = WandbConfig._infer_experiment_git_root() + root = infer_experiment_git_root() # ensure that 1) this is a git root and 2) this source file is underneath assert root is not None diff --git a/tests/test_tracker.py b/tests/test_tracker.py new file mode 100644 index 000000000..15485b83e --- /dev/null +++ b/tests/test_tracker.py @@ -0,0 +1,80 @@ +# NOTE: Do not explicitly import wandb/other trackers here, as this will cause the tests to trivially pass. +import dataclasses +from typing import Tuple + +import pytest +import yaml + +import levanter.tracker +from levanter.tracker import CompositeTracker, TrackerConfig + + +def test_tracker_plugin_stuff_works(): + assert TrackerConfig.get_choice_class("wandb") is not None + with pytest.raises(KeyError): + TrackerConfig.get_choice_class("foo") + + +def test_tracker_plugin_default_works(): + config = """ + tracker: + entity: foo + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig + + import draccus + + tconfig = draccus.decode(ConfigHolder, parsed).tracker + + assert isinstance(tconfig, TrackerConfig.get_choice_class("wandb")) + + assert tconfig.entity == "foo" # type: ignore + + +def test_tracker_plugin_multi_parsing_work(): + config = """ + tracker: + type: noop + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig | Tuple[TrackerConfig, ...] + + import draccus + + from levanter.tracker.tracker import NoopConfig + + assert isinstance(draccus.decode(ConfigHolder, parsed).tracker, NoopConfig) + + config = """ + tracker: + - type: noop + - type: wandb + """ + parsed = yaml.safe_load(config) + decoded = draccus.decode(ConfigHolder, parsed).tracker + assert decoded == (NoopConfig(), TrackerConfig.get_choice_class("wandb")()) + + +def test_get_tracker_by_name(): + wandb_config = TrackerConfig.get_choice_class("wandb") + if wandb_config is None: + pytest.skip("wandb not installed") + + from levanter.tracker import NoopTracker + + wandb1 = wandb_config(mode="disabled").init(None) + tracker = CompositeTracker([wandb1, NoopTracker()]) + + with tracker: + assert levanter.tracker.get_tracker("wandb") is wandb1 + assert levanter.tracker.get_tracker("noop") is not None + + with pytest.raises(KeyError): + levanter.tracker.get_tracker("foo") diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index 3cd762d8b..f95b27efb 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -8,7 +8,7 @@ import levanter.main.train_lm as train_lm import tiny_test_corpus from levanter.distributed import RayConfig -from levanter.logging import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 665c98772..cf4fb74a6 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -11,8 +11,8 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count From 2387f26e8852c9f90f5cfcdecf7edd888710da83 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 27 Nov 2023 22:43:36 -0800 Subject: [PATCH 048/205] wip --- src/levanter/tracker/tensorboard.py | 2 +- src/levanter/tracker/tracker.py | 4 ++-- src/levanter/tracker/wandb.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index a028eb24a..e657ae47f 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -47,7 +47,7 @@ class TensorboardConfig(TrackerConfig): filename_suffix: Optional[str] = "" write_to_disk: Optional[bool] = True - def init(self, run_id: Optional[str], hparams=None) -> TensorboardTracker: + def init(self, run_id: Optional[str], hparams) -> TensorboardTracker: dir_to_write = self.logdir if run_id is not None: dir_to_write = os.path.join(dir_to_write, run_id) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index b9f0b427d..b335dc6ad 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -86,7 +86,7 @@ class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" @abc.abstractmethod - def init(self, run_id: Optional[str], hparams=None) -> Tracker: + def init(self, run_id: Optional[str], hparams) -> Tracker: raise NotImplementedError @classmethod @@ -113,5 +113,5 @@ def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[s @TrackerConfig.register_subclass("noop") @dataclasses.dataclass class NoopConfig(TrackerConfig): - def init(self, run_id: Optional[str], hparams=None) -> Tracker: + def init(self, run_id: Optional[str], hparams) -> Tracker: return NoopTracker() diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 2d1422760..d6994abae 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -105,7 +105,7 @@ class WandbConfig(TrackerConfig): save_xla_dumps: bool = False """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - def init(self, run_id: Optional[str], hparams=None) -> WandbTracker: + def init(self, run_id: Optional[str], hparams) -> WandbTracker: import wandb if run_id is not None and self.id is not None and run_id != self.id: From 6446bc04a4acc061530fd50a24d326b2cb178b10 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 28 Nov 2023 12:41:49 -0800 Subject: [PATCH 049/205] just about workable logger stuff --- examples/alpaca-lora/alpaca_lora.py | 5 ++-- examples/alpaca/alpaca.py | 3 +- src/levanter/__init__.py | 2 +- src/levanter/callbacks.py | 2 +- src/levanter/main/eval_lm.py | 2 +- src/levanter/main/lora_lm.py | 2 +- src/levanter/main/train_lm.py | 43 +++++++++++++++-------------- src/levanter/main/viz_logprobs.py | 6 ++-- src/levanter/tracker/__init__.py | 11 +++++++- src/levanter/tracker/tensorboard.py | 11 ++------ src/levanter/tracker/tracker.py | 2 +- src/levanter/tracker/tracker_fns.py | 16 +++++++++++ src/levanter/tracker/wandb.py | 27 +++++++----------- src/levanter/trainer.py | 2 +- 14 files changed, 74 insertions(+), 60 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 31de93252..c1fa64394 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -48,7 +48,7 @@ class TrainArgs(alpaca.TrainArgs): def train(config: TrainArgs): - config.trainer.initialize(config) + config.trainer.initialize() # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -103,6 +103,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # end major difference from Alpaca with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + levanter.tracker.log_hyperparameters(config) trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) @@ -111,7 +112,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - levanter.log_summary( + levanter.tracker.log_summary( { "parameter_count": all_param_count, "trainable_parameter_count": just_lora_params, diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 20cb98a33..6ceafaa1a 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -181,7 +181,7 @@ def get_prompts(prompt_path): def train(config: TrainArgs): - config.trainer.initialize(config) + config.trainer.initialize() # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -218,6 +218,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() with Trainer(config.trainer, optimizer, compute_loss) as trainer: + levanter.tracker.log_hyperparameters(config) # how we shard parameters across devices parameter_axis_mapping = trainer.parameter_axis_mapping diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index d89ea4945..d53046522 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -5,4 +5,4 @@ import levanter.logging as logging import levanter.tracker as tracker import levanter.visualization as visualization -from levanter.tracker import current_tracker, get_tracker, log_metrics, log_summary +from levanter.tracker import current_tracker, get_tracker diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 154099e8a..a80d0619e 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -62,7 +62,7 @@ def compute_loss(info: StepInfo): prefix = "eval" if name: prefix += "/" + name - levanter.log_metrics({f"{prefix}/loss": loss}, step=info.step) + levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index bea7a5e2b..d27d04eb6 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -41,7 +41,7 @@ class EvalLmConfig: def main(config: EvalLmConfig): - config.trainer.initialize(config) + config.trainer.initialize() tokenizer = config.data.the_tokenizer Batch = Axis("batch", config.trainer.eval_batch_size) diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 4e621239e..e29091356 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -53,7 +53,7 @@ def main(config: LoraLmConfig): converter = converter.replaced(tokenizer=tokenizer) - config.trainer.initialize(config) + config.trainer.initialize() model_config = converter.default_config # randomness in jax is tightly controlled by "keys" which are the states of the random number generators diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 60d5dbbb6..dd0274e73 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -46,6 +46,8 @@ class TrainLmConfig: def main(config: TrainLmConfig): + config.trainer.initialize() + tokenizer = config.data.the_tokenizer # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, @@ -71,36 +73,35 @@ def main(config: TrainLmConfig): else: converter = None - # initialize training config *after* we've done the hf stuff b/c we might have changed the model config - config.trainer.initialize(config) - - # randomness in jax is tightly controlled by "keys" which are the states of the random number generators - # this makes deterministic training pretty easy - seed = config.trainer.seed - data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) - - # some axes we need - Batch = config.trainer.TrainBatch - EvalBatch = config.trainer.EvalBatch - Pos = config.model.Pos - KeyPos = config.model.KeyPos - - # We have two axis_mappings: one for storing the model and optimizer states, and one for compute - # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh - compute_axis_mapping = config.trainer.compute_axis_mapping - parameter_axis_mapping = config.trainer.parameter_axis_mapping + optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - optimizer = config.optimizer.build(config.trainer.num_train_steps) - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp # Using the trainer as a context manager does 3 things: # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) - # 3. Sets the global metrics logger + # 3. Sets the global metrics tracker with Trainer(config.trainer, optimizer, compute_loss) as trainer: + levanter.tracker.log_hyperparameters(config) + + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.model.Pos + KeyPos = config.model.KeyPos + eval_datasets = config.data.validation_sets(Pos.size) train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index ad85a0c7d..61903d01f 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -36,12 +36,11 @@ class VizGpt2Config: def main(config: VizGpt2Config): - config.trainer.initialize(config) + config.trainer.initialize() tokenizer = config.data.the_tokenizer - EvalBatch = Axis("batch", config.trainer.eval_batch_size) - # some axes we use outside the model proper + EvalBatch = config.trainer.EvalBatch Pos = config.model.Pos KeyPos = config.model.KeyPos @@ -53,7 +52,6 @@ def main(config: VizGpt2Config): # some axes we use outside the model proper Pos = config.model.Pos - KeyPos = config.model.KeyPos compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py index 02edfc9d2..b9d5ce61c 100644 --- a/src/levanter/tracker/__init__.py +++ b/src/levanter/tracker/__init__.py @@ -1,6 +1,13 @@ from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.tracker import CompositeTracker, NoopConfig, NoopTracker, Tracker, TrackerConfig -from levanter.tracker.tracker_fns import current_tracker, get_tracker, jit_log_metrics, log_metrics, log_summary +from levanter.tracker.tracker_fns import ( + current_tracker, + get_tracker, + jit_log_metrics, + log_hyperparameters, + log_metrics, + log_summary, +) __all__ = [ @@ -10,7 +17,9 @@ "log_optimizer_hyperparams", "NoopTracker", "current_tracker", + "get_tracker", "jit_log_metrics", "log_metrics", "log_summary", + "log_hyperparameters", ] diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index e657ae47f..0e716a91b 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Any, Optional -from levanter.tracker import Tracker, TrackerConfig, helpers +from levanter.tracker import Tracker, TrackerConfig pylogger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def log_summary(self, metrics: dict[str, Any]): self.writer.add_scalar(k, v, global_step=None) def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - pylogger.warning("TensorboardLogger does not support logging artifacts yet") + pylogger.error("TensorboardLogger does not support logging artifacts yet") pass @@ -47,7 +47,7 @@ class TensorboardConfig(TrackerConfig): filename_suffix: Optional[str] = "" write_to_disk: Optional[bool] = True - def init(self, run_id: Optional[str], hparams) -> TensorboardTracker: + def init(self, run_id: Optional[str]) -> TensorboardTracker: dir_to_write = self.logdir if run_id is not None: dir_to_write = os.path.join(dir_to_write, run_id) @@ -66,11 +66,6 @@ def init(self, run_id: Optional[str], hparams) -> TensorboardTracker: write_to_disk=self.write_to_disk, ) - hparams_dict = helpers.hparams_to_dict(hparams) - hparams_dict = _flatten_nested_dict(hparams_dict) - - writer.add_hparams(hparams_dict, {"dummy": 0}) - return TensorboardTracker(writer) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index b335dc6ad..c213fb227 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -86,7 +86,7 @@ class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" @abc.abstractmethod - def init(self, run_id: Optional[str], hparams) -> Tracker: + def init(self, run_id: Optional[str]) -> Tracker: raise NotImplementedError @classmethod diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index 69ab4ca0b..0d4a32415 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -5,6 +5,7 @@ import jax from levanter.tracker import CompositeTracker, Tracker +from levanter.tracker.helpers import hparams_to_dict from levanter.tracker.tensorboard import TensorboardTracker from levanter.tracker.wandb import WandbTracker from levanter.utils.jax_utils import is_inside_jit @@ -54,6 +55,21 @@ def log_summary(metrics: dict[str, Any]): _global_tracker.log_summary(metrics) +def log_hyperparameters(hparams: Any): + """ + Log hyperparameters to the global tracker. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + hparams_dict = hparams_to_dict(hparams) + _global_tracker.log_hyperparameters(hparams_dict) + + @typing.overload def current_tracker() -> "Tracker": ... diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index d6994abae..ab7f13f34 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -48,23 +48,15 @@ def __init__(self, run: Optional[WandbRun]): self.run = run def log_hyperparameters(self, hparams: dict[str, Any]): - if self.run is None: - raise RuntimeError("Must call init before logging hyperparameters") self.run.config.update(hparams) def log(self, metrics: dict[str, Any], *, step, commit=None): - if self.run is None: - raise RuntimeError("Must call init before logging metrics") self.run.log(metrics, step=step, commit=commit) def log_summary(self, metrics: dict[str, Any]): - if self.run is None: - raise RuntimeError("Must call init before logging summary") self.run.summary.update(metrics) def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - if self.run is None: - raise RuntimeError("Must call init before logging artifacts") self.run.log_artifact(artifact, name=name, type=type) @@ -105,7 +97,7 @@ class WandbConfig(TrackerConfig): save_xla_dumps: bool = False """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - def init(self, run_id: Optional[str], hparams) -> WandbTracker: + def init(self, run_id: Optional[str]) -> WandbTracker: import wandb if run_id is not None and self.id is not None and run_id != self.id: @@ -118,7 +110,7 @@ def init(self, run_id: Optional[str], hparams) -> WandbTracker: if id is None: id = run_id - hparams_to_save = hparams_to_dict(hparams) + hparams_to_save = {} # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled # however, we do share information about the run id, so that we can link to it from the other workers @@ -168,13 +160,14 @@ def init(self, run_id: Optional[str], hparams) -> WandbTracker: logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - if dataclasses.is_dataclass(hparams): - with tempfile.TemporaryDirectory() as tmpdir: - config_path = os.path.join(tmpdir, "config.yaml") - with open(config_path, "w") as f: - draccus.dump(hparams, f, encoding="utf-8") - if wandb.run is not None: - wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") + # TODO: bring this back? + # if dataclasses.is_dataclass(hparams): + # with tempfile.TemporaryDirectory() as tmpdir: + # config_path = os.path.join(tmpdir, "config.yaml") + # with open(config_path, "w") as f: + # draccus.dump(hparams, f, encoding="utf-8") + # if wandb.run is not None: + # wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") # generate a pip freeze with tempfile.TemporaryDirectory() as tmpdir: diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 4636b50d0..e85b7f74e 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -567,7 +567,7 @@ def __post_init__(self): warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) self.tracker = self.wandb - def initialize(self, all_config): + def initialize(self): """Initializes jax, logging, setting the run name/id in the process""" self._initialize_jax_config() self.distributed.initialize() From 1b821d17f6e9b2d3b682ad7934b98b466a5f8bbf Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 28 Nov 2023 15:21:17 -0800 Subject: [PATCH 050/205] fix logging of config with a new levanter.initialize --- examples/alpaca-lora/alpaca_lora.py | 4 +- examples/alpaca/alpaca.py | 3 +- src/levanter/__init__.py | 2 + src/levanter/logging.py | 6 +- src/levanter/main/eval_lm.py | 2 +- src/levanter/main/lora_lm.py | 2 +- src/levanter/main/train_lm.py | 4 +- src/levanter/main/viz_logprobs.py | 2 +- src/levanter/tracker/__init__.py | 4 + src/levanter/tracker/tensorboard.py | 2 +- src/levanter/tracker/tracker.py | 10 +- src/levanter/tracker/tracker_fns.py | 56 +++++++++- src/levanter/tracker/wandb.py | 17 +-- src/levanter/trainer.py | 160 +++++++++++++++++----------- 14 files changed, 178 insertions(+), 96 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index c1fa64394..0e7c5790e 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -48,7 +48,7 @@ class TrainArgs(alpaca.TrainArgs): def train(config: TrainArgs): - config.trainer.initialize() + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -103,8 +103,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # end major difference from Alpaca with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - levanter.tracker.log_hyperparameters(config) - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 6ceafaa1a..647b1d10f 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -181,7 +181,7 @@ def get_prompts(prompt_path): def train(config: TrainArgs): - config.trainer.initialize() + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -218,7 +218,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() with Trainer(config.trainer, optimizer, compute_loss) as trainer: - levanter.tracker.log_hyperparameters(config) # how we shard parameters across devices parameter_axis_mapping = trainer.parameter_axis_mapping diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index d53046522..4153c0711 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -4,5 +4,7 @@ import levanter.distributed as distributed import levanter.logging as logging import levanter.tracker as tracker +import levanter.trainer as trainer import levanter.visualization as visualization from levanter.tracker import current_tracker, get_tracker +from levanter.trainer import initialize diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 23cf63047..1ce936dbe 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -10,13 +10,17 @@ pylogger = pylogging.getLogger(__name__) -def init_logging(path: Union[str, Path], level: int = pylogging.INFO) -> None: +def init_logging(log_dir: Union[str, Path], run_id: str, level: int = pylogging.INFO) -> None: """ Initialize logging.Logger with the appropriate name, console, and file handlers. :param path: Path for writing log file :param level: Default logging level """ + log_dir = Path(log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + path = log_dir / f"{run_id}.log" + process_index = jax.process_index() log_format = f"%(asctime)s - {process_index} - %(name)s - %(filename)s:%(lineno)d - %(levelname)s :: %(message)s" # use ISO 8601 format for timestamps, except no TZ, because who cares diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index d27d04eb6..9202f8c4b 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -41,7 +41,7 @@ class EvalLmConfig: def main(config: EvalLmConfig): - config.trainer.initialize() + levanter.initialize(config) tokenizer = config.data.the_tokenizer Batch = Axis("batch", config.trainer.eval_batch_size) diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index e29091356..dd005e5f4 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -45,6 +45,7 @@ class LoraLmConfig: def main(config: LoraLmConfig): + levanter.initialize(config) tokenizer = config.data.the_tokenizer converter = HFCheckpointConverter.from_hf(config.initialize_from_hf, trust_remote_code=config.trust_remote_code) @@ -53,7 +54,6 @@ def main(config: LoraLmConfig): converter = converter.replaced(tokenizer=tokenizer) - config.trainer.initialize() model_config = converter.default_config # randomness in jax is tightly controlled by "keys" which are the states of the random number generators diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index dd0274e73..6ce28ec05 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -46,7 +46,7 @@ class TrainLmConfig: def main(config: TrainLmConfig): - config.trainer.initialize() + levanter.initialize(config) tokenizer = config.data.the_tokenizer @@ -84,8 +84,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker with Trainer(config.trainer, optimizer, compute_loss) as trainer: - levanter.tracker.log_hyperparameters(config) - # randomness in jax is tightly controlled by "keys" which are the states of the random number generators # this makes deterministic training pretty easy seed = config.trainer.seed diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 61903d01f..28c8e0294 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -36,7 +36,7 @@ class VizGpt2Config: def main(config: VizGpt2Config): - config.trainer.initialize() + levanter.initialize(config) tokenizer = config.data.the_tokenizer # some axes we use outside the model proper diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py index b9d5ce61c..69156c6a6 100644 --- a/src/levanter/tracker/__init__.py +++ b/src/levanter/tracker/__init__.py @@ -4,9 +4,11 @@ current_tracker, get_tracker, jit_log_metrics, + log_configuration, log_hyperparameters, log_metrics, log_summary, + set_global_tracker, ) @@ -19,7 +21,9 @@ "current_tracker", "get_tracker", "jit_log_metrics", + "log_configuration", "log_metrics", "log_summary", "log_hyperparameters", + "set_global_tracker", ] diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index 0e716a91b..bd3ee70ba 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -31,7 +31,7 @@ def log_summary(self, metrics: dict[str, Any]): for k, v in metrics.items(): self.writer.add_scalar(k, v, global_step=None) - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pylogger.error("TensorboardLogger does not support logging artifacts yet") pass diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index c213fb227..8b6816f17 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -43,7 +43,7 @@ def log_summary(self, metrics: dict[str, Any]): pass @abc.abstractmethod - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass def __enter__(self): @@ -77,9 +77,9 @@ def log_summary(self, metrics: dict[str, Any]): for tracker in self.loggers: tracker.log_summary(metrics) - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): for tracker in self.loggers: - tracker.log_artifact(artifact, name=name, type=type) + tracker.log_artifact(artifact_path, name=name, type=type) class TrackerConfig(draccus.PluginRegistry, abc.ABC): @@ -106,12 +106,12 @@ def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): def log_summary(self, metrics: dict[str, Any]): pass - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass @TrackerConfig.register_subclass("noop") @dataclasses.dataclass class NoopConfig(TrackerConfig): - def init(self, run_id: Optional[str], hparams) -> Tracker: + def init(self, run_id: Optional[str]) -> Tracker: return NoopTracker() diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index 0d4a32415..accd16b68 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -1,7 +1,12 @@ +import dataclasses +import os +import tempfile import typing +import warnings from contextlib import AbstractContextManager from typing import Any, Literal, Optional +import draccus import jax from levanter.tracker import CompositeTracker, Tracker @@ -55,7 +60,7 @@ def log_summary(metrics: dict[str, Any]): _global_tracker.log_summary(metrics) -def log_hyperparameters(hparams: Any): +def log_hyperparameters(hparams: dict[str, Any]): """ Log hyperparameters to the global tracker. @@ -66,9 +71,56 @@ def log_hyperparameters(hparams: Any): if _global_tracker is None: raise RuntimeError("No global tracker set") + _global_tracker.log_hyperparameters(hparams) + + +def log_configuration(hparams: Any, config_name: Optional[str] = None): + """ + Logs a configuration object to the global tracker. If the configuration object is a dataclass, + it is dumped to a yaml file and logged as an artifact. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + hparams_dict = hparams_to_dict(hparams) _global_tracker.log_hyperparameters(hparams_dict) + if dataclasses.is_dataclass(hparams): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + draccus.dump(hparams, f, encoding="utf-8") + name = config_name or "config.yaml" + _global_tracker.log_artifact(config_path, name=name, type="config") + + +def set_global_tracker(tracker: Tracker): + """ + Set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + In general, it's preferred to use the context manager returned by `current_tracker` instead of this function + except for once at the beginning of the program. + + Args: + tracker: The tracker to set as the global tracker + force: Whether to force setting the global tracker even if it is already set + + Examples: + >>> from levanter.tracker import set_global_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> set_global_tracker(WandbTracker()) + >>> log_metrics({"foo": 1}, step=0) + """ + global _global_tracker + if _global_tracker is not None: + warnings.warn("Global tracker is already set. Overwriting it.") + _global_tracker = tracker + @typing.overload def current_tracker() -> "Tracker": @@ -100,7 +152,7 @@ def current_tracker( >>> from levanter.tracker.wandb import WandbTracker >>> with current_tracker(WandbTracker()): ... log_metrics({"foo": 1}, step=0) - ... current_tracker().log_metrics({"foo": 2}, step=1) + ... current_tracker().log({"foo": 2}, step=1) """ global _global_tracker if tracker is None: diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index ab7f13f34..cf4147351 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -1,4 +1,3 @@ -import dataclasses import logging import os import tempfile @@ -7,13 +6,12 @@ from dataclasses import dataclass from typing import Any, List, Optional, Union -import draccus import jax from draccus import field from git import InvalidGitRepositoryError, NoSuchPathError, Repo from levanter.tracker import Tracker -from levanter.tracker.helpers import generate_pip_freeze, hparams_to_dict, infer_experiment_git_root +from levanter.tracker.helpers import generate_pip_freeze, infer_experiment_git_root from levanter.tracker.tracker import TrackerConfig from levanter.utils import jax_utils @@ -56,8 +54,8 @@ def log(self, metrics: dict[str, Any], *, step, commit=None): def log_summary(self, metrics: dict[str, Any]): self.run.summary.update(metrics) - def log_artifact(self, artifact, *, name: Optional[str] = None, type: Optional[str] = None): - self.run.log_artifact(artifact, name=name, type=type) + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + self.run.log_artifact(artifact_path, name=name, type=type) def is_wandb_available(): @@ -160,15 +158,6 @@ def init(self, run_id: Optional[str]) -> WandbTracker: logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - # TODO: bring this back? - # if dataclasses.is_dataclass(hparams): - # with tempfile.TemporaryDirectory() as tmpdir: - # config_path = os.path.join(tmpdir, "config.yaml") - # with open(config_path, "w") as f: - # draccus.dump(hparams, f, encoding="utf-8") - # if wandb.run is not None: - # wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") - # generate a pip freeze with tempfile.TemporaryDirectory() as tmpdir: requirements_path = os.path.join(tmpdir, "requirements.txt") diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index e85b7f74e..f7dccb130 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -9,7 +9,21 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, +) import equinox as eqx import jax @@ -31,7 +45,7 @@ import levanter.logging import levanter.tracker import levanter.tracker.wandb -from levanter import logging, tracker +from levanter import tracker from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader @@ -146,10 +160,7 @@ def __init__( self._raw_loss_function = loss_fn self.optimizer = optimizer self.is_trainable_param = is_trainable - if isinstance(config.tracker, Sequence): - self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) - else: - self.tracker = config.tracker.init(self.run_id) + self._cmanagers = [] @cached_property @@ -217,7 +228,7 @@ def __enter__(self): raise RuntimeError("Trainer is already entered") self._cmanagers = [ - levanter.current_tracker(self.tracker), + # levanter.current_tracker(self.tracker), self.device_mesh, hax.axis_mapping(self.parameter_axis_mapping), ] @@ -249,51 +260,50 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - with levanter.tracker.current_tracker(self.tracker): - if model is not None and model_init is not None: - raise ValueError("only one of model and model_init should be specified") - elif model is None and model_init is None: - raise ValueError("one of model and model_init must be specified") + if model is not None and model_init is not None: + raise ValueError("only one of model and model_init should be specified") + elif model is None and model_init is None: + raise ValueError("one of model and model_init must be specified") + + if model is not None: + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials + # We can't use plain partials because they aren't pytrees + model_init = jax.tree_util.Partial(lambda m: m, model) + + model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + + # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones + trainable_model_shape = self.trainable_params_only(model_shape) + + ckpt = self.maybe_load_checkpoint( + trainable_model_shape, + (opt_state_shape, training_key), + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, + ) + if ckpt is not None: + trainable_model, (opt_state, training_key), completed_step = ckpt if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees - model_init = jax.tree_util.Partial(lambda m: m, model) - - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) - - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) - - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ) - - if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt - if model is not None: - model = eqx.combine(trainable_model, model) - elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) - else: - model = trainable_model - step = completed_step + 1 + model = eqx.combine(trainable_model, model) + elif any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(trainable_model)): + # if we're resuming, we need to re-initialize the non-trainable parameters to their original values + non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) + model = eqx.combine(trainable_model, non_trainable) else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) - step = 0 + model = trainable_model + step = completed_step + 1 + else: + model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) + step = 0 - return TrainerState(step, model, opt_state, training_key) + return TrainerState(step, model, opt_state, training_key) def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ - with capture_time() as step_time, levanter.current_tracker(self.tracker): + with capture_time() as step_time: key, new_key = jax.random.split(state.training_key) loss, new_model, new_optstate = self._train_step_fn( state.model, state.opt_state, *batch, **batch_kwargs, key=key @@ -310,23 +320,23 @@ def training_steps( Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) - with levanter.current_tracker(self.tracker): - while state.step < self.config.num_train_steps: - with capture_time() as loading_time: - example = next(iter_data) + # with levanter.current_tracker(self.tracker): + while state.step < self.config.num_train_steps: + with capture_time() as loading_time: + example = next(iter_data) - levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) - info = self.train_step(state, example) - state = info.state + info = self.train_step(state, example) + state = info.state - if run_hooks: - with capture_time() as hook_time: - self.run_hooks(info) + if run_hooks: + with capture_time() as hook_time: + self.run_hooks(info) - levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) - yield info + yield info def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ @@ -496,6 +506,15 @@ def maybe_load_checkpoint( return None +def _initialize_global_tracker(config, run_id): + if isinstance(config, Sequence): + tracker = levanter.tracker.CompositeTracker([c.init(run_id) for c in config]) + else: + tracker = config.init(run_id) + + levanter.tracker.set_global_tracker(tracker) + + @dataclass class TrainerConfig: seed: int = 0 # random seed @@ -573,15 +592,18 @@ def initialize(self): self.distributed.initialize() self._validate_and_set_defaults() - self._maybe_set_id() - self._initialize_logging() + id = self._maybe_set_id() + levanter.logging.init_logging(self.log_dir, f"{id}.log") + _initialize_global_tracker(self.tracker, id) + self.ray.initialize() if self.require_accelerator is None: self.require_accelerator = not sys.platform.startswith("darwin") if self.require_accelerator: - assert jax.default_backend() != "cpu", "Accelerator required but not found" + if jax.default_backend() == "cpu": + raise RuntimeError("No accelerator found. Please run on a TPU or GPU.") if self.shutdown_at_exit is not False: if isinstance(self.shutdown_at_exit, bool): @@ -640,10 +662,6 @@ def _initialize_jax_config(self): for key, value in self.jax_config.items(): jax.config.update(key, value) - def _initialize_logging(self): - self.log_dir.mkdir(parents=True, exist_ok=True) - levanter.logging.init_logging(self.log_dir / f"{self.id}.log") - def _maybe_set_id(self): # always do this so we don't get weird hangs if the id isn't set right # for random ids, we want to ensure that all hosts have the same id @@ -667,6 +685,8 @@ def _maybe_set_id(self): logger.info(f"Setting run id to {self.id}") + return self.id + # we can't do this in post_init because we don't want to call jax.device_count before calling distributed.initialize def _validate_and_set_defaults(self): if jax.device_count() % self.model_axis_size != 0: @@ -694,6 +714,22 @@ def _validate_and_set_defaults(self): self.per_device_eval_parallelism = self.per_device_parallelism +class AllConfig(Protocol): + trainer: TrainerConfig + + +def initialize(config: TrainerConfig | AllConfig): + """Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config + as hyperparameters and an artifact""" + if isinstance(config, TrainerConfig): + trainer_config = config + else: + trainer_config = config.trainer + + trainer_config.initialize() + levanter.tracker.log_configuration(config) + + @dataclass class OptimizerConfig: # Config related to optimizer (always adam for now) From afb645976e0727f915f922ed159041ec29414da6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 28 Nov 2023 15:29:03 -0800 Subject: [PATCH 051/205] missed a sopt --- src/levanter/main/cache_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 077da674d..e22e33d18 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -21,7 +21,7 @@ class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): @levanter.config.main() def main(args: RayCachedLMDatasetConfig): """Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset""" - init_logging("cache_dataset.log") + init_logging(".", "cache_dataset.log") args.initialize() tokenizer = args.the_tokenizer From 9d916bd9b35b380f954f7e6954e9af965cc6af60 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 28 Nov 2023 16:01:53 -0800 Subject: [PATCH 052/205] on second thought, don't use tb in small_fast --- config/gpt2_small_fast.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index a7375de2f..8d12c7ed3 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -12,8 +12,8 @@ trainer: - type: wandb project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] - - type: tensorboard - logdir: gs://levanter-checkpoints/tblogs/ +# - type: tensorboard +# logdir: gs://levanter-checkpoints/tblogs/ mp: p=f32,c=bfloat16 model_axis_size: 1 From 4d8cd68fbfe2658094df510bb7d8f80aa18b7b89 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 30 Nov 2023 21:13:46 -0800 Subject: [PATCH 053/205] main->dev (#375) * Tweaks to improve behavior on slurm clusters (#373) * giving docker a shot * wip prepare_cache * fix deps * push slurm stuff into logical_cpu_count * add release for build so it's the default * roll slurm nodelist expansion ourselves * don't use cluster stuff for initializing ray * cleanup * cleanu * no levanter slurm cluster * fix validation set when it's an HF dataset * do basic logger initialization up front * add more logging for distributed init * pretty sure this was breaking multihost levanter * better slurm nodename * use the LevanterSlurmCluster thing again * fix commented out line * less obnoxious logging --- src/levanter/distributed.py | 78 +++++++++++++------- src/levanter/models/longformer_scale_test.py | 49 ------------ src/levanter/trainer.py | 2 + tests/test_distributed.py | 24 ++++++ 4 files changed, 78 insertions(+), 75 deletions(-) delete mode 100644 src/levanter/models/longformer_scale_test.py create mode 100644 tests/test_distributed.py diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index 7de53117e..c88eedb34 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -1,14 +1,14 @@ import atexit +import itertools import logging import os import re -import subprocess from dataclasses import dataclass from typing import List, Optional, Union import jax import ray -from jax._src import clusters +from jax._src import clusters, distributed from jax._src.clusters import SlurmCluster, TpuCluster from levanter.utils.py_utils import logical_cpu_core_count @@ -25,7 +25,7 @@ _NUM_NODES = "SLURM_STEP_NUM_NODES" _TASKS_PER_NODE = "SLURM_STEP_TASKS_PER_NODE" _VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" -_NODE_NAME = "SLURM_TOPOLOGY_ADDR" +_NODE_NAME = "SLURMD_NODENAME" class LevanterSlurmCluster(clusters.SlurmCluster): @@ -97,18 +97,14 @@ def get_local_device_ids_for_process(cls) -> Optional[List[int]]: # now we can figure out which node we are on. This is also annoying because the node list # is a comma separated list of nodes, but they collapse the list if there are multiple nodes # with the same name e.g. node001,node002,node003,node004,node007 -> node[001-004,007] - # thankfully slurm exposes a command to expand this list for us + # slurm exposes a command to expand this list for us, but it's not always available node_list = LevanterSlurmCluster._node_list() if node_list is None: raise ValueError( "Could not find node list in environment variables. You must set coordinator_address manually." ) - node_list = ( - subprocess.check_output(["scontrol", "show", "hostnames", node_list], input=b"") - .decode("utf-8") - .splitlines() - ) + node_list = _square_brace_expand(node_list) # finally, we can figure out which node we are on local_node = os.environ[_NODE_NAME] @@ -133,6 +129,38 @@ def get_local_device_ids_for_process(cls) -> Optional[List[int]]: return all_visible_devices[begin : begin + num_devices_per_local_process] +def _square_brace_expand(node_list): + # Find all parts of the sequence including text and number ranges + parts = re.findall(r"(\[.*?\]|[^\[\]]+)", node_list) + + # This function will generate numbers from a range or a single number string + def generate_numbers(number_string): + if "-" in number_string: # it's a range + start, end = map(int, number_string.split("-")) + return [str(i).zfill(len(number_string.split("-")[0])) for i in range(start, end + 1)] + else: # it's a single number + return [number_string] + + # This function will process each part and return a list of strings or a list of lists of strings + # Process each part to create lists of possible variations + processed_parts = [] + for part in parts: + if part.startswith("[") and part.endswith("]"): + # Extract the number sequences and expand each one + number_sequences = part.strip("[]").split(",") + processed_parts.append( + list(itertools.chain.from_iterable(generate_numbers(seq) for seq in number_sequences)) + ) + else: + processed_parts.append([part]) + + # Compute the Cartesian product of all parts to generate all combinations + expanded_nodes = ["".join(combination) for combination in itertools.product(*processed_parts)] + + # Join the nodes with commas + return expanded_nodes + + def _choose_port(id): port = int(id) % 2**12 + (65535 - 2**12 + 1) return port @@ -172,20 +200,13 @@ def _munge_address_port(address: str): address = os.getenv("RAY_ADDRESS") logger.info("Auto-discovered ray address using RAY_ADDRESS: %s", address) else: - cluster_types = [LevanterSlurmCluster, TpuCluster] - found = False - for cluster_type in cluster_types: - if cluster_type.is_env_present(): - found = True - break - - if not found: + coord_address = getattr(distributed.global_state, "coordinator_address", None) + + if coord_address is None: logger.info("No auto-discovered ray address found. Using default ray.init()") address = None else: - logger.info(f"Auto-discovered ray address using {cluster_type.__name__}") - - coord_address = cluster_type.get_coordinator_address() + logger.info(f"Auto-discovered ray address using JAX coordinator address: {coord_address}") host, port = _munge_address_port(coord_address) ray_port = _choose_port(port + 10234) @@ -194,7 +215,7 @@ def _munge_address_port(address: str): # Explicitly setting the number of CPUs on ray init stops init errors num_cpus = logical_cpu_core_count() - if cluster_type.get_process_id() == 0: + if jax.process_index() == 0: logger.info(f"Starting ray head on port {ray_port}. We are process 0.") logger.info(f"Starting ray with num_cpus set to {num_cpus}.") os.system(f"ray start --head --port {ray_port} --num-cpus {num_cpus}") @@ -202,13 +223,12 @@ def _munge_address_port(address: str): atexit.register(lambda: os.system("ray stop -g 10 --force")) elif start_workers: logger.info( - f"Starting ray worker and connecting to {address}." - f" We are process {cluster_type.get_process_id()}." + f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}." ) logger.info(f"Starting ray with num_cpus set to {num_cpus}.") os.system(f"ray start --address {address} --num-cpus {num_cpus}") - logger.info(f"ray.init(address='{address}', **{kwargs})") + logger.info(f"ray.init(address={repr(address)}, namespace={repr(namespace)}, **{repr(kwargs)})") # Ray has retry logic, so we don't need to retry here :fingers-crossed: ray.init(address=address, namespace=namespace, **kwargs) atexit.register(lambda: ray.shutdown()) @@ -252,8 +272,14 @@ def initialize(self): jax.distributed.initialize(coordinator_address, self.num_processes, self.process_id, device_ids) logger.info( - f"Initialized jax.distributed with {jax.device_count()} devices, {jax.process_count()} hosts" - f", coordinator_address={coordinator_address}, process_id={self.process_id}" + f"Initialized jax.distributed with {jax.device_count()} devices, {jax.process_count()} processes," + f" coordinator_address={coordinator_address}, process_id={self.process_id}, my" + f" device_ids={device_ids}." + ) + else: + logger.info( + "Not initializing jax.distributed because no distributed config " + "was provided, and no cluster was detected." ) diff --git a/src/levanter/models/longformer_scale_test.py b/src/levanter/models/longformer_scale_test.py deleted file mode 100644 index f097afe46..000000000 --- a/src/levanter/models/longformer_scale_test.py +++ /dev/null @@ -1,49 +0,0 @@ -import time - -import jax -import numpy as np -from jax.sharding import Mesh - -import haliax as hax -from haliax.partitioning import named_jit - -from levanter.models.longformer import causal_sliding_window_attention - - -Len = hax.Axis("Len", 8192) -W = hax.Axis("W", 512) -D = hax.Axis("D", 4096) -B = hax.Axis("B", 256) - -# Len = hax.Axis("Len", 64) -# W = hax.Axis("W", 4) -# D = hax.Axis("D", 8) -# B = hax.Axis("B", 4) - - -devices = np.array(jax.devices()) -mesh = Mesh(devices, ("data",)) - -axis_resources = {"B": "data"} - - -@named_jit(axis_resources=axis_resources) -def do_attn(inputs): - return causal_sliding_window_attention(Len, W, D, inputs, inputs, inputs) - - -@named_jit(axis_resources=axis_resources) -def init(): - return hax.random.uniform(jax.random.PRNGKey(0), (B, Len, D)) - - -if __name__ == "__main__": - with mesh: - data = init() - result = do_attn(data) - result.array.block_until_ready() - time_in = time.time() - result2 = do_attn(data) - result2.array.block_until_ready() - time_out = time.time() - print(time_out - time_in) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index f7dccb130..649bf6d0e 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -589,6 +589,8 @@ def __post_init__(self): def initialize(self): """Initializes jax, logging, setting the run name/id in the process""" self._initialize_jax_config() + # Can't do full logging setup until we've initialized jax b/c we use jax for rank id + pylogging.basicConfig(level=pylogging.INFO) self.distributed.initialize() self._validate_and_set_defaults() diff --git a/tests/test_distributed.py b/tests/test_distributed.py new file mode 100644 index 000000000..359b8b77a --- /dev/null +++ b/tests/test_distributed.py @@ -0,0 +1,24 @@ +from levanter.distributed import _square_brace_expand + + +def test_square_brace_expand(): + custom_sequence = "node[001-004,007]suffix" + expanded_nodes = _square_brace_expand(custom_sequence) + assert expanded_nodes == ["node001suffix", "node002suffix", "node003suffix", "node004suffix", "node007suffix"] + + custom_sequence_2 = "prefix[001-002]node[005-006]suffix" + expanded_nodes_2 = _square_brace_expand(custom_sequence_2) + assert expanded_nodes_2 == [ + "prefix001node005suffix", + "prefix001node006suffix", + "prefix002node005suffix", + "prefix002node006suffix", + ] + + custom_sequence_3 = "node[1-11]suffix" + expanded_nodes_3 = _square_brace_expand(custom_sequence_3) + assert expanded_nodes_3 == [f"node{i}suffix" for i in range(1, 12)] + + custom_sequence_3 = "node[1-11,21]suffix" + expanded_nodes_3 = _square_brace_expand(custom_sequence_3) + assert expanded_nodes_3 == [f"node{i}suffix" for i in range(1, 12)] + ["node21suffix"] From 3b27a08dad46dfba24989dcf7a782e2dfb345d3d Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Dec 2023 15:16:40 -0800 Subject: [PATCH 054/205] supporting new trainer in gsm8k example --- examples/gsm8k-lora/gsm8k_lora.py | 91 ++++++++++++++++--------------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 6b369bf77..b4eaa3ec9 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -9,7 +9,6 @@ import jax.random as jrandom import numpy as np import transformers -import wandb import haliax as hax @@ -126,7 +125,7 @@ def format_output(ex): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -168,53 +167,57 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": From f2842e98ab839c74f2a5ef6639abb967a71e1f9f Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 7 Dec 2023 11:40:10 -0800 Subject: [PATCH 055/205] Add Sophia-H, some WIP support for Sophia-G (#372) * wip * initial commit of hero * switch to autodiff hvp * make test run anywhere * remove stub adam thing * fix missing returns in text.pyu i don't understand how this code made it in/wasn't caught * fix "replicated" arrays in named_pjit * support hessian updates for hero in gpt2_example * track time for hessian computation * attempt to add some logging from inside jax * stupidly was squaring the LR * make global batch loader invariant to reorderings of how jax.array_from_callback call the callback (#132) * default parallelism * add decay schedule for gamma * rename hero to sofia * wip * mucking around with sharding b/c of some weird behaviro on cpu * let's see if this works ons TPU * fix sofia * fix test for named_jit * wip * wip making sophia configurable * wip * wip * mdakd * wip * fix arg name * update configs * explicit sophia_g config * wip * not sure how this didn't get committed * fix sophia-g (dumb dtype issue) * try logging to see what the hold up is * sigh * try to jit the eval loader and see if that fixes it * refactor loaders to have more in common, should hopefully improve perf * fix sophia test * update sophia large configs * missed a sofia * tweak lr * small_fast sophia_h config * got sophia_h test working again * please pre-commit * fix alpaca ocnfig for llama * almost there for sophia-h * switch to 0.4.19 b/c there's a bug in jaxlib that sophia triggers * ok, see if this is correct-ish * try a fancier config * put clipping into the sophia scale so it's easier to manage * try clipping the hessian on a per-component basis? * more clip * fix tests by default to adam * try a full gpt2 small run * add tag * remove clip * wip * bump version * disable logging for now * try this? * cleanup, delete old optimizerconfig * reorganize optim package * remove type: adam stuff since it's default * forgot the hessian_update in the migration * fix LR being reported as 0 * oops --- README.md | 10 +- config/gpt2_data_mix.yaml | 22 -- config/gpt2_large.yaml | 2 +- config/gpt2_large_sophia_g.yaml | 21 ++ config/gpt2_large_sophia_h.yaml | 21 ++ config/gpt2_nano.yaml | 3 +- config/gpt2_small_fast.yaml | 2 - config/gpt2_small_fast_sophia_g.yaml | 24 ++ config/gpt2_small_fast_sophia_h.yaml | 24 ++ config/gpt2_small_fast_sophiah.yaml | 26 ++ config/gpt2_small_sophiah.yaml | 19 ++ config/mpt_7b_continued.yaml | 22 -- config/mpt_7b_continued_biomedlm.yaml | 27 -- config/optim/sophia-h_large.yaml | 7 + config/optim/sophia-h_medium.yaml | 7 + config/optim/sophia-h_small.yaml | 7 + config/optim/sophia-h_xl.yaml | 7 + docs/Levanter-1.0-Release.md | 4 +- examples/alpaca/alpaca.py | 3 +- src/levanter/__init__.py | 2 + src/levanter/callbacks.py | 17 +- src/levanter/main/lora_lm.py | 5 +- src/levanter/main/train_lm.py | 7 +- src/levanter/optim/__init__.py | 16 + src/levanter/optim/config.py | 131 ++++++++ src/levanter/optim/second_order.py | 228 ++++++++++++++ src/levanter/optim/sophia.py | 412 ++++++++++++++++++++++++++ src/levanter/optim/util.py | 18 ++ src/levanter/tracker/tracker_fns.py | 15 +- src/levanter/trainer.py | 113 +------ tests/data/hero_data.npy | Bin 0 -> 32128 bytes tests/test_export_to_hf.py | 3 +- tests/test_hf_gpt2_serialize.py | 4 +- tests/test_logging.py | 1 - tests/test_mpt.py | 11 +- tests/test_sophia.py | 56 ++++ 36 files changed, 1086 insertions(+), 211 deletions(-) delete mode 100644 config/gpt2_data_mix.yaml create mode 100644 config/gpt2_large_sophia_g.yaml create mode 100644 config/gpt2_large_sophia_h.yaml create mode 100644 config/gpt2_small_fast_sophia_g.yaml create mode 100644 config/gpt2_small_fast_sophia_h.yaml create mode 100644 config/gpt2_small_fast_sophiah.yaml create mode 100644 config/gpt2_small_sophiah.yaml delete mode 100644 config/mpt_7b_continued.yaml delete mode 100644 config/mpt_7b_continued_biomedlm.yaml create mode 100644 config/optim/sophia-h_large.yaml create mode 100644 config/optim/sophia-h_medium.yaml create mode 100644 config/optim/sophia-h_small.yaml create mode 100644 config/optim/sophia-h_xl.yaml create mode 100644 src/levanter/optim/__init__.py create mode 100644 src/levanter/optim/config.py create mode 100644 src/levanter/optim/second_order.py create mode 100644 src/levanter/optim/sophia.py create mode 100644 src/levanter/optim/util.py create mode 100644 tests/data/hero_data.npy create mode 100644 tests/test_sophia.py diff --git a/README.md b/README.md index bbc5cc6c6..e035193f6 100644 --- a/README.md +++ b/README.md @@ -24,20 +24,20 @@ Levanter is a framework for training large language models (LLMs) and other foun 2. **Scalable**: Levanter scales to large models, and to be able to train on a variety of hardware, including GPUs and TPUs. 3. **Reproducible**: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption. -We built Levanter with [JAX](https:://github.com/google/jax), [Equinox](https://github.com/patrick-kidger/equinox), -and [Haliax](https://github.com/stanford-crfm/haliax). +We built Levanter with [JAX](https:://github.com/google/jax), [Equinox](https://github.com/patrick-kidger/equinox), and [Haliax](https://github.com/stanford-crfm/haliax). ## Features * **Distributed Training**: We support distributed training on TPUs (and soon, GPUs), including FSDP and tensor parallelism. * **Compatibility**: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via [SafeTensors](https://github.com/huggingface/safetensors). * **Performance**: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText. -* **Reproducibility**: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption. * **Cached On-Demand Data Preprocessing**: We preprocess corpora online, but we cache the results of preprocessing so that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training. -* **Logging**: Logging is done with [WandB](https://wandb.ai/), complete with a fancy online visualization of the validation set during training. +* **Optimization**: Levanter supports the new [Sophia](https://arxiv.org/abs/2305.14342) optimizer, which can be 2x as fast as Adam. We also support ses [Optax](https://github.com/deepmind/optax) for optimization with AdamW, etc. +* **Logging**: Levanter supports a few different logging backends, including [WandB](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard). (Adding a new logging backend is easy!) Levanter even exposes the ability +to log inside of JAX `jit`-ted functions. +* **Reproducibility**: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption. * **Distributed Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now. -* **Optimization**: Levanter uses [Optax](https://github.com/deepmind/optax) for optimization. Our new optimizer, [Sophia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon! diff --git a/config/gpt2_data_mix.yaml b/config/gpt2_data_mix.yaml deleted file mode 100644 index 073e3b46b..000000000 --- a/config/gpt2_data_mix.yaml +++ /dev/null @@ -1,22 +0,0 @@ -data: - configs: - owt: - train_urls: - - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" - validation_urls: - - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" - wikitext: - id: dlwh/wikitext_103_detokenized - train_weights: - owt: 0.6 - wikitext: 0.4 - tokenizer: gpt2 - cache_dir: "gs://levanter-data/tokenized/data_mix" -model: - type: gpt2 - hidden_dim: 32 - num_heads: 4 - num_layers: 2 -trainer: - num_train_steps: 100 - train_batch_size: 32 diff --git a/config/gpt2_large.yaml b/config/gpt2_large.yaml index d772f9fdf..8a8aea8d7 100644 --- a/config/gpt2_large.yaml +++ b/config/gpt2_large.yaml @@ -14,7 +14,7 @@ trainer: mp: p=f32,c=bfloat16 model_axis_size: 1 - per_device_parallelism: 16 + per_device_parallelism: -1 optimizer: learning_rate: 2E-4 weight_decay: 0.1 diff --git a/config/gpt2_large_sophia_g.yaml b/config/gpt2_large_sophia_g.yaml new file mode 100644 index 000000000..53a1d0806 --- /dev/null +++ b/config/gpt2_large_sophia_g.yaml @@ -0,0 +1,21 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 1280 + num_heads: 20 + num_layers: 36 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-g"] + + num_train_steps: 200000 + mp: p=f32,c=bfloat16 + +optimizer: + type: sophia-g + learning_rate: 2E-4 + weight_decay: 0.15 diff --git a/config/gpt2_large_sophia_h.yaml b/config/gpt2_large_sophia_h.yaml new file mode 100644 index 000000000..314801728 --- /dev/null +++ b/config/gpt2_large_sophia_h.yaml @@ -0,0 +1,21 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 1280 + num_heads: 20 + num_layers: 36 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-h"] + + num_train_steps: 200000 + mp: p=f32,c=bfloat16 + +optimizer: + type: sophia-h + learning_rate: 1.7E-4 + weight_decay: 0.2 diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 993302670..5612fc104 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -14,8 +14,7 @@ trainer: - every: 50 save_interval: 5m - per_device_eval_parallelism: 1 - per_device_parallelism: 1 + per_device_parallelism: 16 train_batch_size: 32 tensor_parallel_axes: ["mlp", "heads"] diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 8d12c7ed3..6242a37bc 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -12,8 +12,6 @@ trainer: - type: wandb project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] -# - type: tensorboard -# logdir: gs://levanter-checkpoints/tblogs/ mp: p=f32,c=bfloat16 model_axis_size: 1 diff --git a/config/gpt2_small_fast_sophia_g.yaml b/config/gpt2_small_fast_sophia_g.yaml new file mode 100644 index 000000000..0f86ac503 --- /dev/null +++ b/config/gpt2_small_fast_sophia_g.yaml @@ -0,0 +1,24 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest", "sophia-g"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-g + learning_rate: 1E-3 + weight_decay: 0.15 diff --git a/config/gpt2_small_fast_sophia_h.yaml b/config/gpt2_small_fast_sophia_h.yaml new file mode 100644 index 000000000..671acec8f --- /dev/null +++ b/config/gpt2_small_fast_sophia_h.yaml @@ -0,0 +1,24 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest", "sophia-h"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-h + learning_rate: .85E-3 + weight_decay: 0.2 diff --git a/config/gpt2_small_fast_sophiah.yaml b/config/gpt2_small_fast_sophiah.yaml new file mode 100644 index 000000000..71675312c --- /dev/null +++ b/config/gpt2_small_fast_sophiah.yaml @@ -0,0 +1,26 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-h + learning_rate: 0.8E-3 + weight_decay: 0.1 + warmup: 0.01 + gamma: 0.005 diff --git a/config/gpt2_small_sophiah.yaml b/config/gpt2_small_sophiah.yaml new file mode 100644 index 000000000..fd82ab226 --- /dev/null +++ b/config/gpt2_small_sophiah.yaml @@ -0,0 +1,19 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-h"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + + train_batch_size: 512 +optimizer: !include optim/sophia-h_small.yaml diff --git a/config/mpt_7b_continued.yaml b/config/mpt_7b_continued.yaml deleted file mode 100644 index 980b4aaaf..000000000 --- a/config/mpt_7b_continued.yaml +++ /dev/null @@ -1,22 +0,0 @@ -data: !include data/pile_source_old.yaml -model: - type: mpt -initialize_from_hf: true -use_hf_model_config: true -trainer: - tracker: - project: "levanter" - tags: ["pile", "mpt"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 4 - per_device_eval_parallelism: 4 - - train_batch_size: 1024 - num_train_steps: 10000 - steps_per_eval: 500 -optimizer: - learning_rate: 1.2e-4 - weight_decay: 0.1 diff --git a/config/mpt_7b_continued_biomedlm.yaml b/config/mpt_7b_continued_biomedlm.yaml deleted file mode 100644 index 504f1a3ba..000000000 --- a/config/mpt_7b_continued_biomedlm.yaml +++ /dev/null @@ -1,27 +0,0 @@ -data: - train_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_train.{1..128}-of-128.jsonl.gz" - validation_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_val.{1..8}-of-8.jsonl.gz" - cache_dir: "gs://pubmed-mosaic/tokenized/pubmed-sharded-neox/" - tokenizer: "EleutherAI/gpt-neox-20b" -model: - type: mpt -initialize_from_hf: "mosaicml/mpt-7b@68e1a8e0ebb9b30f3c45c1ef6195980f29063ae2" -use_hf_model_config: true -trainer: - tracker: - project: "levanter" - tags: ["pubmed", "mpt", "continued"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 8 - - train_batch_size: 2048 - num_train_steps: 50000 - steps_per_eval: 1000 -optimizer: - learning_rate: 1.2e-5 - weight_decay: 0.1 diff --git a/config/optim/sophia-h_large.yaml b/config/optim/sophia-h_large.yaml new file mode 100644 index 000000000..6644f20b8 --- /dev/null +++ b/config/optim/sophia-h_large.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 3E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_medium.yaml b/config/optim/sophia-h_medium.yaml new file mode 100644 index 000000000..5c411f109 --- /dev/null +++ b/config/optim/sophia-h_medium.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 4E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_small.yaml b/config/optim/sophia-h_small.yaml new file mode 100644 index 000000000..0bb8ea2a7 --- /dev/null +++ b/config/optim/sophia-h_small.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 6E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_xl.yaml b/config/optim/sophia-h_xl.yaml new file mode 100644 index 000000000..fe2c868b3 --- /dev/null +++ b/config/optim/sophia-h_xl.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 1.2E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/docs/Levanter-1.0-Release.md b/docs/Levanter-1.0-Release.md index 8fed293dd..05c66683a 100644 --- a/docs/Levanter-1.0-Release.md +++ b/docs/Levanter-1.0-Release.md @@ -539,7 +539,7 @@ learn differently from Transformers. ## A few other features * **Training**: Levanter uses [Optax](https://github.com/deepmind/optax) for optimization, - though our new optimizer, [Sofia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon! + though our new optimizer, [Sophia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon! * **Logging**: Logging is done with [WandB](https://wandb.ai/), complete with a fancy online visualization of the validation set during training. * **Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now. * **Export**: We also support exporting models to the Hugging Face Hub, with export compatible with Pytorch and Transformers via [SafeTensors](https://github.com/huggingface/safetensors). @@ -627,7 +627,7 @@ trained on the [Lakh MIDI](https://colinraffel.com/projects/lmd/) corpus. The la This is just the beginning for Levanter. In the future, look for: * more models on interesting problem domains, * scaled up versions of new architectures developed here at Stanford and elsewhere, -* new training techniques, including the newly released [Sofia](https://arxiv.org/abs/2305.14342) optimizer, +* new training techniques, including the newly released [Sophia](https://arxiv.org/abs/2305.14342) optimizer, * and larger models! Levanter is still a work in progress, but we are excited to share it with the community. We hope that Levanter will be diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index f63896d09..85ce758e3 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -18,7 +18,8 @@ from levanter.data import Dataset from levanter.data.sharded_dataset import JsonDataset, JsonlDataset, WrappedHFDataset from levanter.models.lm_model import LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils import fsspec_utils from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.py_utils import non_caching_cycle diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 4153c0711..ecabba8df 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,6 +3,8 @@ import levanter.data as data import levanter.distributed as distributed import levanter.logging as logging +import levanter.models as models +import levanter.optim as optim import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index a80d0619e..409b235f9 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -27,6 +27,8 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): total_loss = 0.0 + total_load_time = 0.0 + total_loss_time = 0.0 n = 0 if name is not None: @@ -35,10 +37,20 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n desc = "eval" pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches) - for batch in pbar: + iter_ = iter(pbar) + while True: + time_in = time.time() + batch = next(iter_, None) + if batch is None: + break + load_time = time.time() - time_in + total_load_time += load_time loss = loss_fn(model, batch) total_loss += loss.item() n += 1 + loss_time = time.time() - time_in - load_time + total_loss_time += loss_time + pbar.set_postfix(loss=total_loss / n) if max_batches is not None and n >= max_batches: @@ -47,6 +59,9 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n + logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") + logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") + return total_loss diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index dca79918a..babe7d2fa 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -18,7 +18,8 @@ save_merged_hf_checkpoint_callback, save_peft_checkpoint_callback, ) -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle @@ -32,7 +33,7 @@ class LoraLmConfig: lora: LoraConfig = field(default_factory=LoraConfig) data: LMDatasetConfig = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) peft_save_path: Optional[str] = None # path to save peft-compatible checkpoints peft_hf_upload: Optional[str] = None diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 6ce28ec05..390e2e4af 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -16,7 +16,8 @@ from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count @@ -28,7 +29,7 @@ class TrainLmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) # config related to continued pretraining initialize_from_hf: Union[bool, str] = False @@ -44,6 +45,8 @@ class TrainLmConfig: hf_upload: Optional[str] = None hf_save_steps: int = 10000 + update_hessian_steps: int = 10 + def main(config: TrainLmConfig): levanter.initialize(config) diff --git a/src/levanter/optim/__init__.py b/src/levanter/optim/__init__.py new file mode 100644 index 000000000..319ddf84d --- /dev/null +++ b/src/levanter/optim/__init__.py @@ -0,0 +1,16 @@ +from .config import AdamConfig, OptimizerConfig +from .second_order import ( + AnySecondOrderTransformation, + HessianUpdateFn, + SecondOrderTransformation, + chain_second_order, + inject_hyperparams, +) +from .sophia import ( + ScaleBySophiaState, + SophiaGConfig, + SophiaGObjective, + SophiaHConfig, + scale_by_sophia_g, + scale_by_sophia_h, +) diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py new file mode 100644 index 000000000..68a02e09e --- /dev/null +++ b/src/levanter/optim/config.py @@ -0,0 +1,131 @@ +import abc +import warnings +from dataclasses import dataclass +from typing import Optional + +import draccus +import optax +from jax import numpy as jnp + + +@dataclass +class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): + learning_rate: float = 6e-4 + weight_decay: float = 0.0 + + min_lr_ratio: float = 0.1 + warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup + warmup: float = 0.01 + """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" + cooldown: float = 0.0 + """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" + lr_schedule: str = "cosine" # constant, cosine, linear + + @classmethod + def default_choice_name(cls) -> Optional[str]: + return "adam" + + @abc.abstractmethod + def build(self, num_train_steps: int): + raise NotImplementedError + + def lr_scheduler(self, num_train_steps): + warmup_steps = self._convert_warmup(num_train_steps) + cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) + lr_decay_steps = num_train_steps - warmup_steps - cooldown_steps + min_lr = self.learning_rate * self.min_lr_ratio + + match self.lr_schedule: + case "constant": + schedule = optax.constant_schedule(self.learning_rate) + case "cosine": + schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) + case "linear": + schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps - warmup_steps) + case "inv_sqrt": + schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) + case _: + raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") + + schedules = [] + boundaries = [] + + if warmup_steps != 0: + warmup = optax.linear_schedule(0.0, self.learning_rate, warmup_steps) + schedules.append(warmup) + boundaries.append(warmup_steps) + + schedules.append(schedule) + + if cooldown_steps != 0: + final_main_lr = schedule(lr_decay_steps) + cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps) + schedules.append(cooldown) + boundaries.append(num_train_steps - cooldown_steps) + + if len(schedules) > 1: + schedule = optax.join_schedules(schedules, boundaries) + + return schedule + + def _convert_warmup(self, num_train_steps: int): + if self.warmup_ratio is not None: + warnings.warn("warmup_ratio is deprecated. Use warmup instead") + return int(self.warmup_ratio * num_train_steps) + else: + return _convert_ratio_or_steps(self.warmup, num_train_steps) + + +def _inv_sqrt_decay_schedule(lr: float, min_lr: float, warmup_steps: int, timescale: float = 10000): + def schedule(count): + decay = jnp.minimum(1.0, 1.0 / jnp.sqrt(jnp.maximum(count + warmup_steps, 1) / timescale)) + return jnp.maximum(lr * decay, min_lr) + + return schedule + + +def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): + if ratio_or_steps < 1.0: + return int(ratio_or_steps * num_train_steps) + else: + return int(ratio_or_steps) + + +@dataclass +class HessianOptConfig(OptimizerConfig, abc.ABC): + update_interval: int = 10 + """How often to update the hessian approximation.""" + + +@OptimizerConfig.register_subclass("adam") +@dataclass +class AdamConfig(OptimizerConfig): + weight_decay: float = 0.1 + beta1: float = 0.9 + beta2: float = 0.999 + epsilon: float = 1e-8 + max_grad_norm: Optional[float] = 1.0 + + def build(self, num_train_steps): + """Creates the optimizer""" + # indirection makes it work with optax.inject_hyperparams so we can log the learning rate + def _optimizer(learning_rate): + components = [] + + if self.max_grad_norm: + components.append(optax.clip_by_global_norm(self.max_grad_norm)) + + components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) + + if self.weight_decay > 0: + # TODO: add weight decay masking?? + components.append(optax.add_decayed_weights(self.weight_decay)) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = optax.chain(*components) + + return optimizer + + return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) diff --git a/src/levanter/optim/second_order.py b/src/levanter/optim/second_order.py new file mode 100644 index 000000000..fd0da7325 --- /dev/null +++ b/src/levanter/optim/second_order.py @@ -0,0 +1,228 @@ +import functools +import inspect +import typing +from typing import Callable, Iterable, List, NamedTuple, Optional, Union + +import chex +import jax +import optax +from jax import numpy as jnp +from optax._src import numerics +from optax._src.schedule import InjectHyperparamsState, _convert_floats + + +class HessianUpdateFn(typing.Protocol): + """A callable type for the""" + + def __call__( + self, + state, + fn, + model, + *batch, + **batch_kwargs, + ) -> optax.OptState: + """Returns the updated `state` given the `hessian` and `state`.""" + pass + + +class SecondOrderTransformation(NamedTuple): + """A triple of pure functions that together define a second-order optimizer.""" + + init: optax.TransformInitFn + update: optax.TransformUpdateFn + update_hessian: HessianUpdateFn + + +AnySecondOrderTransformation = Union[SecondOrderTransformation, optax.GradientTransformation] +"""A type that can be used to represent either a first or second order transformation.""" + + +def chain_second_order(*args: AnySecondOrderTransformation) -> SecondOrderTransformation: + """Applies a list of chainable update transformations. Analogous to optax.chain, + but for second order transformations. + """ + + init_fns = [] + update_fns = [] + update_hessian_fns: List[Optional[HessianUpdateFn]] = [] + + for arg in args: + if isinstance(arg, SecondOrderTransformation): + init_fns.append(arg.init) + update_fns.append(arg.update) + update_hessian_fns.append(arg.update_hessian) + else: + init_fns.append(arg.init) + update_fns.append(arg.update) + update_hessian_fns.append(None) + + def init_fn(params): + return tuple(fn(params) for fn in init_fns) + + def update_fn(updates, state, params=None): + if len(update_fns) != len(state): + raise ValueError( + "The number of updates and states has to be the same in chain! Make sure you have called init first!" + ) + + new_state = [] + for s, fn in zip(state, update_fns): + updates, new_s = fn(updates, s, params) + new_state.append(new_s) + return updates, tuple(new_state) + + def update_hessian_fn(state, fn, model, *batch, **batch_kwargs): + if len(update_hessian_fns) != len(state): + raise ValueError( + "The number of updates and states has to be the same in chain! Make sure you have called init first!" + ) + + new_state = [] + for s, update_fn in zip(state, update_hessian_fns): + if update_fn is None: + new_state.append(s) + else: + new_s = update_fn(s, fn, model, *batch, **batch_kwargs) + new_state.append(new_s) + return tuple(new_state) + + return SecondOrderTransformation(init_fn, update_fn, update_hessian_fn) + + +def inject_hyperparams( + inner_factory: Callable[..., SecondOrderTransformation], + static_args: Union[str, Iterable[str]] = (), + hyperparam_dtype: Optional[jnp.dtype] = None, +) -> Callable[..., SecondOrderTransformation]: + """ + Second Order version of optax.inject_hyperparams. + + Original docstring: + + Wrapper that injects hyperparameters into the inner GradientTransformation. + + This wrapper allows you to pass schedules (i.e. a function that returns a + numeric value given a step count) instead of constants for + hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean + flags cannot be scheduled). + + For example, to use ``scale_by_adam`` with a piecewise linear + schedule for beta_1 and constant for beta_2:: + + scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)( + b1=optax.piecewise_linear_schedule(...), + b2=0.99) + + You may manually change numeric hyperparameters that were not scheduled + through the ``hyperparams`` dict in the ``InjectHyperparamState``:: + + state = scheduled_adam.init(params) + updates, state = scheduled_adam.update(grads, state) + state.hyperparams['b2'] = 0.95 + updates, state = scheduled_adam.update(updates, state) # uses b2 = 0.95 + + Manually overriding scheduled hyperparameters will have no effect (e.g. + in the code sample above, you cannot manually adjust ``b1``). + + Args: + inner_factory: a function that returns the inner + ``optax.GradientTransformation`` given the hyperparameters. + static_args: a string or iterable of strings specifying which + callable parameters are not schedules. inject_hyperparams treats all + callables as schedules by default, so if a hyperparameter is a + non-schedule callable, you must specify that using this argument. + hyperparam_dtype: Optional datatype override. If specified, all float + hyperparameters will be cast to this type. + + Returns: + A callable that returns a ``optax.GradientTransformation``. This callable + accepts the same arguments as ``inner_factory``, except you may provide + schedules in place of the constant arguments. + """ + static_args = {static_args} if isinstance(static_args, str) else set(static_args) + inner_signature = inspect.signature(inner_factory) + + if not static_args.issubset(inner_signature.parameters): + raise ValueError( + "`static_args` must specify a subset of `inner_factory`'s parameters. " + f"Given `static_args`: {static_args}. `inner_factory` parameters: " + f"{set(inner_signature.parameters.keys())}" + ) + + @functools.wraps(inner_factory) + def wrapped_transform(*args, **kwargs) -> SecondOrderTransformation: + bound_arguments = inner_signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + + sched_hps, numeric_hps, other_hps = {}, {}, {} + for name, value in bound_arguments.arguments.items(): + if name in static_args or isinstance(value, bool): + other_hps[name] = value + elif callable(value): + sched_hps[name] = value + elif isinstance(value, (int, float, chex.Array)): + numeric_hps[name] = value + else: + other_hps[name] = value + + def schedule_fn(count, dtype): + return {k: _convert_floats(f(count), dtype) for k, f in sched_hps.items()} + + def init_fn(params): + count = jnp.zeros([], jnp.int32) + if hyperparam_dtype is None: + dtype = _find_first_floating_dtype(numeric_hps) + else: + dtype = hyperparam_dtype + hparams = {k: jnp.asarray(_convert_floats(v, dtype)) for k, v in numeric_hps.items()} + hparams.update(schedule_fn(count, dtype)) + return InjectHyperparamsState( # pylint:disable=too-many-function-args + count, hparams, inner_factory(**other_hps, **hparams).init(params) + ) + + def update_fn(updates, state, params=None): + if hyperparam_dtype is None: + dtype = _find_first_floating_dtype(updates) + else: + dtype = hyperparam_dtype + hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()} + hparams.update(schedule_fn(state.count, dtype)) + updates, inner_state = inner_factory(**other_hps, **hparams).update(updates, state.inner_state, params) + count_inc = numerics.safe_int32_increment(state.count) + + # pylint:disable=too-many-function-args + return updates, InjectHyperparamsState(count_inc, hparams, inner_state) + # pylint:enable=too-many-function-args + + def _find_first_floating_dtype(updates): + dtype = jnp.float32 + for v in jax.tree_util.tree_leaves(updates): + if isinstance(v, jnp.ndarray): + if isinstance(v.dtype, jnp.floating): + dtype = v.dtype + break + return dtype + + def update_hessian(state, fn, model, *batch, **batch_kwargs): + if hyperparam_dtype is None: + dtype = _find_first_floating_dtype(batch) + else: + dtype = hyperparam_dtype + hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()} + hparams.update(schedule_fn(state.count, dtype)) + new_inner_state = inner_factory(**other_hps, **hparams).update_hessian( + state.inner_state, + fn, + model, + *batch, + **batch_kwargs, + ) + + # pylint:disable=too-many-function-args + return InjectHyperparamsState(state.count, hparams, new_inner_state) + # pylint:enable=too-many-function-args + + return SecondOrderTransformation(init_fn, update_fn, update_hessian) + + return wrapped_transform diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py new file mode 100644 index 000000000..3cb07044c --- /dev/null +++ b/src/levanter/optim/sophia.py @@ -0,0 +1,412 @@ +import abc +import typing +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional, TypeVar, runtime_checkable + +import equinox as eqx +import jax +import jaxtyping +import optax +from jax import numpy as jnp +from jax.random import PRNGKey +from jaxtyping import PRNGKeyArray + +# TODO: remove dependency on _src internals +from optax._src import numerics +from optax._src.transform import bias_correction, update_moment + +import levanter.tracker +from levanter.optim.config import HessianOptConfig, OptimizerConfig +from levanter.optim.second_order import SecondOrderTransformation, chain_second_order, inject_hyperparams +from levanter.optim.util import hvp, tree_gaussian +from levanter.utils.jax_utils import parameter_count + + +M = TypeVar("M") +Ex = TypeVar("Ex") + +GAMMA_SOPHIA_G = 0.05 +GAMMA_SOPHIA_H = 0.01 + + +class ScaleBySophiaState(NamedTuple): + """State for Sophia and similar.""" + + count: jaxtyping.Array # shape=(), dtype=jnp.int32. + hessian_count: jaxtyping.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates # momentum + h: optax.Updates # EMA of hessian diagonal + hess_key: PRNGKey + + +@runtime_checkable +class SophiaGObjective(typing.Protocol): + """ + Class for objective functions that can be used with Sophia-G + + Sophia-G is a second order optimizer that uses the Gauss-Newton-Bartlett approximation to the Hessian + to compute the second order update. This requires the objective function be of the form loss(logits(x)) + where logits(x) is the activation of the model for the given example x. This is the case for most models + that are trained with "typical" losses. + """ + + def logits(self, parameters: M, example: Ex, *args, **kwargs) -> Any: + """ + Returns the logits/activations of the model for the given example, + or just sufficient statistics for the example for non-categorical models. + """ + ... + + def sample(self, logits, example: Ex, *, key: PRNGKey) -> Ex: + """ + Samples a new example with the same shape as the original example, but with + the "labels" replaced with some sampled values + """ + ... + + def loss(self, logits, example: Ex): + """ + Just computes the loss, e.g. cross entropy. + + Should return the mean loss over the batch, not the sum. + + TODO: should we reconsider this? + """ + ... + + def __call__(self, parameters: M, example: Ex, *args, **kwargs): + """ + Just a convenience method for invoking the objective for "normal" training w/o sophia-g + """ + logits = self.logits(parameters, example, *args, **kwargs) + return self.loss(logits, example) + + def num_data_points(self, example: Ex) -> int: + """ + Returns the number of data points in the example. This should take into account the loss mask + or any other masking that might be applied to the example. + + By default, we just return 1, and you can just pull the term into the hyperparams of Sophia if you want. + + Returns: + The number of data points in the example + """ + return 1 + + +@dataclass +class BaseSophiaConfig(HessianOptConfig): + """Base class for sophia variants. Doesn't implement the state update""" + + weight_decay: float = 0.1 + beta1: float = 0.96 + beta2: float = 0.99 + + epsilon: float = 1e-12 + clip_threshold: Optional[float] = 1.0 + rng_seed: int = 0 + + @abc.abstractmethod + def compute_hessian( + self, + fn, + model, + *batch, + hess_key: PRNGKey, + **batch_kwargs, + ): + raise NotImplementedError + + def build(self, num_train_steps: int): + def _optimizer(learning_rate, gamma) -> SecondOrderTransformation: + components = [] + key = jax.random.PRNGKey(self.rng_seed) + + components.append( + _sophia_gradient_transform( + sophia_hess_fn=self.compute_hessian, + update_interval=self.update_interval, + b1=self.beta1, + b2=self.beta2, + eps=self.epsilon, + gamma=gamma, + initial_key=key, + clip_threshold=self.clip_threshold, + ) + ) + + # Algorithm 3, step 11 (Note, this comes after clipping b/c it's not supposed to be clipped) + # In the paper, it comes as a prior step, but doesn't get clipped + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay)) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = chain_second_order(*components) + + return optimizer + + # Hong suggested using cosine decay for gamma + # gamma_decay_schedule = optax.cosine_decay_schedule(self.gamma, num_train_steps // 2, 0) # type: ignore + constant_gamma_schedule = optax.constant_schedule(self.gamma) # type: ignore + # gamma_schedule = optax.join_schedules([constant_gamma_schedule, gamma_decay_schedule], [num_train_steps // 2]) + + return inject_hyperparams(_optimizer)( + learning_rate=self.lr_scheduler(num_train_steps), gamma=constant_gamma_schedule + ) + + +@OptimizerConfig.register_subclass("sophia-g") +@dataclass +class SophiaGConfig(BaseSophiaConfig): + gamma: float = GAMMA_SOPHIA_G + + def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): + return stochastic_diag_gauss_newton(fn, model, *batch, **batch_kwargs, hess_key=hess_key) + + +@OptimizerConfig.register_subclass("sophia-h") +@dataclass +class SophiaHConfig(BaseSophiaConfig): + gamma: float = GAMMA_SOPHIA_H + + def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): + return stochastic_hessian_diagonal(fn, model, *batch, **batch_kwargs, hess_key=hess_key) + + +def sophia_h( + lr: float = 0.85e-3, + *, + b1: float = 0.965, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_H, + weight_decay: float = 0.0, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + key: PRNGKey, +) -> SecondOrderTransformation: + """Sophia-H: https://arxiv.org/pdf/2305.14342.pdf Algorithm 1&3""" + components = [] + + components.append(scale_by_sophia_h(b1, b2, eps, gamma, clip_threshold, update_interval, key=key)) + + if weight_decay > 0: + components.append(optax.add_decayed_weights(weight_decay)) + + components.append(optax.scale(-lr)) + + return chain_second_order(*components) + + +def scale_by_sophia_h( + b1=0.965, + b2=0.99, + eps=1e-8, + gamma=GAMMA_SOPHIA_H, + clip_threshold: Optional[float] = 1.0, + update_interval=10, + *, + key: PRNGKey, +): + + return _sophia_gradient_transform( + sophia_hess_fn=stochastic_hessian_diagonal, + update_interval=update_interval, + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + initial_key=key, + ) + + +def sophia_g( + lr: float = 1e-3, + *, + b1: float = 0.99, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_G, + weight_decay: float = 0.0, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + key: PRNGKey, +) -> SecondOrderTransformation: + """Sophia-G: https://arxiv.org/pdf/2305.14342.pdf Algorithm 2&3""" + components = [] + + components.append(scale_by_sophia_g(b1, b2, eps, gamma, clip_threshold, update_interval, key=key)) + + if weight_decay > 0: + components.append(optax.add_decayed_weights(weight_decay)) + + components.append(optax.scale(-lr)) + + return chain_second_order(*components) + + +def scale_by_sophia_g( + b1: float = 0.99, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_G, + clip_threshold: Optional[float] = 1.0, + update_interval=10, + *, + key: PRNGKeyArray, +): + + return _sophia_gradient_transform( + sophia_hess_fn=stochastic_diag_gauss_newton, + update_interval=update_interval, + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + initial_key=key, + ) + + +def _sophia_gradient_transform( + sophia_hess_fn, + update_interval: int, + b1: float, + b2: float, + eps: float, + gamma: float, + clip_threshold: Optional[float], + initial_key: PRNGKeyArray, + mu_dtype: Optional[Any] = None, +) -> SecondOrderTransformation: + mu_dtype = jax.canonicalize_dtype(mu_dtype) if mu_dtype is not None else None + + def init_fn(params): + mu = jax.tree_util.tree_map(lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) # First moment + h = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleBySophiaState( + count=jnp.zeros([], jnp.int32), hessian_count=jnp.zeros([], jnp.int32), mu=mu, h=h, hess_key=initial_key + ) + + def update_fn(updates, state, params=None): + mu = update_moment(updates, state.mu, b1, 1) + # nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + h_hat = state.h + # track how often hessian is used + mu_leaves = jax.tree_util.tree_leaves(mu_hat) + h_leaves = jax.tree_util.tree_leaves(h_hat) + + stats: dict[str, Any] = { + "optim/param_norm": jnp.sqrt(sum(jnp.sum(p**2) for p in jax.tree_util.tree_leaves(params))), + "optim/momentum_norm": jnp.sqrt(sum(jnp.sum(m**2) for m in mu_leaves)), + "optim/hessian_norm": jnp.sqrt(sum(jnp.sum(h**2) for h in h_leaves)), + } + + # with sophia-g the max(h, 0) is not needed but no harm + updates = jax.tree_util.tree_map( + # lambda m, v: m / jnp.maximum(jnp.maximum(jnp.abs(m), gamma * jnp.maximum(v, 0)), eps), mu_hat, h_hat + lambda m, h: m / jnp.maximum(gamma * h, eps), + mu_hat, + h_hat, + ) + + if clip_threshold is not None: + unclipped_count = sum(jnp.sum(jnp.abs(u) < clip_threshold) for u in jax.tree_util.tree_leaves(updates)) + updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates) + stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates) + + # this doesn't work well on CPU, so skip if cpu + if jax.lib.xla_bridge.get_backend().platform != "cpu": + levanter.tracker.jit_log_metrics(stats, step=state.count) + + if mu_dtype is not None: + mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) + + return updates, ScaleBySophiaState( + count=count_inc, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key + ) + + def update_hessian(state, fn, model, *batch, **batch_kwargs): + def _do_update(): + key, next_key = jax.random.split(state.hess_key) + new_hess = sophia_hess_fn(fn, model, *batch, hess_key=key, **batch_kwargs) + # new_hess = jax.tree_util.tree_map(lambda h: jnp.clip(h, -1, 1), new_hess) + + # EMAs of hessian + hessian_count_inc = numerics.safe_int32_increment(state.hessian_count) + nu = update_moment(new_hess, state.h, b2, 1) + return ScaleBySophiaState( + count=state.count, hessian_count=hessian_count_inc, mu=state.mu, h=nu, hess_key=next_key + ) + + def _dont_update(): + return state + + return jax.lax.cond( + jnp.equal(state.count % update_interval, 0), + lambda _: _do_update(), + lambda _: _dont_update(), + state.count, + ) + + return SecondOrderTransformation(init_fn, update_fn, update_hessian) + + +# use this for Sophia-G +def stochastic_diag_gauss_newton(fn: SophiaGObjective, model, example, *args, hess_key: PRNGKey, **kwargs): + """ + + Approximate the diagonal of the Hessian using an approximation to the Gauss Newton matrix. + This is Algorithm 2 of https://arxiv.org/pdf/2305.14342.pdf + + Args: + fn (SophiaGObjective): objective function + model: model whose Hessian to compute + hess_key: key for sampling + *args, **kwargs: passed to fn's logits + """ + if not isinstance(fn, SophiaGObjective): + raise ValueError("objective must be a SophiaGObjective") + + # Step 3 + logits, model_backward = eqx.filter_vjp(lambda model: fn.logits(model, example, *args, **kwargs), model) + + # Step 4 + y_hat = fn.sample(logits, example, key=hess_key) + + # Step 5 + grad_loss_logits = eqx.filter_grad(fn.loss)(logits, y_hat) + pseudo_g = model_backward(grad_loss_logits)[0] + + # Step 6 + bs = fn.num_data_points(example) + h = jax.tree_util.tree_map(lambda x: x**2 * bs, pseudo_g) + + return h + + +# Use this for Sophia-H +def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs): + """Compute the diagonal of the Hessian of a function using a normal distribution. + + https://arxiv.org/pdf/2305.14342.pdf Algorithm 1 + + Args: + fn: function to compute the Hessian of + model: model to compute the Hessian of + hess_key: key for the normal distribution + """ + # cf https://arxiv.org/pdf/2006.00719.pdf eqn 9 + # https://www-users.cse.umn.edu/~saad/PDF/umsi-2005-082.pdf + # https://arxiv.org/pdf/2208.03268.pdf + g = tree_gaussian(hess_key, model) + # TODO: consider allowing for n > 1 gaussians? + product = hvp(lambda m: fn(m, *args, **kwargs), model, g) + hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g) + + return hessian diff --git a/src/levanter/optim/util.py b/src/levanter/optim/util.py new file mode 100644 index 000000000..fccb427a2 --- /dev/null +++ b/src/levanter/optim/util.py @@ -0,0 +1,18 @@ +import equinox as eqx +import jax + + +# TODO: filter_jvp? +def hvp(f, x, v): + """Compute the Hessian-vector product of a function.""" + return jax.jvp(eqx.filter_grad(f), (x,), (v,))[1] + + +def tree_gaussian(key, tree): + """Samples a tree of gaussian noise with the same structure as `tree`.""" + leaves, structure = jax.tree_util.tree_flatten(tree) + keys = jax.random.split(key, len(leaves)) + g = jax.tree_util.tree_map(lambda x, key: jax.random.normal(key, x.shape), leaves, list(keys)) + g = jax.tree_util.tree_unflatten(structure, g) + + return g diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index accd16b68..c890fca74 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -1,4 +1,5 @@ import dataclasses +import logging import os import tempfile import typing @@ -16,6 +17,9 @@ from levanter.utils.jax_utils import is_inside_jit +logger = logging.getLogger(__name__) + + _global_tracker: Optional["Tracker"] = None @@ -42,9 +46,18 @@ def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optiona _global_tracker.log(metrics, step=step) +def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + try: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log(metrics, step=step) + except Exception: + logger.exception("Error logging metrics") + + def jit_log_metrics(metrics, *, step=None): """uses jax effect callback to log to wandb from the host""" - jax.debug.callback(log_metrics, metrics, step=step) + jax.debug.callback(_no_throw_log_metrics, metrics, step=step) def log_summary(metrics: dict[str, Any]): diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 696e31a76..38b25223d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -27,10 +27,8 @@ import equinox as eqx import jax -import jax.numpy as jnp import jmp import numpy as np -import optax from draccus import field from jax import ShapeDtypeStruct from jax.experimental import multihost_utils @@ -52,6 +50,7 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import accumulate_gradients_sharded from levanter.logging import capture_time +from levanter.optim import SecondOrderTransformation from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils @@ -431,6 +430,12 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): )(trainable_model, *batch, **batch_kwargs) updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + + if isinstance(self.optimizer, SecondOrderTransformation): + opt_state = self.optimizer.update_hessian( + opt_state, split_loss_fn, trainable_model, *batch, **batch_kwargs + ) + model = eqx.apply_updates(model, updates) return loss, model, opt_state @@ -734,109 +739,5 @@ def initialize(config: TrainerConfig | AllConfig): levanter.tracker.log_configuration(config) -@dataclass -class OptimizerConfig: - # Config related to optimizer (always adam for now) - learning_rate: float = 6e-4 - weight_decay: float = 0.0 - beta1: float = 0.9 - beta2: float = 0.999 - epsilon: float = 1e-8 - max_grad_norm: Optional[float] = 1.0 - - min_lr_ratio: float = 0.1 - warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup - warmup: float = 0.01 - """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - cooldown: float = 0.0 - """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" - lr_schedule: str = "cosine" # constant, cosine, linear - - def build(self, num_train_steps: int) -> GradientTransformation: - """Creates the optimizer""" - # indirection makes it work with optax.inject_hyperparams so we can log the learning rate - def _optimizer(learning_rate): - components = [] - - if self.max_grad_norm: - components.append(optax.clip_by_global_norm(self.max_grad_norm)) - - components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) - - if self.weight_decay > 0: - # TODO: add weight decay masking?? - components.append(optax.add_decayed_weights(self.weight_decay)) - - # - learning rate for descent - components.append(optax.scale(-learning_rate)) - - optimizer = optax.chain(*components) - - return optimizer - - return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) - - def lr_scheduler(self, num_train_steps): - warmup_steps = self._convert_warmup(num_train_steps) - cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) - lr_decay_steps = num_train_steps - warmup_steps - cooldown_steps - min_lr = self.learning_rate * self.min_lr_ratio - - match self.lr_schedule: - case "constant": - schedule = optax.constant_schedule(self.learning_rate) - case "cosine": - schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) - case "linear": - schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps - warmup_steps) - case "inv_sqrt": - schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) - case _: - raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") - - schedules = [] - boundaries = [] - - if warmup_steps != 0: - warmup = optax.linear_schedule(0.0, self.learning_rate, warmup_steps) - schedules.append(warmup) - boundaries.append(warmup_steps) - - schedules.append(schedule) - - if cooldown_steps != 0: - final_main_lr = schedule(lr_decay_steps) - cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps) - schedules.append(cooldown) - boundaries.append(num_train_steps - cooldown_steps) - - if len(schedules) > 1: - schedule = optax.join_schedules(schedules, boundaries) - - return schedule - - def _convert_warmup(self, num_train_steps: int): - if self.warmup_ratio is not None: - warnings.warn("warmup_ratio is deprecated. Use warmup instead") - return int(self.warmup_ratio * num_train_steps) - else: - return _convert_ratio_or_steps(self.warmup, num_train_steps) - - -def _inv_sqrt_decay_schedule(lr: float, min_lr: float, warmup_steps: int, timescale: float = 10000): - def schedule(count): - decay = jnp.minimum(1.0, 1.0 / jnp.sqrt(jnp.maximum(count + warmup_steps, 1) / timescale)) - return jnp.maximum(lr * decay, min_lr) - - return schedule - - def _params_only(t): return eqx.filter(t, is_inexact_arrayish) - - -def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): - if ratio_or_steps < 1.0: - return int(ratio_or_steps * num_train_steps) - else: - return int(ratio_or_steps) diff --git a/tests/data/hero_data.npy b/tests/data/hero_data.npy new file mode 100644 index 0000000000000000000000000000000000000000..f39678d793efddbd3b2df287f35ce5e38dccfe2a GIT binary patch literal 32128 zcmbSS_dnI|`?u0SL(!6zRYsCGlDa4=GNO`|Qg$e$fy&5gXc!G!B&&>UxsJUukG=Oe z*5Pn$zUT8Vd|$u4UO$}2x$o<`ulstAD?syx##McKy2o@b!lq`9_Z)=fd4(0s<%Gp} zh0Py0JaRCxyZyky)a-xXFPhjmn$f;HTADmGqrHoql$4Yd}NWMsXPfn(1X%#0Zf1^~jdEl1T`dVH1pSpMm(CE1~627ce#H(Zk(-YbaJ! zyq&dR0y{F_>Pm2xfRy^(E#}Ajz|Tu6y768%{0Io==6XB}0U;lH4E2XG=bIJDv~?73 z@zGz@c(aInUiNpmz72q&!OxYxgf*Oam3QH``#cP0N<5c#-aywK*&gM?gXs3#^`P_A zFivcvYe^E#gM>Fa2e`VrQE*vD?6K7V=AJzu$*M7cf%30jXR}nprN|^=d3y(5xE!LC z{%IME52(KAcJ08rJItTmU!(&QJ4>rU3mL4{1`Ja^H(-xFd#_np1|*8Q=l*i=Lp|?L z$G**P!uhbz{El3mC|e_&M&u^@aFO*zH-feDrobL<%J62pCL4 zP)oG2c}WX!oR+xmr?!>gG*JIF#eE+4>KLau@9M%={K3~6_mSb6eE;5`=Ud>e)3@zO zcP269^lw@E!V0)JAhI(kp%Y4T!{VeR8?o)A3VRZN7IYh==WD#8SFB*{sqolbMB;=2 zSG~z7E{5cm+|BBMpXv+}{+lz%_aQ$$Wup(2iv;_|zRZH@(U&9NOlHyImiw4k?jRm| zEo^qol#bxI!%U_5X&=6ht)m>AnZreTXC6PnVeC1`f3I9>5zg@>bybZ|3|4BzjFptc&VGnP_yUN3b z1_oWY-ka=V5-dl$jNQ$voMEgkqD4EVCUjA5N7V?@`pO;o%yTPM7G9Fk1x zG&_t&Q8wYogy7;Ba6OCW%B2iorP0&>PDf9}biva5Bl{Q?zn#7s7b`IfH(s*P-w$LW zytX*6dgkvqvIqi`Q*SlsXf=zK%;dplVNG^})d=Ld^H14ZPa#2!kT~vLgPa<6>ySHN~lG7%0foTwKi|X+?OK!rh&L^w3 zXO_Tj)Q2O}sRNA%FXn!FLqIm85M^Pm7Wk5~_~^1wC^j@SH|rZv&_ynwUO;IcIpjSY zOD3kV;iaE8-)IlqU<+4O(CCNsbN@xm_*3DyZ)T;*+7RxkIOsPMTZECP^dF_4>xV~4 z3G90!2eDA_`#-x!Wbn}1p?^?k3XF61%kO+LkMp6u%hLz*VWFf#Lyx@}_*W|nEwkFu z`?N5TDKZ0=4~G_hc|41DhZcpvdj*R>oHZb5rQrEY&5stxTA_I@r=e$a0zWfWRXRRR zfjgZGO<#|VqY(9G%BrKaaut$YzFTqIR8B0F^osHeo23$L5sJ5SxFal79%Bh z7^CP6Ofo8yx3TtPO|1QLbHE@n7>QikB3=*wO`U%&<~|DvdtGIhN*{uq!FKNg#|cnQ z724i>ijJ_PIn`E|MZ#~Q>Nibv$WU_LJa4>r0M^+SLz^Thu;YuXOgrx=2nzeA)|aJ$ z+VOP%^Zx!w$5A-yn3;jvJ8vEJ`;~>rGJ1aaPb5gscIkSa9mg}7{_iJNhvD?WXCp_& z$++)?s&Z0)DZXWsx|Yf@hpqk>KImrTqrcuypY3^M*!m>ASEVKc+{$j@(_ z?&bh=ZQAwk!U8a>WlSoU%Y}#{r~ZWeK71H|Q>AF33q|$jxa8I{h&Iyr%c9vI5$ehqd}VDWGh!UnV8K3T)-p`dqmNaH5I(vizYj zpbuhwFwFW7=Izg>%5EEji<-(hVJRyZf6zG1bdCsK5(ga@Bo|?>cp)gpeio{X_w0;J zoktbI-9NAId_)G7mxI{?xe)BG7oVF%MRp+Niur z!Izh+Ez;|zph#I-Z`QmE&K--)=8Y;u@leid=jmr4;1S)@FaBXnSnY4ia$kb!;O&3f zdTIMobE7G9+dR}?m-PE6G5~UKw@O^yItTI-skhV3)}ZxUwDULeEIc$8C{dSr3F~O8g(Ad~nvKA6#@;!S0V_f6k6!xX&p0frpXEv}o7*TWw34^?6Yg%Fk~v)EF; z2tL1jgo;!aaD^Ycx3 z`gqVJn{Ee!cGuB67J2wbkWakt_Y^7%`NuVV97Pd}-eT+MeoQ_4eRL14&VC%*{PslR z7yf<6G%M@4gwmB8l@df|f?($Hr%reKVfT%V%!6uV6l#j37oqh7)!O=5ZB>h?c3)Oe z#I_&)@<{D@wtWJ_T2FN7%=H5{y?s|9I|T=Xk`kR>BiNH76T?Hgv2&8j{gA5~cO7sE zoNFN9q9kD(F|ZTwwq}kmNKb%V!z=6BZT;9mA1(1fbqUk69CeM(Y$7u?CF(C{EA|jR zY**}V!Na7KMcwZth&|n<7qLbHKK4=KYg~r#!f(`~tqmY!N=m#&9mWR{2?jn_8qxP2 z(~f;5ZMgl3_Bky+3O@51P9EV}0pT>GBh9a(3@ZcPB$pc!Jsi z*Fp@kDy?{5M#hOJ6#c*&GCaIBox>Vfh*^ap4z({Qf#uQbu1u*FoDuoHHSkF}(m#+b z&;8p0GVf$(6B1@Is=>u8g|=P|3F*!6+$|X7e!oAutrdQ1FxUAXr&Cl)5Rlh6-j2ym z)F6E|Dukt#JuaQeLB34?z-#ww@!c_Nspb3vkYf0(n!mppx&APg>~5!E*UdjYz0WC7 z;d$j-cu*#;=7l~Ge*FiSW*v$>1Lzf_O(Mza7Z*Xe`%At|OAZ8v$7cQ9J&1dAO^qi- zSJ7gBTJH;EDu}R2hi7hS0}>TdodH2SbVdZ2$8s;(WO-)`HD-yB5MJ#*qw%Ooq_ zK!8H4qZ&Har#o9RD?M-xef*P~{~~l#tM?b;V?C{#5E_Vah-VPIt6FX153reYyz@ zD`ea|`22gCLO-51;PvZaUB;f3+7`ZpB=nkpbp8G^31jv!UK{z?0-puG=H=1yU}VCD zvcXXTf48Ns>@%7Ly5Q}e#YM{)=l7oEur&{lJzaimYB!31vkXlnYsXPCa&~ltR);np z2AfVinZ!JiQ@*k9DX{)yyh3M~NwHz2;p!Hf4j61(|NhRV5x?KRTI@Ygh!Y3ON%w7M z;e%7Xiqy?!toeE8mP79tid9TsdH*~M78ayy-~_GC8;>15d76q*$^!ZPt$EP>om1)C zg+9zR8`xi3Mgo^F+;Td$Nnr6u^3+QF5Sp)K1PB(cA-`$dPld=ncz(>u^sWXOUX&&s zDb(x)#@Wi?uFYk5_iy>2{N_56izBGR!UpJ+nLXgi-va_ily_w&%)n-nhV;tA3LIzF z%fH_?hvF|j?%YY90Ed$#E&@^GB0T*J7b`rrU zTl5R-wh?ToJhf{1bOWs2bB1u5h!sokJ~ge);;u{d1)XMA7X*;IsO|0O)?#dD8 z^P|vu*}+Dzmx_&!-8M4^n?b-=?v8-`B)%zpnyQ?G zNP0}%@0ZYr8oX~aT0#a;{ESbOz?Nw+V@!UF^|eM$wb%2L0Il*Irr~E2e#hiPv7!y4la2e%;~)O3&mDi*mOM>P>n<6!R@Ur zINNca^3s0_HrKB3C@rplM+C!DInNR7x|FXpr`?0$uChT|@A~2PvHf2J83!TLbBLd3 z+c56!4(~HOoPv){LJSteI$^PZUaR8E6a?||$u|hkLTJt*-4=slJg)L-Uh6V54KM#b zbZldecK*SH)odD-nx0-|WotwA_CKYUOLK9*buxWK)Ho>kBr*x6&7&lhZKO4Q2}#wI z8g2pw=N`))K6G#q>_6JixP|9HVa@}yA+9CZ`c~nBB*O~C2A$m2_NW%BggLEd^p~-F zy{oM6TI9Cj&r+XzwgUNY_BXVyH{i{6|D}A6 z2G}`%I_19u9a#E!hwG8LWwhGy#w6&(EUI;+ON<>X#0$ryT$d$hkT}|QVkmVCcTcUi z1slxZkK!POEY^9rd|cPaaOF8>NakJR{`Qn+^;g~S7Jv2k2#?e(f8U>r>gRm!r30ksuF;hd^*KHA>ICo>RqstB%fmO zuoE6+)erMnRb%v6`{!xC8O#o+tFnyj!>8x?ocSf&G2Vyz!Ru@zeu-Y>^X;C(%`a(h zJnwCSP5Z~SuA(f69qdX=%3Q)i8(x^I9)^6HHT&f5_5(HrSO(b%wzxd zCcLTUle^z?6HmJS{&%N?3|*%k+>ZUFQ)GVeP<3ix5}bX1j~_Cl;MB8(-y+FVFtqoH ztFvJbzApd3Ua?Gq_wzfS1c$caBV9+HW{v@fVzq5jDw@Uhi&Ij z&k?+Hw?w(}*%H38jO9B_UIpgSZO=a>QGjh5wdkht5I!=hi4V7+@mPWdMF-b=pij9m zb3vjP+WNRU+-AGLd8c^T?FC{ zr!QG|_ahEUl=dmDV(48_>HIHb5SQw{#2r2dX)StQY|kfggZXRMS*JQU=~l>cyX`Mp z3OiV}-W)}fZ}f$yFHJ-CfgY{946VrcIp#zn7X|)_ZDaYQxdPwltTijWyKqPM&wQsB zy>MiUWoq^}THaCqY>SDf!bLhybDpVi43F^oqrGJob48X!1nk?eo%!^G(8ELs$+WVV z-`Yyc%Rm#SU$eA&PqvKoU&FO1mSCeT!&uH5-$NXrfLz4I|JITgVW%Q@t)*cNeoY+S zwbOe9a*vRX?50~o#m*v~GK&^?`Hd~U$$S|9GnpA_weP~Cy(aCyGn?@ zncU2~CQzv4i~M@p0=7?9eLH^VC%&Q7Up`B226qjwTU91KD4tp6AaQ&WETmG|=*H_& z=oA|w=b4LM4nl7oqkG2o zR{WU9ddRbP87+3GMoyWx;Ad%XmTA#mv_1E}a-4?0FJ;VM5z85YcOFA)d7XhM%^F+8c{r4;IGZOes`W3b(=Ao9^(jIHl9A3X`O8D)(gnl`# zt|J|zP{Z$FD$p8)%xwy?FP`*5@o&vn=L*}wM*db}wL1~awg0VomC*LniBWQw+8Q{! zopbtSTMNYtDaUH)E76=d-Kq637hiv>Rxvz23mhRFXHGJX;I*tbd!`$H0{Nf7z!TXS z_^=>(;MRF6th`u=S9?V}r#xTci1izo`>y3N`N9Ga@~YX%^DW4FzVpFLrXT3&-WMph z^b`d`NcP8=I-obf>kdc7I3!jy+pS7;L7RL;Xu9+m@JAn4NKDxTf^CXmjYKhYK6r1a zTt!EC_#|+~`|=jWabN$>vOH9{`g>Jh>dGP{23}CV7KZ8 zljiY%koREZhg;+#TCs^m{zt>RRk6b?Jm-gj_vcX#Nsf6utgUL0cbtK6$TI!j1alLJ z*(RR6PoBmwdDrizom03glMQs5YH@qXOO-zD0U*5{r}tdxfXmF@$wOLyAoFj3-p9)Y zP-uUh)J%&DUyS#Hyfg`3XBKR`UYdYyuMP?YCiSC=F7xY;MHS$?`Qw{BXC~V4g!4qL zw1WrpS?&sHGCUrUJ!f`y01OFRgKl)xqnpO0cV|im1n8}Eb?S^EvAT!;HK7?M&z$Je zzc++S@p-YGH|D|lx=OzQHyNJAt^bLfZ^h@KHOGIS>;r>0tJ5b>P2eT|@$+n&6ksB) zSf{GwfsTvo7QWY1*qe1^bWCgtjXOvv=*?^y06fu7mz$coe5;JaVP_-#oKbUMTt<)55_U&jjFZDdpN zw>HZt_4xq4E;v^4qUiCJ$_CWbA{p;o1O7UyqX%XY< zaeQieugGp675-~zFm7IJ#a<1kN9t)5;0@OcOB5-F@5~onG|tg0dZ-m{z4Ub!zXvxf zexN5o5}j1?>eopK-c0-S`~w-^kSJFIOy~&V0YOZ);>&obQU>5G{#+qM7&iqSK?&1GeOKj1^{ncAH@mZqQj%`-NrzJN#yoe~dvk zTlTGLE&?!%q^B|PufT!HJ*UqA5!trstEgP+2m9)Aud0kiY|QPbmaA@rEf+RS)XWm` znXe{m%04RYQonMXnSBYb<$Bi{_cmj!%f4pc-0$Gt!`#UgIfPbxC)Cw;tl{$28#88U zBk*+J@14LtiyRwfQN?vcWL+?P^iyF9BPTwP9))k3@19|UO4nNcX4gWb8niz4nLO{!w zh+USWuvo`>Th(Y1TG$4(EX`)nrx%^#d#!uC*DjJ_@wJdZTs56y>NC8nOC1GE z`B(QSe+Qs}JBfE>Y7E{U8Y2H&qwxg7GxLXy24TKJ_D;gDRcwPDr=8i#fiiZ!Oy8sf zS7vu!=6E;{!+nDJ+O&PR;z|)xJogVl%XdUgo)(|?Ov-LqVesMKrG&KhdFV$T}l52p!u z5_|TINpdd)K2hmj9N2_VJMX*w;w{imPKw#ZH3v=ACtrxjPJ`U>AMm-b32Qt4uwJ=8 z0*x2;USZrZ1PUxfeRb^>;4TaQJbHKtKJRC8yYZtHdJfR~2)-fEvXwC~`P75=$7bHW zU#In|IIA!G1JX(!)zx*+W7rsM$f)$) zY0uxrR(!zdZyeN92_Z?$aWNC~=o#p<+<`?Hbgl3^-8Wi)!d@w6e}4cL*k9&uX`6=o zVGhV|nF86#h1Y|6sHkZr5%cv;J^qazsC4+vsQ3|H@6R|)L^{8oH{NU%*qlFb%(*E6 z^iEC3y4H5W$e$N#g)<2tNAXft(;)z|?itj7yoxUkDBxp0hR%;|CY2(INRTe*Vw{UY z?U;GXEcz-a7x&}O)L8&gro8>T+D5Q`-0*}2_XOIvk~rd1da?PVzq>JO3|?%^`E<-< z8i&;VTX!%HVRn1v@z1>dI5#JA>b?sZd;X4Pej^UUXCsXRcwzx+^2?w8S2hbS_SD0% zBPIA@uZrK<2s*;$#AijZ+lSziG+h)2OD~)`E>F1GJpw4mt`CId3s4dXQ2r(m?6D%u1>3ing=zRsgsUbBfcjn}VO zy0P>`>6|HX`mC03F zfXSTcH;Q_naobO;11`}$=vY!Jb@I#r?wS4SV{&T_{=U@K+WDy;Lb}%}BZu0+eIsJ& zL|{Ia$uG-XenLX-tEu-k4txX-vbeF^pHAedPpOL09HsGG<#Z96jga*7Usn020la&E z(|DzB1zj&Zbm?B3gus)3FX&55g8z>!tEy}w+)y_-I}%B!_*VDl$Z5ZSIQq9UVtIWQ za=%L%=>GTtz$g4NO|TR>uJ7i~_(6u+wYLF-5fj+p#rjCWG8=xIn|ZF>8v>WZN8F=I zmO%054~noj2}(13wVWf`r~NSC{o!$T*fic2H&CeKq4^xHy!Cm!M$@xrUkSRsrBl4| zOfovDwi_5dyy7lcRKwO-qj~wCL`b+i_Q}5_8lS3cpv251*iu*&Iv4su&JUgAgnvNG zM0*ksPaj@L6O(qG9fB9dZ3w^Kj0yOBe5uf#i>H5gyHT+qWph9T*5&E$PF zKg8xyoKx-s&E!t@jDrLOdA_oD^xvU%(fi_&6T{FM{e=CxDH)y_>m9sGVA#*2&ukip3qKbuu%sVW^O9V$zC_OFehIcrt^ZVt9E) zx%^N*W?3}Fw-^(&AFCeA?}57GN@-t5XusoI?;zLl|s=zl?_mDvdd%h>j%v23!i+J^AAclbVt3O@BBXV#5u0fBLT=8Kv2&yw z)|Cm$cEm;u)S1rtpxF!sM@GhF2PYwWv1@yrRWnjr0{WlnjlfPTX@!A>Hr!jfMd*2H z8g3A@pZ}PgMsmyTh}XD^UNz(znQrR;B)SG z;Hd?eP^~PK_)CJ@ExL!fFZZDo|GRPlbsB#&RU>^X4RKoBRK;;>6(5h)RrMv2AnoEW zf~{;n+HU8VsA^(Ve7Ea|{PFt~6h3jVzv*xzre0u+zRt9ar#CbXxk?UVe@}W_kt-G7 zucv-dIUS8hm~G#t@7$^=vA4EVO?V!TqFix=P%l37+a>zie-8HLN$ZdM)B}UPis#=G z9W?zo-9y1T2PZ$Dc@^+&3Ckz09u<`t0bi;j|5b$rkT406Ihwo zOs{DAhfeN{%7Q0n@ODYYrE|SKcr@y0wx3uZWUl9qB_x%B+y#E4ol0rgaFp0OaIhYw zCB%^3AL5!1`EnBJY1Vsoztkc|RFFXBTVbZi!(CUBBEaWP*sg!pO*#QoX2_RVghQlGg} z(EScL)KO$UWZwxq3a5pF-J3C5@aCB$-yy74XZkp^KUxYTe)VgtYXWRhlBF}dK||$p_l%Ao8f=cFnlNEBG`3czN(1b16q7W#<n`MHHaJQG36@FJE}Hetg3=|<(y4O|-eY~7dQk14jNmclb>IOO3= zo3q^g7?x!GWYdgq5I|8g6_!P8FsmJdjH^FvX3pm#Ya7I;UT`#y?gJQ@x?)IqrF zd0VI=j)o6!EslJC)(`odJ1!?LmO!~i%aG@zV#wP*U+}kM2?e>jnST{c0Nhs4l48rk zgz;_dAMUn8!()-QwzyKjEc43L`9@5zS5yxhNJrza^fR}%4ne6}mI34RFf1!w)cnZV z2n)@xUuc*XgY1NA6)Sr?T<>~!rlmX{Iu_r1*O<)VjP50iOBo~ZUAMe!@AM>yTGR!V z?5o7Bp9XpaPSGhcJY73MFdM|vf7@)RbzxvrEdJE8w;A$sidpSW57O=h()+Z7y_oD7 zb7iS+3QS_dR8fl;~J&Wp5lr?%8>+3wy6Oa~GS>JF1&8xO~u$jmm>`Pi+=M>S=pd+}GO~$ex4#2)U&yrudRbigI%4IP^ z53U944hDW(fHT)r{+`s2hn;e!0pfEc5Zm+#8oJwo`yc#H;Ys<8i(D(&$7s5w9M|mO z_NOzDZgE!NXK^dYUF}K!`KkqHdgt~l*v{h$&w%4Iw>;R&dPt0qhCfFL*(;L`tH7pM z55i;VaAwdZ$NO~w8l01`Z>pf+`;^U3Q>ha$(Y4)$!*2{T@4qjeq3eRZFVm$G_-0Y{ z_{$!*n@i}+GDF8m--pBHd$akr*Ff5*Q#~eJ{m|>w_PN1+0XPmtM|FQ5fbbSyX`!_t zxbm#}yFySUq~;7-3=J$}QU1P6r$hZP&#u>(Av6IM?EjrSvDOd#A*u1A%8U4D=)yhe z#wPsc+cz&l!4#M_ z1sC3(ySApjh_{dII<9h}2};tAq8q zY-XBY_FrxGz&rzCk!vDzkl2gmHlOK=+6KV);%?_g8jrm{*Um4Vw+k8)qrcZgQXte* zL31p&5vhWW$#EZtAY5^0%VI_^cy%BCmrUA3ZSpF^*-w4w&(P#y$eD;2N8julacM%8 z7KP<8Svp1A*CwAbzeVHWC;JlY)JE_n{jFE$8i$~2;pl3u%&{s9eK!TgskT>b34J+QaVPnyb9Bva?Mng zETC;hbCHwzD5N~?XZzYU0-tn_G};&}!q%jW`RdVf^!X%}qjtO*Uj`Q%oK9vSgeQTW z)hGpezFd#z&Rl^bq0n%Za|Lp$ukS3`9D!^jKOd_Q8h&@tPF_zRfK=!GS$TsLJRj-r zaU&`d{T2k`RD$G^hAJ-mU$xTCdG}=DZuMUyzgQ9M@ z{P45x?IBO_J9XrJ*M~`zI`h2rirXlD6DYP*EX)VPqbt&4w0t_RDaHbZQ_!BMQS&#Q zUeRm+cOo@!091}Jv@5w$VLfbpf35cd=50z8zH1mq-lAG-dm;snDz~l~_>oX3;ip~8 z){kg-W|iBGSPH}`)mZbSb?iFyi=Z2q1NFL}UW75zd`m~?8Y5(e(1%o}?t7^LMP8bp z*m-##o^gPH>I-^BX?d|6{rMSet+F4JEg!>*hWVXds}%UUr1?N0VisKz?LAveCorPu z;};*%HoW!bS$OdFWh8BljXzG)dzaU|wI(-u!A<1z{a+lju-9&|wAXDB#SfpzyEs6F zSK*?L(!G7aXV*7%-hK{;Uu!MAlU+uIhUY;m!`j zgJ`HG-B>L&j#7`R-k;K*!ipxXgt+D<*t~x8IZrc_qP92l*PENeaF@r5?eo3~-1c9^ za#I-+vrkIkn~0)97y8q`$ur;5LY?Lyzf47n<;{^}f3~Pe`D+Feg+@w7Uzl4kH9X#R{qe&o3PgTRemS+b z5SleVT)Do5=I0t=li<&31bd^uhut08;D~_t){)o4pqW>^KuxMey7J$V(-M6!v1WbO z$7UKicJ_K-{W^vLlRd#oN<`Qm+h0jGXom2`uQ68we!|;9GE&jrp6;mM~04ll#T(52FGklrZ5)N`X1H?MxZv2l15!U9?@uy{-% zcShF|g)oo2RR-lJY#M<6R&;+_KZD{H)IR@Qb{))vBe#tjPU52vmB%6_#z1F>$l78t z2`y5-c>VC8()@yni=paliq;!hJ5#u3;nyy+IriHt=-JLzd@-^Ry{~;5wr3y55Ejh? zmHYZ3XXlm2lQ({#*gmTWURobj;cI8yk<^8;XO|qWZMNaq`1h>MjU_zO7Z_6Au!!nX zpRaB|--qhsI^&jn>7Aj&vG9BITcnGDUp$vU1a^gju||5DspCzLNnB_n(coR zIfbU06`MxiJo~*aU7x)B(A_c@mswhF*3R8 z<45x;l%1d)RPF;2UAe8&Gi_jC#dbC;ZW!OzA9qeGnZ{{>8IHSmT9NDyGFla*7`I&i zXUgwk2Gil&xJRngDFn_FY+KWs;_Ya_|w>gpn1`z^`Cbz}su*BF|ws;|IY z*^X~>!wiJ%s4EGliwE$@#E0$u$~4~a)3u$%=i^w)7Ud_C+k<7iQEZ*(M^N_9lP48} zt*=Lj|1OW~8A<<2CH2t{h0Kpv`TF@_*%`#^Ha# zpP%jP?gs%yE@FZC5Ip02tI76}f|ORiY%AVTOuMaG=#&%*dKFc@Vt*SkrY&#kKvg>& z8Ztzy_3v15L+(*%-~dpod?z;@7zs~bH1iGEP{DO=tz23C7j`>Im%XkJq{U0{PjdDQ z>}P5cDLUVRtuL?4Wt<|RZ4S?|qNP>XT#)U%bF&$)sL?&Wpg4|!?_=c-MoeQyXqwML z@&Fht-%p36IjF5L1jDK|Shnrs{ZKFrPi^8_uVr?Rw3`k$zizj zpX8{eGHsq}lfF+ABado_hj>w6tWYFlZ39t}ptpJk;jFn@0DZO@?; zG&I|!Cs7(f?n$;<1OFTfY~A(n^28kcXr=hE^wJUXudJFr^;*JF8?sgVm0?^yka0aJ za2{^V>Z^^ui@+(Fw|#bhMqy^{qPf;)KlUG554=%D(@T~lf+7`5;TKOJy*A4Vgx*Bz zrFTWRE7CTWU+D*4rSd3WnjXX-Gy29i)sw*AyN$$d+78Fek6w5b--=Ltzbi3(9i05X zD9rAkhAMLF?cdX6P=dFw0zC$B{g3R;GJX<<&z6?z_VmNV$;HMxGcp+Dh2BxQG>ArL zKItDjQVU_azjRy3AMxy9h?ITc2$IAr`SdbIkuJL-q3-1hFmLL(=Hv~amV9*e_wowN zYr8zLlVu35(^+WJJ?w+(V4=KkagA`o{ovt#&QZLmL?q81p@Ot{s4vNA1nySH?T9!- zr?>&9;;Mpb!DK-qvN&b{@657rN*`Eu#IjFzuCw*y1(>R6C)19FZZ4ca$7 z1hEa)YRtC=(QbT-GC)}aKeI0pLZks?c}HJx?*;{4UrJd@jqAinp${fM9?qcu8;2^d zuLVeyKN{j|Oax`=iw6s0n=yAH#aZ`#F=VCr%#1XZ!n>;lUve`>@Rp$HIatJzJ1D_jR9N*Am3dGm^7&=fI)To?VMJFS5CjpZ)gL5{`gkeoXN?q~^+Q zHDIm;x=BrLv}I6~e&8#|o7n(5r(92KYji>1=lb5>l1V%$9-~&XYn3*|*>?Ed*D6}+ixQY-up-JD#19DcW4EkPbPm2x1=Dc>c6aW&3(Y2G&w)zLWa${ z`>NjFUC>qUZl51tk7sn=j?0!+fyZ0Ne+gJhtH&Pdk-R0)CYZ4qZTl)IAU=w zn~c@pj_{=@w?b!tJ>BSfJy>1ipHlwN4mt-G*j(HDVc)L!&nfq&a7I(+d-b+ba3kFL zSsu`Zyk|^Fcj!jIj%ed9!`KKNy}s2B0k!z&^Ygg%(J7dfovCc=ErV1+PT~5KRP=K$ zsr@Ur1T1?VFKzR0MWy62!|!Lt;6n4u)1K^2XvRAgIiXF)Umtzq_44N6_a$pK)AB~x zS8vv&Ur|Au59;rFr_>Al`Jla)@2uAO;qyH(<9lz{c`GuQHQxO6 zKClsjFQzm}M0UZo1E(z5!h2}`b2fS9&&yDvL*Mznj@FNjEbk4b&3!oNc0AF)k_kMI zW}7UW$*7?z%gkcEh?cK~b@4<49uTgv{q%{EAh;woeVE2$iC&RSj?P`hdw&_CLWe3b zDYspHpK%-NPltlvQWt#jimHFsPJ;d$etcj44M2XTrSU%=651yXjIZDtR45q{^CJdf zd)LhpzJzYLGF85+UB3j|cNqJsJI(^1oD>W{G#L^Duc8^YBS^ecNhoVWiBHHLc z5=#T8mD9W8C;IVG^O)q^o)Y+~?m#qtMS^#Gw=XO_tU}KJeAF)l4S;gM!OIy7S_v$g4wd zdA}Z_W%9m@&D3LVns{3y{{%F!e~Gg_*$I0eetO2SI)d(t!Y#UY8t}Z>$7e=+H=v*H z_fUIO82B&ql2sL2@X@8Xg9DD!aE|U)^QG<~5OjI}FScM5MB{czCsLYm?@g%?f?Fr1 z3&;ghvs!UP#6L-&qY13FrT!aY9RQb4s<~VEitxORfM?6Z03Io8W6s!;0{K^`D2!YT ziiS4jadu0U7&!&p?zFjuIm0Mhn^htnN!!x@<=!InWb5>`tgdW9oq)oea=)oT(PiMxkQsnX%@5IoDC(bj-tvr zDQ#w*0K6u4VSaj<1kY`%Z9}Umq%uPlDffgX;>$hOd+V3R@reuHG@oKW6kllmx}Tf) zd#tAvyZojxCCG%n_gpS~nm8epG+Y9f%jOwWe&aZ9WR#g6x{CY(Z8FVeJ-BVxf91ng zKarc-OjP0?!f*2GjnC8{qKG|H{eFuD>>BYhZaqbu&lVl|HF>sS-vzDfH9ccs{+%^^ zLv|9ia!zC)_H2f|S{*OX)Q-Y-b1VPct9dk?_;td_p$`kfH`woOSA%Jv#TfPGA{wWQ z$ommF2=R9LcjjCvVBYkCA}1sXmLFXqw?6VkN%tGKI8V)CNxc+Up6Y?jKgvM`j_tVH zRyP}NcS7;9;?Jhz1Hk$6Jm2o~WN>lrXu8*tgR+XxuJfiupkCzXM%8+v|IIw@)6De> zjVTlk=@ETUy*_z=Ey8!tZQWsA(?0_%P4_oHySL%#4(5BdMnr#1f0*0oR6CkPeJ-fM zIr#C^x8Keo9EE;Z9D8cD23M=Uz7@YfE%hcj?P%I?Cq7}KkGx$wgUs~5L!!>6;|N1M z|9^e;m?xC<_^M?k@XUPSKFK%+&Wr(K-w1zP&hLVu*{@z~yU*NrhS1q0Eu?=LH&>x@ zTK+ZlZWIDmUZU%lT+Dv`C*D({4YoHOKF_$$V!3=@{t2NLEZB8PKQD;r zD@&IqM7NHCW}$+bZEY`}d~9Z-;xGikUv97b_qGp@-S?n)mRSx~Y(edgQRB$`V_;j4 zy&7b0Y`s-g?8C2fYdjBIm%+@>Z0~Q{MLZ>bQUB;0B41;=@5%0fX1HwHDBR7Tfi8Dq zx`XVO(SM3U#U^nM0zVY1x2cSy?x9{g`}TYom>&i05i%-XJtXuyfQ%(`6b8m|Q^xwz1D!Hd8w}@r^?W!PKS|3@!tcUaaek-hDf2$QM zZFD!Dk`q{!{RcK(=Vl~@F7jkx1Hz}w%(@Ml(8(wQa*Kyf&)~fiwRxRHH2R%cJ1X?Yw>m1 zFlr*ds;Ys1lA`W53G?7mDt)vCVk5p0E=3F z@1i5?cYjS@sRNhSuS04i#VO{X($kmpT&==8MOY=`vPMPeY}m;`+>y zI;be@2>zP3fFFODI4ZErVVPgX#}^?ZpjCW~?`}{V-j!bT6ZpId!@ojJ&&JGwgSy?^ zp0RrLV6O_fajyrMd#+aI5q$vLc6EO9gPo`qT~2wLXA-YA?>^e!o{PUfP6gW)_Myo? zl^>obTHzOSeC~g|b5OrVw#nSHiBBT6&i!|dQfi;=qlbH@8-P|`q2UIRKYF)uA;f#2 z2I&WS&$>7cK&ym@Em?mE1`cz2-BYTA$#_<|2VyHlVA{!wWaXA`<+f^4otisuacM?Gtz`fU{NP~5mr_hbN-Lp!yKq>28W z8-s29lO_zT)X@3f^bg*${S{EDo<#pJYmCAtx`H_9QRD#e?K-wqgK1- z-v{SbHmthO_rZu_=)cwLi%=U-v2%`CPaxu!c;a3gXq=oB8?IOa#rJWmwBBuC_d9xr zHKD^B_3xjbBzl2Dd-Sfj#`j>?yAqMV%NwvB8YDEuR}DV`4Y}r&W+Cj-(Y_`9CQ$m| zTA&d-fF=QNKR)|JMJj3jbM-*(5I8nVy8V*Lfaks@KUbv}z~av`mojM`rryNHCDZ%? zI*Z^k!@??<9=tmytd|0r&Afl@mpgDin(qtezc%2}ne8)WUIR^o&Rh?P4aDakuyAD+ zHt2%6duV^cP$fswhnF+pV3$59R5K0dZ75aL;)l@CWPyX~aS7aHzvtjpwG8WPkxOL# zdHlW^V?KX#7TOdKhe|Uyz+|S2rq`Qg^tg17a;J7b@UhlJ?cQ4nYR7a__T_EiSA{=K z;vt=Ivix=K-h~ll9lrYXw*MU1|9BpG*K`S#8S(9kNDUm|5Mp22(}tMIE$^pC2JN%{ zw>tLqquNQ8*s#ZcadWqyFWGGtj#BSc4!>6k@|(H!3-_ilfZ|W>PBDT9!NuZ9{!F|- z9n;~TCg<_#E6M%!!`&D$#6CeaQh{ag zxFtCY(R?@I#O9x)!5iHW`Bf~oXg~Y~%iT-b&Wy`*FSANkF9xXv9OM&Z_}sh=H-c-1xL#h~60YCiQGJC@Q0-`-_ulM6;c`n~+H z+o>tsy49piDLx5uvF1T{X!_7|s3w#B_Xw;AtT3c~?!u$~toH>OIA|wX# zq%9{5p~a~>NyVUHaR1UD(Mz|0YJwY%end{dgE6FG>+&#MKe}dc@JS8o+eQSgL}#Fb z(uV8^S15|ww|_0XO-Hh>x62NtAatC{dnZp;5`CVU9Ni!L%FsvEqWfj@68u=~Ja9yG z1p@hI1QyRZ#_7@Mxi_)kh)=o)t$I z=ZUYt`P4~Y@y%ge+tcn}(A0%FW|wn%wb zzRB0?3W@pi*__ng&`-tq@<#2bY}^cZ>)hQP5kceugt?DZcF!Q4{=Ou`mu;Zl=J0`S zsswuv{)0)8CH%CnOS+E8-6Y#<(-h|<16lm>Rv7&TSY^;Dgm#WX@YydXJt-GZdrUek ztiA>2zp9yhJJf}JMxDJBqp8@)uVx=D`3BrKpo+F9dF99C$j8;TWc#&U;a@4>?Ij(sD?uhKhMYN zBP-vz-Z#NzHM!r(Z#Td>Nmt5o-z>o?&1jO}A)GGoG5>qn@Kq zqu$5fxu!GQx4NrXdcVB)8_gtR4DNOADPGX~v(=_Md5N3=D z5dI6vVYOTKiQ;F|9&G zb1t111V7TKMZAZQvA?YhCF{pL$q>qi!2wco1xhLEm zhbv9D!xGE-p_k)a!wg?9@-M9#9`}fV(Z-P6y;ylcm}lZ#J1z_UlJwx9BsG_(48)jB!>2cT$M$%x;<@j3 z&q99-(%l_$YrtrXdfl%M5+mmIiUvTcoxRqD(ytt4SCl@vpP7@ zKK<9={4Dm0-tT1;NyGV%Ldqmv!iW24lKahe9n^yICy%RYY|C z{NCa$`0|$w-8NAdJ`x<86wBCACRTdVVs5kOJm)+yPyX%P#X$5iBT7wNJ}+Q0Rr`AM zj#`x9oxY<|)&*z#-Wtyd?T|W7NlAWwZVJY_-x~B9k|Bp*b<#NB3wjuq?7x^SA?2sL z{{l%HxJ9e=vml@VY>sO>>)Xu0(>+$Tchl$L*PZh>7xXEmUU&a|XdAkN^xpfJ&h&A@ z*Hw$F3`y94c%jmZ_sygZ4j^R;USKQ19_iqO9_ z-dY?NW@tIyyNL2eM-{xhlhJdw;HE|xrBq<}_PfhnGf?cd+ivD6(W|`hgQBHx4t&?P z1y>|z&^7nzZn|i~&r3T0{@l}ONIlQ|FV<-QO(jLQ*9pD%(S&KY??f6XUdW7Iq4&Uz z)W!GacYE-9`!%cfj8%~6eYDAVqZb|xKZ_nRCG>mwYZ{#oOCjLDtKWJPmr$*Qi=tv| z5S>`+e1$&`LR1*LdGN{rD23=+FiCFWE7{jQoTPDBme~0nZBnSqAVWDd4U zLLc>aGhH05hcyAq-?lLXw=D1OS7+@7JX5S%Yg0Fl+1WkrmrfO9m8qL%-_aFxKPec| zKwgF^7imKmf-@4SGq;PT@OaDl4^0sn`R*r_eb4R(tJ27Z_ZPargHCufmEfeb zJj`Ir=v~EA>W6|!T11{RY%?%|(8&Xf<^sHYlYm;~49WQY5-c8!*%9}p8saim?p3bm z;m!a-E$AG?jlJ;;-oM+RZIjM=tiBJp*U$4k<)b9o?F^w-^J@i~6VkUl>_*>_=L~#td8lyt@|_!Ei+Eh->t>nJBIrEP%lxKY1`osk zy$O5W3h}W9I#Mf3cqC(xc}K82@|@XE|F*mc3yqs}S6SELRI*IMvAAj6W|qDFXcE7L`t89oRJ>79FsY0YCqDUevHud zY-;pnQ=5V9F`4xR@%sqKy^K>H>4ye|CRc$Sy>Nq;vb?mzBx)x~gjsmC zn0^kg-H7uywJU>XEj87OM4t77h>}l0*Z}VI)lFPHN8|&`KPfv$6S@qpDQui3<5-K( zS@Dui)P657dystu>u%)tdKnTvh!3rx=wcssI?(H1^T~%X-*uHQ0+l#gT+qtm*#yrO zHuq=MPvXZfN4j`sLhj#bw2Q z(KyXPusaqzRr|6R)>!$XnxD^s)9;)Fqx^u$Fl;;f8M6}y{ZM7rQ_D7)mbpct@K^P zvlA|=r~ir}bmO+`KU4TWPaxOFU96soL>}kU`jfiyQ8WzgT;?ojg&yzkUdE1Dc!k1d zpK3z~)X8t_Z*rOAg~Y{)22wtZJU+}MJRJve)JjhmJ-Sd<_u_$@2YNA;y2AfvZVx`A zi3mu0yNE|T{KS9g65q*0R zVapRjL2;>dn3fiJb%$vlQ0}rx-*tWl<5h&@B%F8HuQwM@DOKY}DhnlLcxPCL_Qj_Wbh*T^ zi;r;-n_CuN)Eg3ha>b|Ze#3rCaeb<&XXqqdILDr_QKNN5&I)&fqm z+>Ea`>BfodGlc?ETeyiuCFhwaq)z-ED_gwq500_x{-PL-N2i{>kIh-jVL+?F5SfY~ zkCsZwnV3KF-+6H>n0^Cx-9@tZ`U1w^n;!BY6X%3`i5+u~rg1~a;5a|_La{^CGRwJj zC}@BBN5ORrg^PyE=r1qBnR80NW(^uKQHNesZG_OT%};AP`}V`RQza=|Eo1nkPC!xm zbT^tEy4=u5Ka1DZ9iQn7CPA^1&gCB=Q^;7gnoN`4k9=SFc<%mO1ZvA|aZ@QWWPj4W zAorT!v^*_;vtN%|>V=_IvI2P&)OVUy88}l*ndR+Hm)}1R?UrmkIxjl$-WOgb+QW4i zL_-tUL__G`=L^odoz90HdG}Y0Z2FMd_Wfeck#|_NVZfajk%&jz|1*22yo?L;>hzx! zMxm9;)m)No8{HDG&I|DMphA~kobc}rJp9U2XMKJIIc>UIqEmffYP(3#WML3oV^zH$ zKNy0G8z+i9XJ_m2`p$vdr#Qu6gNEqObi_?nMR91itX> zYu51^gstVu2kBEZQi}0{ryQI}ARtJcc;NXMvKo4n#s6J|>z97$U0@zU;T9fY9<3Gh zU{;!WC$R-#)gc^d$xE1>JT}=ao(V4=SU5LUwW0Bz9?B8wVGO7i(Ul2Ghrk<6m$e25 z;2had#hCD~n`2^fMLi=yp&=_%;A}E##9aK$NBrJW=NUb_+t@ z*XUoxFhV=C0Cr|~Cfq!uQ7rPm>eR+oq@IfO70xMx?*&E{GELKP_Vd5&xcCiR(vj`E zBv1?WIj_Heq*=y1v$_uTm<+IfTM?v@v<2ouLvj9Rs)+o{%tB+z9CF)Pc#ZAfM9sX> zHYxQ@FwzjB4xAc?!lN6G`$LA294_kGYSoJy*P2eazUxBKkX6YSid`T%m|57AwuJLx z`_EeuJ@fSCJNMMIh9OfWOs&Ld7T+$ut{)0UJEm5 z^g$}sH;4NSqd?!Z8dty98SXh8sW5CM_$fcV7&NkeAy>dqk89s5up0bPN)aSsji6fK z!ij9G5!qkd@6nIJ@}lpJ9?YQ+zi|!Y#R@3X;xEZ|ngPmfpMJ*6Ur_MTf4CNmVG|27TNy;3ZG4e2Dk!8W@@}qkVP|nm^V=GTv&& z6qS0XM!qR<=B!l-W}C$L15H03D8*yPu~&+M2Zw-lXrkrs^$p~HXnvV($1Fa}tbbJT zhsec^Gm3GPu7KF?`)Z}v;&G5Y*}g}A8A|rc2FkI|KwHPjLvMP=L2EIWUSfI?uO}W+ z(1|2tFb#?CbHN5^)crPKp6Wn157Y0pXJ*iTIoSQE=q%bZsxWG+l|bHiowCxa3lQbK zbcu((7nPXH6e6EdN|nEEH@r65j!nrq2Ut9ZFx_bK1<#!=bglJyLeJF$2jhJFstMoj zDeqIOmbzZJcC{+v?NToqpCNshk?BAh=AzgmETdq>s&eJ9?K~#mW0?E?q!y3L-znke zngRW_{CZW|e<Wc0E=a1ohwdsET_xLq#u*$@cL%%zo1UxlN=9c6?z@__RSy`u-&J znu%gFG;TF&>h3Iu2ak>0mYVxOyriY&IAbT;W*xO+%IbsJQ&s7aT_upPFK~>{p2%^f zeiSSV-hdy@e7*n0q~p$xgZnIkhcHWosw-!q6t&pxo5d!FKx@@rmLhis&seC-ePS6z zg&E3uW8fTM_#Vpo(>8mdioUe;&V*BJ139i}6S7-TN z6~O3#P*+#UBDk(!XrVJ$$9y`5MTzSZ1Q*RItCG-b5*Jh^E%l`rBpLGdgJ1jnm`v?C4>uoZJmI? zu>Y=tNFCDsr4qDaZv_Z3U#Y)d348??t5V)3!O7r%@1I>>MaKgt1C{QUL)cXKc}t5< ze7sLycBZiz{Z9QjVr5u~7W<*G`g%9AvohJ8vmS+I1GBoZTZJH?Zrrq&u^0Ez@)xlY<-n;5Q_e*4bHq#iERw4d4JBQ;73~LYcfY#p=-?Ym9*;u+O%gp zcdl%Mr0Jx3DzO>pl6heiLrn(h{dxSO zYT(7`RI|fRUP1r;6M=1qMj*Y2QHs*zDfZ@q4VUN7 z;Hh~S-`tm?!AWqY#W{r{U-!cSk8dl2r6E8imatdqKqKDJybxGA&<0CuR6+5Zf3Vc_ zyqosmC=RwBQ=ceUhl_C^M0XMAF$Kr}>V~b`uzJ%z{YTs=PVKSze)%UEmL66t{$y`L zHlu!8d0A?bUsKc5BR)x>*C`sZnB9kN<1a@Kb}VCA;qMzM{iDF}DYN0C711l93(2OF zC3tu?DigaRXCV3~AKI>_T$_02iBT2M<81` zww*m-6qMVmvK$r1a|KG7RBIzt+XVyeYgX z=a4{KNK1NfB6UWIasnMXDc>v+=L8WWW#?_5d~~l1KEd2QjG68URx6sF=*1ALm@h$w zz4sZf`IwJD(Se*_+<#KB(o5*t6~0E?o;f^H5;l*PK8cMUM2@svOJiU7-6FVP?NxtV zIv-*rewSO^qL30eG<@x1cp|D8rE_HxIZNHoM=G2OYO%)hP{IV;EM$0Xe;aTna+s`k z@1K6`!-xC$5?!wmzUZrWk~)ShL~dqHSGj@EpCp#E?H*7_T^&r&u66Fm(CT>MG}T5> zJ)-S}#zg*C@_53w)&x%F`yN-m?T6I1pN_T;Y{H_#KfPJvyxG>ncK-o!{%7-)DBoeZ zg?C-9UgB}B0!Cw|YRbVmeDCema!_v-L$3N|EAo?xoZU{1@>AJ(vc5FM*L4wN)5dR( z94~;==4Fbjwgku5wIWTIBL*%9PKdwSLGZyh-!)JQ&O_XpYt8+)t5D$2%pb36Dw1;U zVIS-BsklZA_F__#NF&=JL{ry=11*yjl#Vr+L0vyx{Cbhtzi+yw6TXmn!gi;P^BTTQ ztR4+>q)UPg}1o1X6Lh_wCb&fPR9nQ8fQEjE*E{tupE_5DBGsepk{e zlz_AC6|CG!2lb|b36Y2@=wo)C-#Sk0%X{{#5tB3t6Mphj+~|Ohf16{^bu9qj*qM&i z-g!KK_+TW1J0+=~-L;n--UOc2w^^1c`rz(LC*PgU8FbOJ*O(E?gx(Z$R<0m4w;0X;=o;bK%oGN1z(GOKRSz!dj;jJ)0)sxbdzN? zpask>Uz|&R)Qf?|tBM=r2-Bp42M4_dAh;?2qJ!mcxO+wJZ1mSo^a$g!I5JA`Bh%z0 zqP}$_kG5<;S5zmYv#VMh`QD0-bXQMy>J#@yRC`&yEboVcLcCTYI)EX4GvuJlLts^` zHN(NaiisM}SsYfE(96}8la_xO8G>opF0^)G=4QM5A%cT>F!yvd5q4)s8+-4o)D20hidRAr}-z2ErWT^^W9p6zt`U^TYSHA5O<9lL>Mr2<4F1|Z|VGO zR97*eblaVPcUSP4UKf=VJt=G)QF zHYMovNH$78X{Y=b-VNkEWtU79n}G9-=>tDPe~l@4FKake4m~%@A9YCfLxS>i$(>q# zNXkkXKIugz6+y=p%1P{p;&%V)jNbhM*4H7+Rr_l|`SvF3qpmUdR;D5*TxOS{Z(bCq!o1LB-=c50B3{Am)UBLlN zSB&VY;2A>BF8*{*{w}y;cRhm3ZxmZMDmwgZ%YZlSkZ8AgJF+;REmC>dk38!4(|-S$ zBKT8-&U#RVeF3NZ9q0*;#%Y;fp}YU!z5}DJPwVP&RuwMpzB`Hhl~!C0=EQqFJEah( zI}axmpV1l&wPS>%2w#KuG}dZJe6b?>c*1%0*$g%tP!+${#{O&wp43>VsE-YyT zKi;1ct8gS?$UZ=#4`X-Kocp$C720$h$Ua2>V?=9xa$=6)*B&~iR5&{TVm6H<#`V-v ze;Y~70aP8RB06qf`l1S!+;5XsdOPt*2$gWb_YIU#whNfJUIG@ycQ|iY{{yoLc^1z% zOJKYiUhFF~13QJj1eK-sAV)@!U&;{*(vz#w?-K{+Ai;_u2=uyOA~;7*-r_G>g`d4E zM|__SUVPx&&Nq(s7dfh<2w#AeQR9GyOcSm=o`}q8?7_=tX`a45+5_GqJ|+_nr}1!p zS8a;I0*)NW{1>=u0c}eHy$>9X!4)&|6q=E6E&i6f^WQ48Z;`nwHAMJ3(|c9DzydGzjyC?VtLbiF-|||1&WobmEt_JbXSZ6>!7$^9v^;cR8+WKP__l8~pY*r#(>f z6%`EX4SmZ#qbVK>;8kb_kzkQ;)S5((Nh9-EYSA`mR&?_S3=>@8_Px7Tvzqak(M#nZ zL2tNuVfPNvibhCeGp>B;vkA%Tx*rl8W|7;LzEDt{@Z&f$>!>GcQO-N$O6vDkn7`U? zbx~m$>2{&-Z!PN)SNq*NZ!f{uA_2@^n}#G(l9rlHG&qD`pq{jxK&uv6 z6&kB4{AvE5CS6k{%pH;PR%99lD&`t;B5y6!Ezv0_6cYTC+?~`4B2*;rg`O{3B5lNZ zCbrr>EDIcNZOz?gAtTE|u6|(F3^=Z)#z|Ul;pBa4L5s8Fc$z2H?&)MRkgoV&=NDds z%xGhg;G18N`qFDQ9XmpwI~4lN!gCn;)kKpw?a1gDQ575BIEX2ij*>344&lBt>}zfl z8(@BF+&bpS0MZXf$E-b{!s^>bGi>!$ApGdy67|T?R9tOV3G@ z_ii4%bd^FEwc21_)_{geY!PTK6@)pJ5L}bLw;7&uZ{e!Vw4aOF8lFs}>(M;EPW129 z|CSS+Yq8|ptSP>E`227*9}mIL;C=I=L`0t8Df_!q&wKmAgS%SXr(hAnv~R@-&n)0^ z>eAa`b3|TYO#X-du5tV==5@z7d={fCKRvJ`@)tTkDei2wPv8R6=Z8o3O=E@}yQ&xc zJZ$y%HOH)zF}`oWCz;?VpFJqA{6QxQez|UD6*&&0oDLU{i~|`QW4j}|7W1%gx^HBF zeGS#5pJ~*Zt^=@;_d0&f#+Ar}?!1py;ECR4W9FhEd~x&Hu};}>T)(}K!snnJwi7klD86BFZjeuKypze)a2^Ya3DPq1uD6FI>0T}4WAV(s{k@{+7pRSzbx zF_dE1H1h6|vXLVC=LsIaUw$Y0`>nK3lDZ&!Byzn4WNf~6*w3tCUTNxp%?fcJhlgCz3GY-iT^4oo zY@J8Bnb5M>BW2K)Q>nkfP>x~!jaO=-SKv$;S?_FlEz+x>as5QRzq%j(B)xSV1eR-u zZtYVcgVQGcDef~wpOTMmjCtQEXuh4&pOS9IUbXv5k4gytx3f;H{W}?~lG)5Wrd!eQ zZ@){G`XG8W#`>RiAICEt>AFdx8<5$?M|FB&20AKst91OE2D6|i5+W6);NcPXS4+AK zUg%Rx#yBm5_GB`y?NOU#GM+OMa#yH*LFrJw4# zt7GuuS*L#6#SM6Jb6T)beg=(}vM1UbMlp9^LUY!mR{Y{JIVAOg=M-{ca}uL`1y*In2~m%k!gR$d$JjsmWNfTh`!*mMPS6_^cbEn za(TS_5G6_4?$_;_))YA7%6@%2d>yqrk`8>ACD!*&tFoc)I!GOJ^7yVn%yW*DJzS~- zz%0=0sr{h_2F!M;xNtXPhU}$3bIQwrL2Oxf>82n=$~|(0{Ewpm+(`%5k4H^{H)76 z32gh>WGubc(TJ!1TrROr!!yR)4w(}7g5~ni7(Jgsjz_DX7Ay$jt+BTqUpmF~;^pTn`pBvtz8FzYB<%C#K2qnZjlwt;^p^h?ou;-R(v@d%W< zkag}Oe$;>o|P>Y|zS+32eIh7^sgC_e%XNcdtA^ z#!m)ZQ92Jv`0?a#wv%VqVAla}U9plzVDN9I8&sMDUgIs7pF~dL!C}U`oq1GJr)ckT zT8I$3Y4-8XM@n1hTgO=Za#uO*DvH?kyk!<>JvMoq1m{t#zjn_383pO9@{^b+vJ+Tv zeGf^ML;{(Oh-4MwJl$`#uuXA>m^+0H8rLIRai<^yW!%LU%p5Itd$)^1s^W>#kAGb) zn1696n)3tQYN}|3~QiABVh30^&)QWx*>6i z&>1(SPO zG9z%M!caYZvK-wGa_wzB+K4m!&0pT^nZ)ydp0gEocVO}RQKwz5xma3SFQ-3F^a3(& zm|PN^1?3QLo#nPb7}5_T7Zr_Rv+uIUL)|PKV(62ee_0E$_fAY{Z%iV&y|0S5s~e5F zu8jJzEP-?OSswEfZBWAgOVl)hQfkrm(M~F=F$}NuB_)@3!mH+6FFq#p-y z=I2r9MgN^jVqVZ`o%>->wu&?Q!~Ew@6M3<>3&L6BgCI#MFDJA#hmt+_dcq7kh@A6b zscUIBAoQ(JCjz_pFufofgRuSI9E{xF_8ou0< zj^~XkKI~pAL()=jt)2b^D1^0@t|!&OZQi8GMR(%+bl~Q==@G9LXa$Eh!9$_9HepCqCM?Il8C<55{^N&nEDw&KwrlBwwgi@6 z756vL^OBUruzxzvmM4y5R2{q=r})`P%&%j2e->lh2AunIvRQGe1FDZ}+8P)n<5yX^ zzQvX@B9EA6B@s3N%Z7@UCkE=lsN(wB82)M^S7=nWdSVRo^p-!?QxU#jrHr2wsTZy3 z=|rX9^$>Zx^Wh6WOK^ocHRNA4(Gyz!DZTZ94D~xt9e6f31JLX3`dDiU?wG^P1tO1i zIP8I#I70*K&u{;i;xdFcM9jqfjC0V|{8tFG$TBke(ouF3-}$tm)CbYgi{q<9oewv2V@XBKF^iUhfO|JD1RiKUr>M8KYu^`XVk==tC1hzv)ka3r zjyEf!lzlM$*Gy^P+ziyKQ`E0YtpN4Gs+856OQ_0MU!irn5*LmY_Z}LXg(+ckzPUX^ zU{Lci&q?_!`X~K*Zj?_Um3oZocs|WIdbrT`=}OOHd%TR`C!+s$g7MOscR^E7aOb2r z#p@X;XP0f_G8^UvDqEqoCvH(0a~1 zyuNy6%<5Y&eq%|eDIs)V!`i{^b(w4!omx&kz&H(KrG1B`-p`@%`LDr0vk-)Rp2>D* zzUneC#ZKI5f4ExY=LFpPka5@f8{uoZOR;3MOo4gF-nq3$%|Pu# z7Ca?2fe%8<8r9Ropi?t${2bQ9dmT%C9>H`BzW1NsE~gp{J{IH8*xCutIX!%6tN?DlAk@Fz=pnaqbFEA+Q% zZ$>?4m$&Et2(2UTJF?u)^)AMmuLg{5FB`zU=I=#gqF>+Ivm9$D-vzuCh0soioKeeay4#*M$A9^keiIZE++S8>DJ)h zI=hs7o*xi@f6q#!)HJ;FQK=j%rI30$A!>FilY*4r)oT`}-~`9-T6a1n<-xW3Q^xt_ z6X0@YuV=bF;gc)y(Of@12dN>4_zvYKqWBGI>nYk1oKBW>zs7BeS7q}8FUweAE~BVe zk?<&z+C~*r2))LT;!vxh>oSaqYfk>!)dHM{Ue{XI&SCfOFS8QK6|g-XVH>M63%&iR z=N_M$2JxX4JSe{g^27huvw+}?aRlXOTJ&OJzej-HrAg=y{I%#TGzlajL90sXbsV=@ z^mSyeK*x%Gl)SuM_^Cd`DWk0h1I%~*Zuja(Zi)I~Bh7k<;{DjoYBdE+okDqX)Ps0p zflG14pOR#oMYW@rXK<4G<>SU3Z79en-dKX}5){n|*ZsnmX>o3NM11qS%?Xq|OjY)i!Q#=uGMMYYX_-fu!p8)SdgR+fU7g3!h zvgT9qIJB=;FS%8xLrK2>x`o*!>{dF&e7Jz%V5$2wNR>>0a=8aNE|SP=-fiW>y;vPkZFhU;GW_TN{mgA$TcolvEw{Pf G2mb?DtmEqd literal 0 HcmV?d00001 diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index b50bde9cb..3ce092789 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -50,8 +50,7 @@ def test_export_lm_to_hf(): export_lm_to_hf.main(config) if has_torch(): - m = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") - print(m) + AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") finally: try: diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index c9c2bbc61..34c4bb941 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -17,7 +17,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel from levanter.models.loss import next_token_loss -from levanter.trainer import OptimizerConfig +from levanter.optim import AdamConfig from levanter.utils.tree_utils import inference_mode from test_utils import skip_if_no_torch @@ -142,7 +142,7 @@ def compute_loss(model, input_ids): assert onp.isclose(jax_g, torch_g.detach().cpu().numpy(), rtol=1e-2, atol=1e-2).all(), f"{jax_g} != {torch_g}" # now we also want to check that the optimizers do similar things - optimizer_config = OptimizerConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0, lr_schedule="constant") + optimizer_config = AdamConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0, lr_schedule="constant") if optimizer_config.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(torch_model.parameters(), optimizer_config.max_grad_norm) diff --git a/tests/test_logging.py b/tests/test_logging.py index 7c537b182..ab7cc35f2 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -20,5 +20,4 @@ def test_infer_experiment_git_root(): assert pathlib.Path(root).exists() repo = Repo(root) assert repo.working_dir == root - print(root, __file__) assert pathlib.Path(__file__).is_relative_to(root), f"{__file__} is not relative to {root}" diff --git a/tests/test_mpt.py b/tests/test_mpt.py index d4ef74084..aafcb4e1d 100644 --- a/tests/test_mpt.py +++ b/tests/test_mpt.py @@ -10,7 +10,7 @@ from levanter.models.mpt import MptConfig, MptLmHeadModel from levanter.utils.tree_utils import inference_mode -from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch +from test_utils import check_model_works_with_seqlen, skip_if_no_torch @pytest.mark.skip(reason="MPT is broken in the latest version of transformers") @@ -102,15 +102,6 @@ def test_mpt_nano_compare(attn_impl): # lev_model = MptLmHeadModel.from_hf_pretrained("mosaicml/mpt-7b") -@parameterize_with_configs("mpt*.yaml") -def test_mpt_configs(config_file): - from levanter.main.train_lm import TrainLmConfig - - config_class = TrainLmConfig - - check_load_config(config_class, config_file) - - def test_pass_different_length_seq(): config = MptConfig( max_seq_len=32, diff --git a/tests/test_sophia.py b/tests/test_sophia.py new file mode 100644 index 000000000..7e759c330 --- /dev/null +++ b/tests/test_sophia.py @@ -0,0 +1,56 @@ +import os + +import equinox as eqx +import equinox.nn as nn +import jax +import jax.numpy as jnp +import numpy as np + +import levanter +import levanter.optim.sophia + + +def test_sophia_h(): + key = jax.random.PRNGKey(0) + model = nn.Linear(4, 4, use_bias=False, key=key) + data = np.load(f"{os.path.dirname(__file__)}/data/hero_data.npy").astype("float32") + optimizer = levanter.optim.sophia.sophia_h( + lr=1, b1=0, b2=0.99, gamma=2, weight_decay=0.0, clip_threshold=1, key=key + ) + model = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), model) + + opt_state = optimizer.init(model) + + def loss_fn(model, data): + out = eqx.filter_vmap(model)(data) + return jnp.mean(out**2) * 4 + + jit_update = eqx.filter_jit(optimizer.update_hessian) + + for i in range(1000): + opt_state = jit_update(opt_state, loss_fn, model, data) + + # print('Test-estimated hessian: most coordinates should be approximately 2') + # print('Estimated hessian:', opt_state[0].h.weight) + assert jnp.allclose(opt_state[0].h.weight, 2, rtol=0.2, atol=0.3) # this is very approximate + + grad_loss_fn = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn)) + + loss, grad = grad_loss_fn(model, data) + model_updates, opt_state = optimizer.update(grad, opt_state) + model = eqx.apply_updates(model, model_updates) + + # loss should be 15.74834156036377 + assert jnp.allclose(loss, 15.74834156036377) + + # print("Test-model param after 1 step: most coordinates should be very loosely 0.5") + assert jnp.allclose(model.weight, 0.5, rtol=0.2, atol=0.1) # this is very approximate + + # print("Test-loss: loss should shrink by approximately 75% after each iteration") + for i in range(10): + loss, grad = grad_loss_fn(model, data) + model_updates, opt_state = optimizer.update(grad, opt_state) + model = eqx.apply_updates(model, model_updates) + + # print('Step:', i , "Loss:", loss.item()) + assert loss < 15.74834156036377 * 0.75 ** (i + 1) From 83bea6e4ec79512c7035dedd48e17e190f612714 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 9 Dec 2023 20:35:48 -0800 Subject: [PATCH 056/205] fix missing test changes --- tests/test_backpack.py | 4 +++- tests/test_hf_checkpoints.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_backpack.py b/tests/test_backpack.py index ed02c43ac..e3cccc4fc 100644 --- a/tests/test_backpack.py +++ b/tests/test_backpack.py @@ -104,7 +104,9 @@ def test_backpack_nano_compare(): # now test round trip with tempfile.TemporaryDirectory() as tmpdir: - converter._save_pretrained_local(lev_model, tmpdir) + converter._save_pretrained_local( + lev_model, tmpdir, save_tokenizer=True, save_reference_code=True, max_shard_size=1e8 + ) model = AutoModelForCausalLM.from_pretrained(tmpdir, trust_remote_code=True) model.eval() diff --git a/tests/test_hf_checkpoints.py b/tests/test_hf_checkpoints.py index 28de5e6b4..169bb3999 100644 --- a/tests/test_hf_checkpoints.py +++ b/tests/test_hf_checkpoints.py @@ -49,7 +49,9 @@ def test_save_backpack_model_with_code(): lev_model = inference_mode(lev_model, True) with tempfile.TemporaryDirectory() as tmpdir: - converter._save_pretrained_local(lev_model, tmpdir) + converter._save_pretrained_local( + lev_model, tmpdir, save_tokenizer=True, save_reference_code=True, max_shard_size=1e8 + ) new_converter = converter.replaced(reference_checkpoint=tmpdir, trust_remote_code=True) From 92a615fb7b4a801b20405566ff996699403c29a3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 11 Dec 2023 00:23:50 -0800 Subject: [PATCH 057/205] should use a tempdir --- tests/test_text.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_text.py b/tests/test_text.py index 21e2887db..9469999e1 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,11 +1,15 @@ +import tempfile + from levanter.data.text import LMDatasetConfig def test_dont_blow_up_without_validation_set(): - config = LMDatasetConfig( - train_urls=["kaa"], - validation_urls=[], - ) + with tempfile.TemporaryDirectory() as tmpdir: + config = LMDatasetConfig( + train_urls=["kaa"], + validation_urls=[], + cache_dir=tmpdir, + ) # mostly just making sure this doesn't blow up assert config.validation_set(10) is None From cbee4277e2619702d9ed88ae2ea5f22c03552c43 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 11 Dec 2023 00:25:15 -0800 Subject: [PATCH 058/205] update gsm8k lora for sophia refactors --- examples/gsm8k-lora/gsm8k_lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index b4eaa3ec9..febfd2013 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -24,7 +24,8 @@ save_peft_checkpoint_callback, ) from levanter.models.lm_model import LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle From e0485810884ec171238f0327cdc7edaf3154b9ae Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 13 Dec 2023 10:52:43 -0800 Subject: [PATCH 059/205] Allow val change wandb dev (#384) * allow_val_change in wandb * sanity check epochs --- src/levanter/tracker/wandb.py | 3 ++- src/levanter/utils/py_utils.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index cf4147351..83866656e 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -46,7 +46,7 @@ def __init__(self, run: Optional[WandbRun]): self.run = run def log_hyperparameters(self, hparams: dict[str, Any]): - self.run.config.update(hparams) + self.run.config.update(hparams, allow_val_change=True) def log(self, metrics: dict[str, Any], *, step, commit=None): self.run.log(metrics, step=step, commit=commit) @@ -133,6 +133,7 @@ def init(self, run_id: Optional[str]) -> WandbTracker: mode=mode, config=hparams_to_save, settings=git_settings, + allow_val_change=True, ) assert r is not None diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index a172b4498..735073999 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -19,6 +19,7 @@ def non_caching_cycle(iterable): """Like itertools.cycle, but doesn't cache the iterable.""" while True: yield from iterable + print("epoch XXX") # https://stackoverflow.com/a/58336722/1736826 CC-BY-SA 4.0 From 2bdf08b5301916a04c901c6ebf381a01f4e1f527 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 13 Dec 2023 10:53:24 -0800 Subject: [PATCH 060/205] oops --- src/levanter/utils/py_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index 735073999..a172b4498 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -19,7 +19,6 @@ def non_caching_cycle(iterable): """Like itertools.cycle, but doesn't cache the iterable.""" while True: yield from iterable - print("epoch XXX") # https://stackoverflow.com/a/58336722/1736826 CC-BY-SA 4.0 From 8f4aff307bc56da9a932376a1ec8f66d78397f20 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 14 Dec 2023 09:48:51 -0800 Subject: [PATCH 061/205] do loss in fp32 --- src/levanter/models/lm_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 665137846..439df5f96 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -3,6 +3,7 @@ import draccus import equinox as eqx +import jax.numpy as jnp from jax.random import PRNGKey import haliax as hax @@ -98,7 +99,7 @@ def compute_loss( across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ - logits = self(example.tokens, example.attn_mask, key=key) + logits = self(example.tokens, example.attn_mask, key=key).astype(jnp.float32) targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) return cross_entropy_loss( From 2002832925c07507e8d43a50b139899c7b311552 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 17 Dec 2023 00:16:46 -0800 Subject: [PATCH 062/205] more dead code removal --- src/levanter/trainer.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index eb31397e7..99db1ceb8 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -361,13 +361,11 @@ def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bo return info - def _add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): + def _add_default_hooks(self): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) self.add_hook(callbacks.log_step_info, every=1) - if eval_dataset is not None: - self.add_eval_hook(eval_dataset) # engine.add_hook(callbacks.log_memory_usage(), every=1) checkpointer = self.config.checkpointer.create(self.run_id) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency @@ -468,13 +466,6 @@ def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_k return TrainerState(0, model, opt_state, training_key, is_trainable) - def _init_non_trainable_params(self, model_init): - model = model_init() - # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) - non_trainable = self.mp.cast_to_compute(non_trainable) - return non_trainable - def _initialize_global_tracker(config, run_id): if isinstance(config, Sequence): From efa70a12c744bd85fff56c4ee74dfa18a003468a Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 17 Dec 2023 00:22:27 -0800 Subject: [PATCH 063/205] refix merge issues --- src/levanter/trainer.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 99db1ceb8..43c361518 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -286,29 +286,27 @@ def initial_state( # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) - with self: - load_checkpoint_path = self.config.load_checkpoint_path - - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + load_checkpoint_path = self.config.load_checkpoint_path - # if we're loading a checkpoint, we need to know which parameters are trainable - is_checkpointed = TrainerState(True, is_trainable, True, True, is_trainable) # type: ignore + if load_checkpoint_path is None: + load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) assert model_init is not None - state = load_from_checkpoint_or_initialize( - self._initialize_state_from_scratch, - load_checkpoint_path, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - force_load_checkpoint=self.config.load_checkpoint, - is_checkpointed=is_checkpointed, - )( - model_init, - training_key, - is_trainable, - ) + with self: + state = load_from_checkpoint_or_initialize( + self._initialize_state_from_scratch, + load_checkpoint_path, + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, + force_load_checkpoint=self.config.load_checkpoint, + # if we're loading a checkpoint, we need to know which parameters are trainable + is_checkpointed=TrainerState(True, is_trainable, True, True, is_trainable), # type: ignore + )( + model_init, + training_key, + is_trainable, + ) return state From 2a90f57565127268cee1dd1cac55a7abc7020d3d Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 19 Dec 2023 00:24:54 -0800 Subject: [PATCH 064/205] allow train_batch_size to be -1 if per_device_parallelism isn't -1 --- src/levanter/trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 43c361518..1d234a76a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -661,8 +661,13 @@ def _validate_and_set_defaults(self): ): raise ValueError("either model_axis_size or local_device_count must be divisible by the other") + assert self.train_batch_size != -1 or self.per_device_parallelism != -1 + if self.per_device_parallelism == -1: - self.per_device_parallelism = self.train_batch_size // jax.device_count() + self.per_device_parallelism = self.train_batch_size // self.data_axis_size + + if self.train_batch_size == -1: + self.train_batch_size = self.per_device_parallelism * self.data_axis_size # validate size of per_device_parallelism if self.train_batch_size % (self.per_device_parallelism * self.data_axis_size) != 0: From f05739aa094a801a2ca30a2b45b2bad1554320ad Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 16:04:47 -0800 Subject: [PATCH 065/205] wip --- src/levanter/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 1d234a76a..27bea94e6 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -75,7 +75,7 @@ class TrainerState(eqx.Module, Generic[M]): model: M opt_state: OptState training_key: PRNGKeyArray - is_trainable: PyTree[FilterSpec] = eqx.field(static=True) + is_trainable: PyTree[FilterSpec] S = TypeVar("S", bound=TrainerState) @@ -301,7 +301,7 @@ def initial_state( mesh=self.device_mesh, force_load_checkpoint=self.config.load_checkpoint, # if we're loading a checkpoint, we need to know which parameters are trainable - is_checkpointed=TrainerState(True, is_trainable, True, True, is_trainable), # type: ignore + is_checkpointed=TrainerState(True, is_trainable, True, True, False), # type: ignore )( model_init, training_key, From 38db3d5776543789c3472b71b30139db1b7c34da Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 20:43:53 -0800 Subject: [PATCH 066/205] fix performance regression in trainer.py --- src/levanter/trainer.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 6a438abeb..f5502bc39 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -28,6 +28,7 @@ import equinox as eqx import jax +import jax.numpy as jnp import jmp import numpy as np from draccus import field @@ -39,7 +40,7 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit -from haliax.types import Scalar +from haliax.types import IntScalar, Scalar import levanter.logging import levanter.tracker @@ -71,11 +72,19 @@ class TrainerState(eqx.Module, Generic[M]): - step: int + """ + This is the state of the trainer. It contains the model, optimizer state, and random key. + It is an equinox Module becaues it is a PyTree that gets passed to the core `train_step` method + of the Trainer. This unfortunately means that `step` is an Array and not an int, hence the IntScalar. + + It's designed to be extended by subclasses. + """ + + step: IntScalar = eqx.field(converter=lambda x: jnp.asarray(x) if not isinstance(x, bool) else x) model: M opt_state: OptState training_key: PRNGKeyArray - is_trainable: PyTree[FilterSpec] + is_trainable: PyTree[FilterSpec] # = eqx.field(static=True) S = TypeVar("S", bound=TrainerState) @@ -93,11 +102,11 @@ class StepInfo(Generic[M]): model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) - step = property(lambda self: self.state.step - 1) + step = property(lambda self: int((self.state.step - 1).item())) """ The step that was just completed. If you want the next step, use `next_step`. """ - next_step = property(lambda self: self.state.step) + next_step = property(lambda self: int(self.state.step.item())) @dataclass @@ -358,16 +367,17 @@ def training_steps( with capture_time() as loading_time: example = next(iter_data) - levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=int(state.step)) info = self.train_step(state, example) - state = info.state if run_hooks: with capture_time() as hook_time: self.run_hooks(info) - levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=int(state.step)) + + state = info.state yield info From 9b2813ba86cd5e5757893a4c956aae33d36dcadd Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 21:04:31 -0800 Subject: [PATCH 067/205] wth --- src/levanter/main/train_lm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 2cca76284..f0e05c9ff 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -121,6 +121,8 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + print(state.step.sharding, state.step) + if state.step == 0: # TODO: I don't love that we init the model twice, but it's not a big deal i think? if config.initialize_from_hf: From e014c45ed53835a727b2b60544d2ed3d541c25a3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 21:10:24 -0800 Subject: [PATCH 068/205] mdkladmlkad --- src/levanter/main/train_lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index f0e05c9ff..c7617a433 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -121,7 +121,12 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) - print(state.step.sharding, state.step) + print( + state.step.sharding, + state.step, + state.step.sharding.is_fully_addressable, + state.step.sharding.is_fully_replicated, + ) if state.step == 0: # TODO: I don't love that we init the model twice, but it's not a big deal i think? From 95a391f7b387bedbbbae7fc5cba64c4ccecdd871 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 21:17:17 -0800 Subject: [PATCH 069/205] jfakmfa --- src/levanter/main/train_lm.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index c7617a433..7937a0fb6 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -121,14 +121,8 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) - print( - state.step.sharding, - state.step, - state.step.sharding.is_fully_addressable, - state.step.sharding.is_fully_replicated, - ) - - if state.step == 0: + # TODO: I do not love that we have to coerce to int here. + if int(state.step) == 0: # TODO: I don't love that we init the model twice, but it's not a big deal i think? if config.initialize_from_hf: # initialize from an hf pretrained model From 9f40f10dcb739d7693eb3d255446a55db5c9ae9c Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 21:40:59 -0800 Subject: [PATCH 070/205] try this other approach to steps in TrainerState --- src/levanter/checkpoint.py | 2 +- src/levanter/main/train_lm.py | 3 +-- src/levanter/trainer.py | 16 ++++++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index eac0ded93..fed3f40b2 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -343,7 +343,7 @@ def load_checkpoint( # TODO: pretty sure this is right, but should verify step = metadata["step"] new_state = dataclasses.replace( - tree, step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore + tree, _step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore ) return new_state diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 7937a0fb6..2cca76284 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -121,8 +121,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) - # TODO: I do not love that we have to coerce to int here. - if int(state.step) == 0: + if state.step == 0: # TODO: I don't love that we init the model twice, but it's not a big deal i think? if config.initialize_from_hf: # initialize from an hf pretrained model diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index f5502bc39..60edff34e 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -80,12 +80,16 @@ class TrainerState(eqx.Module, Generic[M]): It's designed to be extended by subclasses. """ - step: IntScalar = eqx.field(converter=lambda x: jnp.asarray(x) if not isinstance(x, bool) else x) + _step: IntScalar = eqx.field(converter=lambda x: jnp.asarray(x) if not isinstance(x, bool) else x) model: M opt_state: OptState training_key: PRNGKeyArray is_trainable: PyTree[FilterSpec] # = eqx.field(static=True) + @cached_property + def step(self) -> int: + return int(self._step) + S = TypeVar("S", bound=TrainerState) @@ -102,11 +106,11 @@ class StepInfo(Generic[M]): model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) - step = property(lambda self: int((self.state.step - 1).item())) + step = property(lambda self: self.state.step - 1) """ The step that was just completed. If you want the next step, use `next_step`. """ - next_step = property(lambda self: int(self.state.step.item())) + next_step = property(lambda self: self.state.step) @dataclass @@ -367,7 +371,7 @@ def training_steps( with capture_time() as loading_time: example = next(iter_data) - levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=int(state.step)) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) info = self.train_step(state, example) @@ -375,7 +379,7 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=int(state.step)) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) state = info.state @@ -481,7 +485,7 @@ def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): model = eqx.apply_updates(model, updates) new_state = dataclasses.replace( - state, step=state.step + 1, model=model, opt_state=opt_state, training_key=new_key + state, _step=state._step + 1, model=model, opt_state=opt_state, training_key=new_key ) return loss, new_state From 6df53f45fcf0f517465399d6e1b6323f7af8c5d3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 22:02:43 -0800 Subject: [PATCH 071/205] fix checkpoint tests --- tests/test_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index bca7000c3..14668aa5f 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -33,7 +33,7 @@ def _dummy_step_info(step): return StepInfo( state=TrainerState( # + 1 b/c step here is next step - step=step + 1, + _step=step + 1, model=None, opt_state=(), training_key=(), @@ -204,7 +204,7 @@ def loss_fn(model, data): grad = loss_fn(state.model, data) updates, new_state = optim.update(grad, state.opt_state) model = eqx.apply_updates(state.model, updates) - state = dataclasses.replace(state, step=state.step + 1, model=model, opt_state=new_state) + state = dataclasses.replace(state, _step=state.step + 1, model=model, opt_state=new_state) assert_trees_not_close(state, initial_state) From 94aa8fa5ee81c9c319b0ad6143c9739527f97b2d Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Dec 2023 22:32:49 -0800 Subject: [PATCH 072/205] fix gsm8k --- examples/gsm8k-lora/gsm8k_lora.py | 106 +++++++++++++++--------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index febfd2013..b889568cf 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Optional, Union +import equinox as eqx import jax import jax.random as jrandom import numpy as np @@ -88,7 +89,7 @@ def __iter__(self): else: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - yield LmExample(input_ids, loss_mask) + yield LmExample.causal(input_ids, loss_mask=loss_mask) def mk_dataset(config: TrainArgs, tokenizer: transformers.PreTrainedTokenizerBase): @@ -147,7 +148,12 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - with config.trainer.device_mesh: + def compute_loss(model: LmHeadModel, example: LmExample, key=None): + return model.compute_loss(example, key=key).scalar() + + # end major difference from Alpaca + + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -165,60 +171,54 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - # end major difference from Alpaca - - with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) + + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) - levanter.tracker.log_summary( - { - "parameter_count": all_param_count, - "trainable_parameter_count": just_lora_params, - "fraction_trainable": just_lora_params * 1.0 / all_param_count, - } + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, ) - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) - - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) - - trainer.train(state, loader) + trainer.train(state, loader) if __name__ == "__main__": From 5af6cb2ff14c80b0e5e9ef5c524dc3291e912447 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Dec 2023 12:59:39 -0800 Subject: [PATCH 073/205] update for new Haliax reduction functions --- docs/LoRA.md | 2 +- examples/alpaca-lora/alpaca_lora.py | 2 +- examples/alpaca/alpaca.py | 2 +- examples/gsm8k-lora/gsm8k_lora.py | 2 +- src/levanter/main/lora_lm.py | 2 +- src/levanter/main/train_lm.py | 2 +- src/levanter/models/lm_model.py | 2 +- tests/test_flash_attention.py | 2 +- tests/test_grad_accum.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/LoRA.md b/docs/LoRA.md index 06ee44c1a..2432ff335 100644 --- a/docs/LoRA.md +++ b/docs/LoRA.md @@ -83,7 +83,7 @@ def train(config: TrainArgs): lora_param_filter = lora_trainable_params_filter(model) def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + return model.compute_loss(example, key=key) trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) ``` diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 3011058b6..11f3b6134 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -81,7 +81,7 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + return model.compute_loss(example, key=key) # end major difference from Alpaca diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index c74900a8d..de4128ac9 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -227,7 +227,7 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + return model.compute_loss(example, key=key) with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index b889568cf..7531d2411 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -149,7 +149,7 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + return model.compute_loss(example, key=key) # end major difference from Alpaca diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 224a66b0a..5e9f55e99 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -71,7 +71,7 @@ def main(config: LoraLmConfig): optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + return model.compute_loss(example, key=key) with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 2cca76284..baa50b0c9 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -82,7 +82,7 @@ def main(config: TrainLmConfig): optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + return model.compute_loss(example, key=key) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp # Using the trainer as a context manager does 3 things: diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index c492b2321..322c978c5 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -117,7 +117,7 @@ def compute_loss( key=None, reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, - ) -> NamedArray: + ) -> jnp.ndarray | NamedArray: """ Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 8d1f5aab0..d4b1f08b4 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -60,7 +60,7 @@ def test_grad_attention(): def d_attn(qkv, fn): q, k, v = qkv x_out = fn(KPos, Key, q, k, v, mask=mask) - return (x_out * x_out).sum().scalar() + return (x_out * x_out).sum() hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention) fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True)) diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py index aec568a21..74854dee4 100644 --- a/tests/test_grad_accum.py +++ b/tests/test_grad_accum.py @@ -44,7 +44,7 @@ def test_accumulate_gradients_sharded(parallelism, accum_steps): mlp = Mlp.init(In, Out, Mid, key=jax.random.PRNGKey(0)) def loss_fn(mlp, x): - return mlp(x).mean().scalar() + return mlp(x).mean() x = hax.random.normal(jax.random.PRNGKey(0), (Batch, In)) From 84d3b33155199d71b8720462e99f26538f935428 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Dec 2023 13:10:24 -0800 Subject: [PATCH 074/205] wip --- src/levanter/doremi.py | 60 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 4ddb04298..dd054c8c9 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,11 +1,12 @@ import dataclasses -from typing import Callable, Iterator, TypeVar +from typing import Callable, Iterator, Optional, TypeVar import equinox as eqx import jax import jax.numpy as jnp import jax.random as jrandom import optax +from haliax import Scalar from jaxtyping import PRNGKeyArray import haliax as hax @@ -15,7 +16,7 @@ from levanter.data.mixture import MixtureDataset from levanter.trainer import M, StepInfo, Trainer, TrainerState from levanter.types import ComputeLossFunction - +from optax._src.base import GradientTransformation T = TypeVar("T") @@ -23,20 +24,57 @@ class DoremiState(TrainerState): alpha: hax.NamedArray average_alpha: hax.NamedArray - def __init__(self, step: int, model, opt_state, training_key, alpha): - super().__init__(step, model, opt_state, training_key) - self.alpha = alpha - def update_alpha(self, alpha): - # make it stable average_alpha = self.average_alpha + (alpha - self.average_alpha) / (self.step + 1) return dataclasses.replace(self, alpha=alpha, average_alpha=average_alpha) -# class DoReMiTrainer(Trainer): -# def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): -# base_state = super()._initialize_state_from_scratch(model_init, training_key, is_trainable) -# +class DoReMiTrainer(Trainer): + + + def __init__(self, config: "TrainerConfig", optimizer: GradientTransformation, + initial_alpha: hax.NamedArray, + loss_fn: Optional[ComputeLossFunction] = None, + ): + super().__init__(config, optimizer, loss_fn) + self.initial_alpha = initial_alpha + + + def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): + base_state = super()._initialize_state_from_scratch(model_init, training_key, is_trainable) + return DoremiState(base_state.step, + base_state.model, + base_state.opt_state, + base_state.training_key, + self.initial_alpha) + + def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: + key, new_key = jax.random.split(state.training_key) + opt_state = state.opt_state + model = inference_mode(state.model, False) + + # we do this so that we only take the gradients of the trainable parameters + trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) + + def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): + model = eqx.combine(trainable_model, rest_model) + return self.loss_fn(model, *batch, **batch_kwargs, key=key) + + loss, grads = accumulate_gradients_sharded( + split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping + )(trainable_model, rest_model, *batch, **batch_kwargs) + + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + if isinstance(self.optimizer, SecondOrderTransformation): + opt_state = self.optimizer.update_hessian( + opt_state, split_loss_fn, trainable_model, *batch, **batch_kwargs + ) + + model = eqx.apply_updates(model, updates) + + new_state = dataclasses.replace( + state, _step=state._step + 1, model=model, opt_state=opt_state, training_key=new_key + ) def estimate_mixture_weights( From 85b42b0fed25e9bb7033818078182db8965bb4cc Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 25 Dec 2023 01:27:14 -0800 Subject: [PATCH 075/205] refactor grad_accum to have a separate microbatched --- src/levanter/grad_accum.py | 137 +++++++++++++++++++++++-------------- 1 file changed, 86 insertions(+), 51 deletions(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 9280e5234..c8f821e45 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -1,4 +1,5 @@ import functools +from typing import Callable, Optional, ParamSpec, TypeVar import equinox as eqx import jax @@ -10,91 +11,115 @@ from haliax import Axis from haliax.jax_utils import named_call from haliax.partitioning import ResourceAxis -from haliax.util import is_named_array +from haliax.util import is_jax_array_like, is_named_array from levanter.types import M, ValAndGradFn, ValFn, X -# cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 -@named_call -def accumulate_gradients_sharded( - f: ValFn[M, X], +Args = ParamSpec("Args") +R = TypeVar("R") + + +def microbatched_mean( + fn: Callable[Args, R], Batch: Axis, per_device_parallelism: int, - parameter_axis_mapping, -) -> ValAndGradFn[M, X]: + accum_axis_mapping, + compute_axis_mapping, + patch_in_rng_key: Optional[str] = "key", +) -> Callable[Args, R]: """ - Accumulate gradients across a sharded batch, keeping a local copy of the gradient on each row of the data - parallel axis. (If the model is not sharded, then a copy of the gradient is on each individual device.) + Wraps a function that takes a batch and changes it to instead take microbatches and accumulate the results + This function takes the *mean* of the microbatched results, so it only does what you want if the function + is taking the mean of the batch axis. - Parameters: - f: a function whose gradients are to be accumulated + Args: + fn: a function to wrap + Batch: the batch axis per_device_parallelism: how many examples to process at once on each device - inputs: inputs with the batch axis. non-named arrays assume that the 0th axis is the batch axis. - parameter_axis_mapping: the axis mapping for the model parameters - key: an optional PRNG key for the random number generator. - If provided, this key will be split, 1 for each accum step - kwargs: passed to f - + accum_axis_mapping: the axis mapping for the accumulator (typically this is the same as the params) + compute_axis_mapping: the axis mapping for the computation (typically this is the same as the inputs) + patch_in_rng_key: if provided, this kwarg will be split, 1 for each accum step. It won't work if the + PRNGKey is passed in as a positional argument. + + Returns: + a function that splits the batch into microbatches, calls the function on each microbatch, and + accumulates the results. """ batch_size = Batch.size - data_axis_size = hax.partitioning.physical_axis_size(Batch, parameter_axis_mapping) + data_axis_size = hax.partitioning.physical_axis_size(Batch, compute_axis_mapping) if data_axis_size is None: raise ValueError(f"{Batch} axis must be sharded") - physical_axis_name = hax.partitioning.physical_axis_name(Batch, parameter_axis_mapping) + physical_axis_name = hax.partitioning.physical_axis_name(Batch, compute_axis_mapping) assert physical_axis_name is not None microbatch_size = data_axis_size * per_device_parallelism num_micro_steps = batch_size // microbatch_size - - assert batch_size % data_axis_size == 0, f"batch_size % data_axis_size != 0: {batch_size} % {data_axis_size} != 0" - assert ( - batch_size % microbatch_size == 0 - ), f"batch_size % microbatch_size != 0: {batch_size} % {microbatch_size} != 0" - - Microbatch = Axis(Batch.name, microbatch_size) + Microbatch = Batch.resize(microbatch_size) AccumStep = Axis("accum_step", num_micro_steps) assert num_micro_steps * microbatch_size == batch_size - grad_fn = eqx.filter_value_and_grad(f, has_aux=False) - - @functools.wraps(grad_fn) - def fn(model, *inputs, key=None, **batch_kwargs): + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + key = kwargs.get(patch_in_rng_key, None) if key is not None: key = jax.random.split(key, num_micro_steps) + # first, determine the shape and make accumulator arrays + r_shape = eqx.filter_eval_shape(fn, *args, **kwargs) + acc = jax.tree_util.tree_map( + functools.partial(_zeros_like, accum_axis_mapping), r_shape, is_leaf=is_named_array + ) - # first things first, we want a copy of our gradient sharded like our model, along with a loss value - loss = jnp.zeros(()) - with jax.named_scope("zeros"): - grad = jax.tree_util.tree_map(jnp.zeros_like, eqx.filter(model, eqx.is_inexact_array_like)) - grad = hax.shard_with_axis_mapping(grad, parameter_axis_mapping) - - # second, we want to reshape our data to (num_micro_steps, micro_batch_size, ...), sharded along the data axis - inputs = _reshape_for_microbatch(Batch, Microbatch, AccumStep, inputs, parameter_axis_mapping) + args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping) - # third, we want to do compute. def loop(acc, microbatch_and_key): - loss, grad = acc microbatch, microbatch_kwargs, key = microbatch_and_key - with jax.named_scope("grad"): + with jax.named_scope("compute_microbatch"): microbatch_kwargs = microbatch_kwargs.copy() if key is not None: - microbatch_kwargs["key"] = key - this_loss, this_grad = grad_fn(model, *microbatch, **microbatch_kwargs) - this_grad = hax.shard_with_axis_mapping(this_grad, parameter_axis_mapping) + microbatch_kwargs[patch_in_rng_key] = key + this_r = fn(*microbatch, **microbatch_kwargs) with jax.named_scope("accum"): - loss += this_loss - grad = eqx.apply_updates(grad, this_grad) - grad = hax.shard_with_axis_mapping(grad, parameter_axis_mapping) + acc = eqx.apply_updates(acc, this_r) + acc = hax.shard_with_axis_mapping(acc, accum_axis_mapping) - return loss, grad + return acc - loss, grad = hax.fold(loop, AccumStep)((loss, grad), (inputs, batch_kwargs, key)) + acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key)) + acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc) - return loss / num_micro_steps, jax.tree_map(lambda x: x / num_micro_steps, grad) + return acc - return fn + return wrapped_fn + + +# cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 +@named_call +def accumulate_gradients_sharded( + f: ValFn[M, X], + Batch: Axis, + per_device_parallelism: int, + parameter_axis_mapping, +) -> ValAndGradFn[M, X]: + """ + Accumulate gradients across a sharded batch, keeping a local copy of the gradient on each row of the data + parallel axis. (If the model is not sharded, then a copy of the gradient is on each individual device.) + + Parameters: + f: a function whose gradients are to be accumulated + per_device_parallelism: how many examples to process at once on each device + inputs: inputs with the batch axis. non-named arrays assume that the 0th axis is the batch axis. + parameter_axis_mapping: the axis mapping for the model parameters + key: an optional PRNG key for the random number generator. + If provided, this key will be split, 1 for each accum step + kwargs: passed to f + + """ + grad_fn = eqx.filter_value_and_grad(f, has_aux=False) + grad_fn = microbatched_mean(grad_fn, Batch, per_device_parallelism, parameter_axis_mapping, parameter_axis_mapping) + + return grad_fn def _reshape_for_microbatch(Batch: Axis, Microbatch: Axis, AccumStep: Axis, inputs, axis_mapping): @@ -112,3 +137,13 @@ def _reshape(x): return x return jax.tree_util.tree_map(_reshape, inputs, is_leaf=is_named_array) + + +def _zeros_like(mapping, n): + if isinstance(n, hax.NamedArray): + return hax.auto_sharded(hax.zeros_like(n), mapping) + elif is_jax_array_like(n): + return jnp.zeros_like(n) + else: + assert jnp.isscalar(n) + return 0.0 From c47c188fc9932ea0021777eb575f6b781e1a22ae Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 26 Dec 2023 16:51:49 -0800 Subject: [PATCH 076/205] remove accumulate_gradients_sharded and just use microbatched directly --- src/levanter/grad_accum.py | 58 ++++++++++++++------------------------ src/levanter/trainer.py | 14 ++++++--- tests/test_grad_accum.py | 8 +++--- 3 files changed, 35 insertions(+), 45 deletions(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index c8f821e45..3b2ae92c6 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -1,3 +1,4 @@ +import enum import functools from typing import Callable, Optional, ParamSpec, TypeVar @@ -9,29 +10,33 @@ import haliax as hax from haliax import Axis -from haliax.jax_utils import named_call from haliax.partitioning import ResourceAxis from haliax.util import is_jax_array_like, is_named_array -from levanter.types import M, ValAndGradFn, ValFn, X - Args = ParamSpec("Args") R = TypeVar("R") -def microbatched_mean( +class AccumType(enum.Enum): + SUM = enum.auto() + MEAN = enum.auto() + # TODO: add MAX? + + +# cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 +def microbatched( fn: Callable[Args, R], Batch: Axis, per_device_parallelism: int, accum_axis_mapping, compute_axis_mapping, patch_in_rng_key: Optional[str] = "key", + accum_type: AccumType = AccumType.MEAN, ) -> Callable[Args, R]: """ Wraps a function that takes a batch and changes it to instead take microbatches and accumulate the results - This function takes the *mean* of the microbatched results, so it only does what you want if the function - is taking the mean of the batch axis. + This function has to reduce the batch axis, so it can't be used for functions that need to keep the batch axis. Args: fn: a function to wrap @@ -41,6 +46,7 @@ def microbatched_mean( compute_axis_mapping: the axis mapping for the computation (typically this is the same as the inputs) patch_in_rng_key: if provided, this kwarg will be split, 1 for each accum step. It won't work if the PRNGKey is passed in as a positional argument. + accum_type: whether to sum or average the results Returns: a function that splits the batch into microbatches, calls the function on each microbatch, and @@ -59,6 +65,9 @@ def microbatched_mean( AccumStep = Axis("accum_step", num_micro_steps) assert num_micro_steps * microbatch_size == batch_size + if accum_type not in AccumType: + raise ValueError(f"accum_type must be one of {AccumType}") + @functools.wraps(fn) def wrapped_fn(*args, **kwargs): key = kwargs.get(patch_in_rng_key, None) @@ -74,7 +83,7 @@ def wrapped_fn(*args, **kwargs): def loop(acc, microbatch_and_key): microbatch, microbatch_kwargs, key = microbatch_and_key - with jax.named_scope("compute_microbatch"): + with jax.named_scope("compute"): microbatch_kwargs = microbatch_kwargs.copy() if key is not None: microbatch_kwargs[patch_in_rng_key] = key @@ -86,42 +95,17 @@ def loop(acc, microbatch_and_key): return acc - acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key)) - acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc) + with jax.named_scope("microbatched"): + acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key)) + + if accum_type == AccumType.MEAN: + acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc) return acc return wrapped_fn -# cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 -@named_call -def accumulate_gradients_sharded( - f: ValFn[M, X], - Batch: Axis, - per_device_parallelism: int, - parameter_axis_mapping, -) -> ValAndGradFn[M, X]: - """ - Accumulate gradients across a sharded batch, keeping a local copy of the gradient on each row of the data - parallel axis. (If the model is not sharded, then a copy of the gradient is on each individual device.) - - Parameters: - f: a function whose gradients are to be accumulated - per_device_parallelism: how many examples to process at once on each device - inputs: inputs with the batch axis. non-named arrays assume that the 0th axis is the batch axis. - parameter_axis_mapping: the axis mapping for the model parameters - key: an optional PRNG key for the random number generator. - If provided, this key will be split, 1 for each accum step - kwargs: passed to f - - """ - grad_fn = eqx.filter_value_and_grad(f, has_aux=False) - grad_fn = microbatched_mean(grad_fn, Batch, per_device_parallelism, parameter_axis_mapping, parameter_axis_mapping) - - return grad_fn - - def _reshape_for_microbatch(Batch: Axis, Microbatch: Axis, AccumStep: Axis, inputs, axis_mapping): def _reshape(x): if isinstance(x, hax.NamedArray): diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 60edff34e..d0733ae65 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -50,7 +50,7 @@ from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig -from levanter.grad_accum import accumulate_gradients_sharded +from levanter.grad_accum import microbatched from levanter.logging import capture_time from levanter.optim import SecondOrderTransformation from levanter.tracker import TrackerConfig @@ -472,9 +472,15 @@ def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): model = eqx.combine(trainable_model, rest_model) return self.loss_fn(model, *batch, **batch_kwargs, key=key) - loss, grads = accumulate_gradients_sharded( - split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping - )(trainable_model, rest_model, *batch, **batch_kwargs) + grad_fn = eqx.filter_value_and_grad(split_loss_fn, has_aux=False) + grad_fn = microbatched( + grad_fn, + self.TrainBatch, + self.config.per_device_parallelism, + self.parameter_axis_mapping, + self.parameter_axis_mapping, + ) + loss, grads = grad_fn(trainable_model, rest_model, *batch, **batch_kwargs) updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) if isinstance(self.optimizer, SecondOrderTransformation): diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py index 74854dee4..dd6bfa761 100644 --- a/tests/test_grad_accum.py +++ b/tests/test_grad_accum.py @@ -7,7 +7,7 @@ import haliax as hax import haliax.nn as hnn -from levanter.grad_accum import accumulate_gradients_sharded +from levanter.grad_accum import microbatched class Mlp(eqx.Module): @@ -56,9 +56,9 @@ def loss_fn(mlp, x): @hax.partitioning.named_jit(axis_resources=axis_mapping) def jit_grad_accum(mlp, x): - acc_v, acc_g = accumulate_gradients_sharded( - loss_fn, Batch, per_device_parallelism=parallelism, parameter_axis_mapping=axis_mapping - )( + grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) + grad_fn = microbatched(grad_fn, Batch, parallelism, axis_mapping, axis_mapping) + acc_v, acc_g = grad_fn( mlp, x, ) From 70b766fd4665c0a5de1a3b6632344eeecadb6efc Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 26 Dec 2023 23:13:31 -0800 Subject: [PATCH 077/205] add dtype for grad accum --- src/levanter/grad_accum.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 3b2ae92c6..10e9a4520 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -18,7 +18,7 @@ R = TypeVar("R") -class AccumType(enum.Enum): +class ReductionType(enum.Enum): SUM = enum.auto() MEAN = enum.auto() # TODO: add MAX? @@ -32,7 +32,8 @@ def microbatched( accum_axis_mapping, compute_axis_mapping, patch_in_rng_key: Optional[str] = "key", - accum_type: AccumType = AccumType.MEAN, + reduce: ReductionType = ReductionType.MEAN, + accum_dtype: Optional[jnp.dtype] = None, ) -> Callable[Args, R]: """ Wraps a function that takes a batch and changes it to instead take microbatches and accumulate the results @@ -46,7 +47,8 @@ def microbatched( compute_axis_mapping: the axis mapping for the computation (typically this is the same as the inputs) patch_in_rng_key: if provided, this kwarg will be split, 1 for each accum step. It won't work if the PRNGKey is passed in as a positional argument. - accum_type: whether to sum or average the results + reduce: whether to sum or average the results + accum_dtype: the dtype of floating point values in the accumulator. If None, this will be inferred from the return type of `fn`. Returns: a function that splits the batch into microbatches, calls the function on each microbatch, and @@ -65,8 +67,8 @@ def microbatched( AccumStep = Axis("accum_step", num_micro_steps) assert num_micro_steps * microbatch_size == batch_size - if accum_type not in AccumType: - raise ValueError(f"accum_type must be one of {AccumType}") + if reduce not in ReductionType: + raise ValueError(f"accum_type must be one of {ReductionType}") @functools.wraps(fn) def wrapped_fn(*args, **kwargs): @@ -75,9 +77,9 @@ def wrapped_fn(*args, **kwargs): key = jax.random.split(key, num_micro_steps) # first, determine the shape and make accumulator arrays r_shape = eqx.filter_eval_shape(fn, *args, **kwargs) - acc = jax.tree_util.tree_map( - functools.partial(_zeros_like, accum_axis_mapping), r_shape, is_leaf=is_named_array - ) + + _zeros = functools.partial(_zeros_like, accum_axis_mapping, accum_dtype) + acc = jax.tree_util.tree_map(_zeros, r_shape, is_leaf=is_named_array) args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping) @@ -98,7 +100,7 @@ def loop(acc, microbatch_and_key): with jax.named_scope("microbatched"): acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key)) - if accum_type == AccumType.MEAN: + if reduce == ReductionType.MEAN: acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc) return acc @@ -123,11 +125,11 @@ def _reshape(x): return jax.tree_util.tree_map(_reshape, inputs, is_leaf=is_named_array) -def _zeros_like(mapping, n): +def _zeros_like(mapping, dtype, n): if isinstance(n, hax.NamedArray): - return hax.auto_sharded(hax.zeros_like(n), mapping) + return hax.auto_sharded(hax.zeros_like(n, dtype=dtype), mapping) elif is_jax_array_like(n): - return jnp.zeros_like(n) + return jnp.zeros_like(n, dtype) else: assert jnp.isscalar(n) return 0.0 From 57725ea9eb8f6ddc6ff5274ad0dc321e1df1078b Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 26 Dec 2023 23:31:09 -0800 Subject: [PATCH 078/205] small refactor --- src/levanter/grad_accum.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 10e9a4520..9153ca3ba 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -72,15 +72,16 @@ def microbatched( @functools.wraps(fn) def wrapped_fn(*args, **kwargs): + # Special handling for PRNGKey key = kwargs.get(patch_in_rng_key, None) if key is not None: key = jax.random.split(key, num_micro_steps) + # first, determine the shape and make accumulator arrays r_shape = eqx.filter_eval_shape(fn, *args, **kwargs) + acc = _zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype) - _zeros = functools.partial(_zeros_like, accum_axis_mapping, accum_dtype) - acc = jax.tree_util.tree_map(_zeros, r_shape, is_leaf=is_named_array) - + # then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...) args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping) def loop(acc, microbatch_and_key): @@ -125,6 +126,12 @@ def _reshape(x): return jax.tree_util.tree_map(_reshape, inputs, is_leaf=is_named_array) +def _zeros_like_tree(r_shape, axis_mapping, accum_dtype): + _zeros = functools.partial(_zeros_like, axis_mapping, accum_dtype) + acc = jax.tree_util.tree_map(_zeros, r_shape, is_leaf=is_named_array) + return acc + + def _zeros_like(mapping, dtype, n): if isinstance(n, hax.NamedArray): return hax.auto_sharded(hax.zeros_like(n, dtype=dtype), mapping) From 85f777b8ad4248ec2a4480b962ee357038802c94 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 26 Dec 2023 23:46:47 -0800 Subject: [PATCH 079/205] small refactor --- src/levanter/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index d0733ae65..0f692ae14 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -462,7 +462,6 @@ def _train_step_fn(self): def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: key, new_key = jax.random.split(state.training_key) - opt_state = state.opt_state model = inference_mode(state.model, False) # we do this so that we only take the gradients of the trainable parameters @@ -482,7 +481,7 @@ def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): ) loss, grads = grad_fn(trainable_model, rest_model, *batch, **batch_kwargs) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + updates, opt_state = self.optimizer.update(grads, state.opt_state, params=trainable_model) if isinstance(self.optimizer, SecondOrderTransformation): opt_state = self.optimizer.update_hessian( opt_state, split_loss_fn, trainable_model, *batch, **batch_kwargs From f8d98fc8b207c3b3996fc134d2df3ce54ca467ab Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 27 Dec 2023 00:03:37 -0800 Subject: [PATCH 080/205] fix key handling in grad accum --- src/levanter/grad_accum.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 9153ca3ba..9b92e7b3a 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -72,16 +72,20 @@ def microbatched( @functools.wraps(fn) def wrapped_fn(*args, **kwargs): - # Special handling for PRNGKey - key = kwargs.get(patch_in_rng_key, None) - if key is not None: - key = jax.random.split(key, num_micro_steps) # first, determine the shape and make accumulator arrays r_shape = eqx.filter_eval_shape(fn, *args, **kwargs) acc = _zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype) # then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...) + + # Special handling for PRNGKey: it comes in as a single key, but we need to split it for each microbatch + key = kwargs.get(patch_in_rng_key, None) + if key is not None: + key = jax.random.split(key, num_micro_steps) + kwargs = kwargs.copy() + kwargs.pop(patch_in_rng_key) + args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping) def loop(acc, microbatch_and_key): From 5a8c77aac06701ae60b3cdd47786e3bc8e70c0f1 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 28 Dec 2023 11:15:20 -0800 Subject: [PATCH 081/205] make sophia work with non-trainables again --- .flake8 | 2 +- src/levanter/optim/sophia.py | 9 +++++---- src/levanter/optim/util.py | 15 ++++++++++----- src/levanter/trainer.py | 29 ++++++++++++----------------- src/levanter/utils/jax_utils.py | 19 +++++++++++++++++++ 5 files changed, 47 insertions(+), 27 deletions(-) diff --git a/.flake8 b/.flake8 index d067c43ce..636dc598f 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] exclude = .git max-line-length = 120 -ignore = E203, E501, W503, W605, F821, E266 +ignore = E203, E501, W503, W605, F821, E266, E731 per-file-ignores = */__init__.py: F401 examples/*.py: E402 diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 3cb07044c..ce41758cd 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -18,8 +18,8 @@ import levanter.tracker from levanter.optim.config import HessianOptConfig, OptimizerConfig from levanter.optim.second_order import SecondOrderTransformation, chain_second_order, inject_hyperparams -from levanter.optim.util import hvp, tree_gaussian -from levanter.utils.jax_utils import parameter_count +from levanter.optim.util import hvp, tree_gaussian_like +from levanter.utils.jax_utils import parameter_count, tree_filter_like M = TypeVar("M") @@ -335,7 +335,8 @@ def update_hessian(state, fn, model, *batch, **batch_kwargs): def _do_update(): key, next_key = jax.random.split(state.hess_key) new_hess = sophia_hess_fn(fn, model, *batch, hess_key=key, **batch_kwargs) - # new_hess = jax.tree_util.tree_map(lambda h: jnp.clip(h, -1, 1), new_hess) + + new_hess = tree_filter_like(state.h, new_hess) # EMAs of hessian hessian_count_inc = numerics.safe_int32_increment(state.hessian_count) @@ -404,7 +405,7 @@ def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs): # cf https://arxiv.org/pdf/2006.00719.pdf eqn 9 # https://www-users.cse.umn.edu/~saad/PDF/umsi-2005-082.pdf # https://arxiv.org/pdf/2208.03268.pdf - g = tree_gaussian(hess_key, model) + g = tree_gaussian_like(hess_key, model) # TODO: consider allowing for n > 1 gaussians? product = hvp(lambda m: fn(m, *args, **kwargs), model, g) hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g) diff --git a/src/levanter/optim/util.py b/src/levanter/optim/util.py index fccb427a2..7fd3a41df 100644 --- a/src/levanter/optim/util.py +++ b/src/levanter/optim/util.py @@ -1,18 +1,23 @@ import equinox as eqx import jax +from levanter.utils.jax_utils import is_inexact_arrayish + -# TODO: filter_jvp? def hvp(f, x, v): """Compute the Hessian-vector product of a function.""" - return jax.jvp(eqx.filter_grad(f), (x,), (v,))[1] + return eqx.filter_jvp(eqx.filter_grad(f), (x,), (v,))[1] -def tree_gaussian(key, tree): - """Samples a tree of gaussian noise with the same structure as `tree`.""" +def tree_gaussian_like(key, tree): + """ + Samples a tree of gaussian noise with the same structure as `tree`, except for leaves which are not inexact arrays, + for which it returns None + """ leaves, structure = jax.tree_util.tree_flatten(tree) keys = jax.random.split(key, len(leaves)) - g = jax.tree_util.tree_map(lambda x, key: jax.random.normal(key, x.shape), leaves, list(keys)) + rand_n = lambda x, key: jax.random.normal(key, x.shape) if is_inexact_arrayish(x) else None + g = jax.tree_util.tree_map(rand_n, leaves, list(keys)) g = jax.tree_util.tree_unflatten(structure, g) return g diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 0f692ae14..1a7866e22 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -90,6 +90,10 @@ class TrainerState(eqx.Module, Generic[M]): def step(self) -> int: return int(self._step) + @property + def trainable_model(self) -> M: + return eqx.filter(self.model, self.is_trainable) + S = TypeVar("S", bound=TrainerState) @@ -464,14 +468,7 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal key, new_key = jax.random.split(state.training_key) model = inference_mode(state.model, False) - # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) - - def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): - model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs, key=key) - - grad_fn = eqx.filter_value_and_grad(split_loss_fn, has_aux=False) + grad_fn = eqx.filter_value_and_grad(self.loss_fn, has_aux=False) grad_fn = microbatched( grad_fn, self.TrainBatch, @@ -479,13 +476,15 @@ def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): self.parameter_axis_mapping, self.parameter_axis_mapping, ) - loss, grads = grad_fn(trainable_model, rest_model, *batch, **batch_kwargs) + loss, grads = grad_fn(model, *batch, **batch_kwargs, key=key) - updates, opt_state = self.optimizer.update(grads, state.opt_state, params=trainable_model) + # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us + train_grads = _partition_trainable_params(grads, state.is_trainable)[0] + trainable_model = state.trainable_model + + updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) if isinstance(self.optimizer, SecondOrderTransformation): - opt_state = self.optimizer.update_hessian( - opt_state, split_loss_fn, trainable_model, *batch, **batch_kwargs - ) + opt_state = self.optimizer.update_hessian(opt_state, self.loss_fn, model, *batch, **batch_kwargs) model = eqx.apply_updates(model, updates) @@ -745,10 +744,6 @@ def _params_only(t): return eqx.filter(t, is_inexact_arrayish) -def _trainable_params_only(model: M, filter: PyTree[FilterSpec]) -> M: - return _partition_trainable_params(model, filter)[0] - - def _partition_trainable_params(model, filter): """ Partitions the model into trainable and non-trainable parameters. This is used internally diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 3e3f4d667..cf4875105 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -1,5 +1,6 @@ import contextlib import json +import warnings from dataclasses import fields from typing import Any, Callable, Optional, TypeVar @@ -177,3 +178,21 @@ def is_inexact_arrayish(x): return jnp.issubdtype(x.dtype, jnp.inexact) else: return False + + +def tree_filter_like(template: X, tree: X) -> X: + """ + Filters a tree to only include the leaves that are not None in the template. + + This is useful for filtering out nontrainable parameters from a tree. + """ + + def match_like(templ_leaf, tree_leaf): + if templ_leaf is None: + return None + else: + if tree_leaf is None: + warnings.warn(f"Template has a non-None value where tree is None. Template value: {templ_leaf}") + return tree_leaf + + return jax.tree_util.tree_map(match_like, template, tree, is_leaf=lambda x: x is None) From ff59e51faec143c9fce39d04b780d8f9d90a4e1e Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 28 Dec 2023 11:24:30 -0800 Subject: [PATCH 082/205] factor out some methods in train_step --- src/levanter/trainer.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 1a7866e22..177b65993 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -357,7 +357,7 @@ def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepI Performs a single training step. """ with capture_time() as step_time: - loss, new_state = self._train_step_fn(state, *batch, **batch_kwargs) + loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore @@ -457,7 +457,7 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar return ShardedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) @cached_property - def _train_step_fn(self): + def _jit_train_step_fn(self): return named_jit( axis_resources=self.parameter_axis_mapping, out_axis_resources=self.parameter_axis_mapping, @@ -468,6 +468,14 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal key, new_key = jax.random.split(state.training_key) model = inference_mode(state.model, False) + loss, grads = self._compute_gradients_microbatched(model, batch, **batch_kwargs, key=key) + + new_state = self._take_train_step(state, model, grads, batch, batch_kwargs) + new_state = dataclasses.replace(new_state, _step=state._step + 1, training_key=new_key) + + return loss, new_state + + def _compute_gradients_microbatched(self, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: grad_fn = eqx.filter_value_and_grad(self.loss_fn, has_aux=False) grad_fn = microbatched( grad_fn, @@ -476,23 +484,23 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal self.parameter_axis_mapping, self.parameter_axis_mapping, ) - loss, grads = grad_fn(model, *batch, **batch_kwargs, key=key) + return grad_fn(model, *batch, **batch_kwargs) + def _take_train_step(self, state, model, grads, batch, batch_kwargs) -> TrainerState: + """ + Takes a training step. This is a separate method so that it can be overridden or used in a subclass. + """ # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us train_grads = _partition_trainable_params(grads, state.is_trainable)[0] - trainable_model = state.trainable_model - + trainable_model = _partition_trainable_params(model, state.is_trainable)[0] updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) + + # Sophia, e.g. if isinstance(self.optimizer, SecondOrderTransformation): opt_state = self.optimizer.update_hessian(opt_state, self.loss_fn, model, *batch, **batch_kwargs) - model = eqx.apply_updates(model, updates) - new_state = dataclasses.replace( - state, _step=state._step + 1, model=model, opt_state=opt_state, training_key=new_key - ) - - return loss, new_state + return dataclasses.replace(state, model=model, opt_state=opt_state) def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): model = model_init() @@ -730,7 +738,7 @@ class AllConfig(Protocol): def initialize(config: TrainerConfig | AllConfig): """Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config - as hyperparameters and an artifact""" + as hyperparameters and as an artifact""" if isinstance(config, TrainerConfig): trainer_config = config else: @@ -740,10 +748,6 @@ def initialize(config: TrainerConfig | AllConfig): levanter.tracker.log_configuration(config) -def _params_only(t): - return eqx.filter(t, is_inexact_arrayish) - - def _partition_trainable_params(model, filter): """ Partitions the model into trainable and non-trainable parameters. This is used internally From d7a060d89b240528c6d40c3af75c3481d393d470 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Dec 2023 00:53:20 -0800 Subject: [PATCH 083/205] make the initialize_from logic just use load_checkpoint_or_initialize --- src/levanter/checkpoint.py | 23 +++++++++++++++++++---- src/levanter/trainer.py | 26 +++++++------------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index fed3f40b2..904b54f89 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -25,6 +25,7 @@ import haliax as hax import haliax.partitioning +from haliax.jax_utils import is_in_jit from haliax.partitioning import ResourceMapping from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore @@ -350,7 +351,6 @@ def load_checkpoint( def load_metadata(checkpoint_path, fs=None): if fs is None: - fs: AbstractFileSystem fs, _, _ = fsspec.get_fs_token_paths(str(checkpoint_path)) with fs.open(os.path.join(checkpoint_path, "metadata.json")) as metadata_in: metadata = json.load(metadata_in) @@ -439,16 +439,18 @@ def __post_init__(self): P = ParamSpec("P") +# TODO: add partial checkpoint loading + + def load_from_checkpoint_or_initialize( init_fn: Callable[P, M], checkpoint_path: str, axis_mapping: Optional[ResourceMapping] = None, mesh: Optional[Mesh] = None, *, - # TODO: add this back in - # allow_partial_checkpoint: bool, force_load_checkpoint: Optional[bool] = None, is_checkpointed: Optional[PyTree[FilterSpec]] = True, + subpath: Optional[str] = None, ) -> Callable[P, M]: """ Loads a checkpoint if it exists, otherwise initializes from scratch. @@ -465,7 +467,10 @@ def load_from_checkpoint_or_initialize( is_checkpointed: a filter spec for the checkpointed parameters. This is used to filter out non-checkpointed parameters for the initialization. If you don't specify this, all parameters are assumed to be checkpointed. + subpath: the subpath to load from the checkpoint. This is useful for loading, e.g., just the model and not + the entire training state. """ + if force_load_checkpoint is False: cmanager = mesh or contextlib.nullcontext() @@ -486,10 +491,20 @@ def fn(*args, **kwargs): if axis_mapping is not None: stack.enter_context(hax.axis_mapping(axis_mapping)) + if is_in_jit(): + # TODO: should we check if we're specifically in eval_shape? + logger.debug("In jit, not loading checkpoint. Assuming we're in eval_shape.") + # don't do io if we're in jit + return init_fn(*args, **kwargs) + ckpt_shape = eqx.filter_eval_shape(init_fn, *args, **kwargs) ckpt = load_checkpoint( - eqx.filter(ckpt_shape, is_checkpointed), checkpoint_path, axis_mapping=axis_mapping, mesh=mesh + eqx.filter(ckpt_shape, is_checkpointed), + checkpoint_path, + axis_mapping=axis_mapping, + mesh=mesh, + subpath=subpath, ) if ckpt is None: diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 7973d259b..4033cdbc6 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -42,6 +42,7 @@ from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit from haliax.types import IntScalar, Scalar +import levanter.checkpoint import levanter.logging import levanter.tracker import levanter.tracker.wandb @@ -314,29 +315,16 @@ def initial_state( assert model_init is not None if self.config.initialize_from is not None: - model_shape = eqx.filter_eval_shape(model_init) - model_shape = _partition_trainable_params(model_shape, is_trainable)[0] - # we always load the initial model b/c it might have different non-trainables - logger.info(f"Initializing model from checkpoint {self.config.initialize_from}") - ckpt_model = levanter.checkpoint.load_checkpoint( - model_shape, + model_init = load_from_checkpoint_or_initialize( + model_init, self.config.initialize_from, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh, + force_load_checkpoint=True, + is_checkpointed=is_trainable, subpath="model", ) - if ckpt_model is not None: - if model is not None: - # populate any missing parameters from the passed in model - model = eqx.combine(ckpt_model, model) - model_init = jax.tree_util.Partial(lambda m: m, model) - else: - old_model_init = model_init - model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), ckpt_model, old_model_init) - else: - raise RuntimeError(f"Could not load model from checkpoint {self.config.initialize_from}") - load_checkpoint_path = self.config.load_checkpoint_path if load_checkpoint_path is None: @@ -478,7 +466,7 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal loss, grads = self._compute_gradients_microbatched(model, batch, **batch_kwargs, key=key) - new_state = self._take_train_step(state, model, grads, batch, batch_kwargs) + new_state = self._take_train_step(state, model, grads, *batch, **batch_kwargs, key=key) new_state = dataclasses.replace(new_state, _step=state._step + 1, training_key=new_key) return loss, new_state @@ -494,7 +482,7 @@ def _compute_gradients_microbatched(self, model: M, batch, **batch_kwargs) -> tu ) return grad_fn(model, *batch, **batch_kwargs) - def _take_train_step(self, state, model, grads, batch, batch_kwargs) -> TrainerState: + def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: """ Takes a training step. This is a separate method so that it can be overridden or used in a subclass. """ From 8c44e64d0aacaebd61c9cf5967afea1c9cfa3c5b Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Dec 2023 22:40:38 -0800 Subject: [PATCH 084/205] on second thought load_from_checkpoint_or_initialize is the wrong abstraction --- src/levanter/checkpoint.py | 106 +------------------------------- src/levanter/trainer.py | 93 ++++++++++++++++++---------- src/levanter/utils/jax_utils.py | 7 +++ tests/test_checkpoint.py | 53 ---------------- 4 files changed, 69 insertions(+), 190 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 904b54f89..d1076f325 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -1,7 +1,5 @@ -import contextlib import dataclasses import datetime -import functools import json import logging import os @@ -9,24 +7,18 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Callable, List, Optional, ParamSpec, Sequence, TypeVar, Union +from typing import Callable, List, Optional, Sequence, TypeVar, Union import equinox -import equinox as eqx import fsspec import jax import jax.numpy as jnp from draccus import field from fsspec import AbstractFileSystem -from jax import ShapeDtypeStruct -from jax._src.interpreters.pxla import Mesh from jax.experimental.multihost_utils import broadcast_one_to_all, sync_global_devices from jaxtyping import PyTree -import haliax as hax import haliax.partitioning -from haliax.jax_utils import is_in_jit -from haliax.partitioning import ResourceMapping from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore from levanter.types import FilterSpec @@ -278,7 +270,7 @@ def load_checkpoint( discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[M]: +) -> M: """ Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint @@ -305,7 +297,7 @@ def load_checkpoint( checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore if checkpoint_path is None or not fs.exists(checkpoint_path): - return None + raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) @@ -436,96 +428,4 @@ def __post_init__(self): prev_interval = interval -P = ParamSpec("P") - - # TODO: add partial checkpoint loading - - -def load_from_checkpoint_or_initialize( - init_fn: Callable[P, M], - checkpoint_path: str, - axis_mapping: Optional[ResourceMapping] = None, - mesh: Optional[Mesh] = None, - *, - force_load_checkpoint: Optional[bool] = None, - is_checkpointed: Optional[PyTree[FilterSpec]] = True, - subpath: Optional[str] = None, -) -> Callable[P, M]: - """ - Loads a checkpoint if it exists, otherwise initializes from scratch. - - Args: - init_fn: the initialization function. This should be a function that takes some arguments and returns a - model or state. It should be jit-able and should not have any (destructive) side effects. It will - likely be called 2-3 times. - checkpoint_path: the path to the checkpoint - axis_mapping: the axis mapping to use for initialization. If None, the default axis mapping will be used. - mesh: the mesh to use for initialization. If None, the default mesh will be used. - force_load_checkpoint: if True, we must load a checkpoint. If False, we will not load a checkpoint. If None, - we will load a checkpoint if it exists. - is_checkpointed: a filter spec for the checkpointed parameters. This is used to filter out non-checkpointed - parameters for the initialization. If you don't specify this, all parameters are assumed to be - checkpointed. - subpath: the subpath to load from the checkpoint. This is useful for loading, e.g., just the model and not - the entire training state. - """ - - if force_load_checkpoint is False: - cmanager = mesh or contextlib.nullcontext() - - @functools.wraps(init_fn) - @hax.named_jit(axis_resources=axis_mapping) - def fn(*args, **kwargs): - with cmanager: - return init_fn(*args, **kwargs) - - return fn - else: - - @functools.wraps(init_fn) - def fn(*args, **kwargs): - with contextlib.ExitStack() as stack: - if mesh is not None: - stack.enter_context(mesh) - if axis_mapping is not None: - stack.enter_context(hax.axis_mapping(axis_mapping)) - - if is_in_jit(): - # TODO: should we check if we're specifically in eval_shape? - logger.debug("In jit, not loading checkpoint. Assuming we're in eval_shape.") - # don't do io if we're in jit - return init_fn(*args, **kwargs) - - ckpt_shape = eqx.filter_eval_shape(init_fn, *args, **kwargs) - - ckpt = load_checkpoint( - eqx.filter(ckpt_shape, is_checkpointed), - checkpoint_path, - axis_mapping=axis_mapping, - mesh=mesh, - subpath=subpath, - ) - - if ckpt is None: - if force_load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {checkpoint_path}") - else: - out = hax.named_jit(init_fn, axis_mapping)(*args, **kwargs) - return out - else: - ckpt = eqx.combine(ckpt, ckpt_shape) - if any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_util.tree_leaves(ckpt)): - # if we're resuming, we need to initialize any non-checkpointed values - @hax.named_jit(axis_resources=axis_mapping) - def partial_init(init_fn, *args, **kwargs): - m = init_fn(*args, **kwargs) - return eqx.filter(m, is_checkpointed, inverse=True) - - non_checkpointed = partial_init(init_fn, *args, **kwargs) - ckpt = eqx.filter(ckpt, lambda x: not isinstance(x, ShapeDtypeStruct)) - return eqx.combine(ckpt, non_checkpointed) - else: - return ckpt - - return fn diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 4033cdbc6..527ffea8f 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -28,7 +28,6 @@ import equinox as eqx import jax -import jax.numpy as jnp import jmp import numpy as np from draccus import field @@ -47,7 +46,7 @@ import levanter.tracker import levanter.tracker.wandb from levanter import tracker -from levanter.checkpoint import CheckpointerConfig, load_from_checkpoint_or_initialize +from levanter.checkpoint import CheckpointerConfig, load_checkpoint from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig @@ -57,7 +56,7 @@ from levanter.tracker import TrackerConfig from levanter.types import ComputeLossFunction, FilterSpec, ModuleComputeLoss from levanter.utils import cloud_utils -from levanter.utils.jax_utils import is_inexact_arrayish +from levanter.utils.jax_utils import as_arrayish, is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -81,7 +80,7 @@ class TrainerState(eqx.Module, Generic[M]): It's designed to be extended by subclasses. """ - _step: IntScalar = eqx.field(converter=lambda x: jnp.asarray(x) if not isinstance(x, bool) else x) + _step: IntScalar = eqx.field(converter=lambda x: as_arrayish(x)) model: M opt_state: OptState training_key: PRNGKeyArray @@ -200,7 +199,7 @@ def loss_fn(self): Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute """ - @named_jit(in_axis_resources=self.parameter_axis_mapping, axis_resources=self.compute_axis_mapping) + @named_jit(axis_resources=self.compute_axis_mapping) @functools.wraps(self._raw_loss_function) def fn(model, *batch, **batch_kwargs): with hax.axis_mapping(self.compute_axis_mapping): @@ -293,7 +292,11 @@ def initial_state( is_trainable: PyTree[FilterSpec] = True, ) -> TrainerState[M]: """ - Initializes the model, optimizer state, and random key. Also handles loading a checkpoint if needed. + Either loads a checkpoint or initializes a fresh trainer state. This is the recommended way to initialize + a trainer state. + + This method is smart enough to handle subclasses of TrainerState. If you want to extend TrainerState, you + can override _initialize_state_from_scratch Args is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable @@ -312,39 +315,55 @@ def initial_state( # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) + del model assert model_init is not None - if self.config.initialize_from is not None: - model_init = load_from_checkpoint_or_initialize( - model_init, + # we don't save the full trainer state, so we need to filter out the non-trainable parameters + trainer_state_shape = eqx.filter_eval_shape( + self._initialize_state_from_scratch, model_init, training_key, is_trainable + ) + saveable_state_shape = self._make_saveable_trainer_state(trainer_state_shape, is_trainable) + + # first try to load a full trainer state checkpoint + path = self.config.load_checkpoint_path + if path is None: + path = self.config.checkpointer.expanded_path(self.run_id) + + if self.config.load_checkpoint is not False: + try: + state = load_checkpoint( + saveable_state_shape, path, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh + ) + except FileNotFoundError: + if self.config.load_checkpoint: + raise + else: + state = None + + # if that fails, try to load just a model from a checkpoint for initialization + if state is None and self.config.initialize_from is not None: + logger.info(f"Initializing from {self.config.initialize_from}") + # todo: we are potentially holding two models in memory at once here, if we pass in a model + # instead of a model_init and we use initialize_from. We could avoid this by deleting + # any to-be-loaded parameters from the model before loading, but that's a bit more complicated + loaded_model = load_checkpoint( + saveable_state_shape.model, self.config.initialize_from, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh, - force_load_checkpoint=True, - is_checkpointed=is_trainable, subpath="model", ) - load_checkpoint_path = self.config.load_checkpoint_path + # we don't necessarily load the full model, so we need to combine it with the model init + model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), loaded_model, model_init) - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + # now we initialize a fresh trainer state, possibly just to finish any missing fields + @named_jit(axis_resources=self.parameter_axis_mapping, donate_args=(True, True, True, False)) + def init_state(partial_state, model_init, training_key, is_trainable): + fresh_state = self._initialize_state_from_scratch(model_init, training_key, is_trainable) + return eqx.combine(partial_state, fresh_state) - with self: - assert model_init is not None - state = load_from_checkpoint_or_initialize( - self._initialize_state_from_scratch, - load_checkpoint_path, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - force_load_checkpoint=self.config.load_checkpoint, - # if we're loading a checkpoint, we need to know which parameters are trainable - is_checkpointed=TrainerState(True, is_trainable, True, True, False), # type: ignore - )( - model_init, - training_key, - is_trainable, - ) + state = init_state(state, model_init, training_key, is_trainable) return state @@ -454,11 +473,7 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar @cached_property def _jit_train_step_fn(self): - return named_jit( - axis_resources=self.parameter_axis_mapping, - out_axis_resources=self.parameter_axis_mapping, - donate_args=(True,), - )(self._train_step) + return named_jit(self._train_step, axis_resources=self.parameter_axis_mapping, donate_args=(True,)) def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: key, new_key = jax.random.split(state.training_key) @@ -511,6 +526,16 @@ def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_k return TrainerState(0, model, opt_state, training_key, is_trainable) + def _make_saveable_trainer_state(self, trainer_state: S, is_trainable) -> S: + """ + Returns the shape of the trainer state that we save to a checkpoint. This is used to load a checkpoint. + You can override if you really need custom checkpointing logic. By default everything in the trainer state + is saved (except for non-trainable model parameters) + """ + saveable_model = eqx.filter(trainer_state.model, is_trainable) + saveable_state = dataclasses.replace(trainer_state, model=saveable_model) + return saveable_state + def _initialize_global_tracker(config, run_id): if isinstance(config, Sequence): diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index cf4875105..2cb275d70 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -196,3 +196,10 @@ def match_like(templ_leaf, tree_leaf): return tree_leaf return jax.tree_util.tree_map(match_like, template, tree, is_leaf=lambda x: x is None) + + +def as_arrayish(x): + if hasattr(x, "shape") and hasattr(x, "dtype"): + return x + else: + return jnp.asarray(x) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 14668aa5f..e511cb11d 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -9,19 +9,13 @@ import numpy as np import optax from chex import assert_trees_all_close -from jax import ShapeDtypeStruct from jax import numpy as jnp -from jax import tree_util as jtu - -import haliax.nn -from haliax import Axis from levanter.checkpoint import ( Checkpointer, CheckpointInterval, discover_latest_checkpoint, load_checkpoint, - load_from_checkpoint_or_initialize, load_metadata, save_checkpoint, ) @@ -231,50 +225,3 @@ def test_checkpoint_discovery(): assert latest == f"{tempdir}/step-30" assert discover_latest_checkpoint("file:///tmp/does-not-exist") is None - - -def test_load_from_checkpoint_or_initialize(): - In = Axis("in", 2) - Out = Axis("out", 1) - - def init_fn(key): - return haliax.nn.MLP.init(In, Out, 2, 3, key=key) - - k0 = jax.random.PRNGKey(0) - k1 = jax.random.PRNGKey(1) - - model0 = init_fn(k0) - model1 = init_fn(k1) - - is_checkpointed = jtu.tree_map(lambda _: False, model0) - is_checkpointed = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) - is_checkpointed1 = jtu.tree_map(lambda _: False, model1) - is_checkpointed1 = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) - with jax.sharding.Mesh(jax.devices(), ("devices",)): - with tempfile.TemporaryDirectory() as tmpdir: - filtered = eqx.filter(model0, is_checkpointed) - save_checkpoint(filtered, step=0, checkpoint_path=tmpdir, exist_ok=True) - - loaded = load_from_checkpoint_or_initialize(init_fn, tmpdir, is_checkpointed=is_checkpointed)(k1) - - assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct)))) - - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed))), - ) - - assert_trees_not_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed, inverse=True))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed, inverse=True))), - ) - - assert_trees_not_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed))), - ) - - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed, inverse=True))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed1, inverse=True))), - ) From 72f1e47f9100b514f3b6a61cdc30e3daaa1c4394 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Dec 2023 22:40:59 -0800 Subject: [PATCH 085/205] wip --- src/levanter/doremi.py | 73 ++++++++++++------------------------------ 1 file changed, 21 insertions(+), 52 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index dd054c8c9..6c588453c 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -6,7 +6,9 @@ import jax.numpy as jnp import jax.random as jrandom import optax -from haliax import Scalar + +import levanter.tracker +from haliax import Scalar, named_jit from jaxtyping import PRNGKeyArray import haliax as hax @@ -14,10 +16,12 @@ from levanter.data import ShardableDataset from levanter.data.mixture import MixtureDataset -from levanter.trainer import M, StepInfo, Trainer, TrainerState +from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState from levanter.types import ComputeLossFunction from optax._src.base import GradientTransformation +from levanter.utils.tree_utils import inference_mode + T = TypeVar("T") class DoremiState(TrainerState): @@ -30,15 +34,7 @@ def update_alpha(self, alpha): class DoReMiTrainer(Trainer): - - - def __init__(self, config: "TrainerConfig", optimizer: GradientTransformation, - initial_alpha: hax.NamedArray, - loss_fn: Optional[ComputeLossFunction] = None, - ): - super().__init__(config, optimizer, loss_fn) - self.initial_alpha = initial_alpha - + # we just use the DoReMi trainer for state management def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): base_state = super()._initialize_state_from_scratch(model_init, training_key, is_trainable) @@ -48,37 +44,9 @@ def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_k base_state.training_key, self.initial_alpha) - def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: - key, new_key = jax.random.split(state.training_key) - opt_state = state.opt_state - model = inference_mode(state.model, False) - - # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = _partition_trainable_params(model, state.is_trainable) - - def split_loss_fn(trainable_model, rest_model, *batch, **batch_kwargs): - model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs, key=key) - - loss, grads = accumulate_gradients_sharded( - split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping - )(trainable_model, rest_model, *batch, **batch_kwargs) - - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) - if isinstance(self.optimizer, SecondOrderTransformation): - opt_state = self.optimizer.update_hessian( - opt_state, split_loss_fn, trainable_model, *batch, **batch_kwargs - ) - - model = eqx.apply_updates(model, updates) - - new_state = dataclasses.replace( - state, _step=state._step + 1, model=model, opt_state=opt_state, training_key=new_key - ) - def estimate_mixture_weights( - trainer: Trainer, + trainer: TrainerConfig, loss_fn: ComputeLossFunction, initial_proxy, ref, @@ -98,20 +66,22 @@ def estimate_mixture_weights( domain_indices = list(data_sources.keys()) domain_to_index = {domain: index for index, domain in enumerate(domain_indices)} - # Initialize domain weights. # TODO: should we initialize to the ref or to uniform? Domain = hax.Axis("domain", len(domain_indices)) initial_alpha = hax.ones(Domain) / Domain.size + trainer = DoReMiTrainer(trainer, optax.adamw(1e-3), ref, initial_alpha, loss_fn) + # calculate per-token losses for proxy and ref - def compute_excess_loss(proxy, ref, batch): + def compute_excess_loss(ref, proxy, batch): proxy_losses = loss_fn(proxy, batch, reduction_axis=()) - ref_losses = loss_fn(proxy, batch, reduction_axis=()) + ref_losses = loss_fn(ref, batch, reduction_axis=()) # calculate excess losses excess_losses = proxy_losses - ref_losses return excess_losses + # Loss is alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) def proxy_model_loss(excess_losses, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch @@ -123,8 +93,9 @@ def proxy_model_loss(excess_losses, domains, alpha): return loss - @hax.named_jit(axis_resources=trainer.parameter_axis_mapping) - def doremi_step(proxy, opt_state, alpha, batch, domains): + @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True, )) + def doremi_step(state: DoremiState, batch, domains): + proxy = inference_mode(state.model, False) # this is one of those times when PyTorch's backward() is nice excess_losses, excess_backward = eqx.filter_vjp(lambda proxy: compute_excess_loss(proxy, ref, batch), proxy) @@ -134,22 +105,20 @@ def doremi_step(proxy, opt_state, alpha, batch, domains): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch per_domain_losses = hax.dot(excess_losses.axes, one_hot_domains, clipped_losses) - old_alpha = alpha - alpha = alpha * hax.exp(domain_weight_step * per_domain_losses) + alpha = state.alpha * hax.exp(domain_weight_step * per_domain_losses) alpha /= hax.sum(alpha) alpha = (1 - smoothing) * alpha + initial_alpha * smoothing - # TODO: log this - alpha_distance = hax.sum(hax.abs(alpha - old_alpha)) + alpha_distance = hax.sum(hax.abs(alpha - state.alpha)) + levanter.tracker.log_metrics({"alpha_distance": alpha_distance}, step=state.step) # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) loss, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) grad = excess_backward(grad_loss) - updates, new_opt_state = trainer.optimizer.update(opt_state, grad, params=proxy) - proxy = optax.apply_updates(proxy, updates) + new_state = trainer._take_train_step(state, proxy, grad, batch) - return loss, proxy, new_opt_state, alpha, alpha_distance + return loss, new_state # TODO: we don't support serializing stuff from anything other than the model and the opt_state. should fix. running_alpha_mean = initial_alpha From add3df42b5b14aa9a64ad0bd400f0e2412b688bd Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Dec 2023 22:40:38 -0800 Subject: [PATCH 086/205] on second thought load_from_checkpoint_or_initialize is the wrong abstraction --- src/levanter/checkpoint.py | 92 ++------------------------- src/levanter/trainer.py | 106 ++++++++++++++++++-------------- src/levanter/utils/jax_utils.py | 7 +++ tests/test_checkpoint.py | 53 ---------------- 4 files changed, 70 insertions(+), 188 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index fed3f40b2..b5c9e442a 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -1,7 +1,5 @@ -import contextlib import dataclasses import datetime -import functools import json import logging import os @@ -9,23 +7,18 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Callable, List, Optional, ParamSpec, Sequence, TypeVar, Union +from typing import Callable, List, Optional, Sequence, TypeVar, Union import equinox -import equinox as eqx import fsspec import jax import jax.numpy as jnp from draccus import field from fsspec import AbstractFileSystem -from jax import ShapeDtypeStruct -from jax._src.interpreters.pxla import Mesh from jax.experimental.multihost_utils import broadcast_one_to_all, sync_global_devices from jaxtyping import PyTree -import haliax as hax import haliax.partitioning -from haliax.partitioning import ResourceMapping from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore from levanter.types import FilterSpec @@ -277,7 +270,7 @@ def load_checkpoint( discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[M]: +) -> M: """ Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint @@ -304,7 +297,7 @@ def load_checkpoint( checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore if checkpoint_path is None or not fs.exists(checkpoint_path): - return None + raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) @@ -436,81 +429,4 @@ def __post_init__(self): prev_interval = interval -P = ParamSpec("P") - - -def load_from_checkpoint_or_initialize( - init_fn: Callable[P, M], - checkpoint_path: str, - axis_mapping: Optional[ResourceMapping] = None, - mesh: Optional[Mesh] = None, - *, - # TODO: add this back in - # allow_partial_checkpoint: bool, - force_load_checkpoint: Optional[bool] = None, - is_checkpointed: Optional[PyTree[FilterSpec]] = True, -) -> Callable[P, M]: - """ - Loads a checkpoint if it exists, otherwise initializes from scratch. - - Args: - init_fn: the initialization function. This should be a function that takes some arguments and returns a - model or state. It should be jit-able and should not have any (destructive) side effects. It will - likely be called 2-3 times. - checkpoint_path: the path to the checkpoint - axis_mapping: the axis mapping to use for initialization. If None, the default axis mapping will be used. - mesh: the mesh to use for initialization. If None, the default mesh will be used. - force_load_checkpoint: if True, we must load a checkpoint. If False, we will not load a checkpoint. If None, - we will load a checkpoint if it exists. - is_checkpointed: a filter spec for the checkpointed parameters. This is used to filter out non-checkpointed - parameters for the initialization. If you don't specify this, all parameters are assumed to be - checkpointed. - """ - if force_load_checkpoint is False: - cmanager = mesh or contextlib.nullcontext() - - @functools.wraps(init_fn) - @hax.named_jit(axis_resources=axis_mapping) - def fn(*args, **kwargs): - with cmanager: - return init_fn(*args, **kwargs) - - return fn - else: - - @functools.wraps(init_fn) - def fn(*args, **kwargs): - with contextlib.ExitStack() as stack: - if mesh is not None: - stack.enter_context(mesh) - if axis_mapping is not None: - stack.enter_context(hax.axis_mapping(axis_mapping)) - - ckpt_shape = eqx.filter_eval_shape(init_fn, *args, **kwargs) - - ckpt = load_checkpoint( - eqx.filter(ckpt_shape, is_checkpointed), checkpoint_path, axis_mapping=axis_mapping, mesh=mesh - ) - - if ckpt is None: - if force_load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {checkpoint_path}") - else: - out = hax.named_jit(init_fn, axis_mapping)(*args, **kwargs) - return out - else: - ckpt = eqx.combine(ckpt, ckpt_shape) - if any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_util.tree_leaves(ckpt)): - # if we're resuming, we need to initialize any non-checkpointed values - @hax.named_jit(axis_resources=axis_mapping) - def partial_init(init_fn, *args, **kwargs): - m = init_fn(*args, **kwargs) - return eqx.filter(m, is_checkpointed, inverse=True) - - non_checkpointed = partial_init(init_fn, *args, **kwargs) - ckpt = eqx.filter(ckpt, lambda x: not isinstance(x, ShapeDtypeStruct)) - return eqx.combine(ckpt, non_checkpointed) - else: - return ckpt - - return fn +# TODO: add partial checkpoint loading diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 177b65993..aa1e4c40d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -28,7 +28,6 @@ import equinox as eqx import jax -import jax.numpy as jnp import jmp import numpy as np from draccus import field @@ -46,7 +45,7 @@ import levanter.tracker import levanter.tracker.wandb from levanter import tracker -from levanter.checkpoint import CheckpointerConfig, load_from_checkpoint_or_initialize +from levanter.checkpoint import CheckpointerConfig, load_checkpoint from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig @@ -56,7 +55,7 @@ from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils -from levanter.utils.jax_utils import is_inexact_arrayish +from levanter.utils.jax_utils import as_arrayish, is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -80,7 +79,7 @@ class TrainerState(eqx.Module, Generic[M]): It's designed to be extended by subclasses. """ - _step: IntScalar = eqx.field(converter=lambda x: jnp.asarray(x) if not isinstance(x, bool) else x) + _step: IntScalar = eqx.field(converter=lambda x: as_arrayish(x)) model: M opt_state: OptState training_key: PRNGKeyArray @@ -195,7 +194,7 @@ def loss_fn(self): Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute """ - @named_jit(in_axis_resources=self.parameter_axis_mapping, axis_resources=self.compute_axis_mapping) + @named_jit(axis_resources=self.compute_axis_mapping) @functools.wraps(self._raw_loss_function) def fn(model, *batch, **batch_kwargs): with hax.axis_mapping(self.compute_axis_mapping): @@ -284,7 +283,11 @@ def initial_state( is_trainable: PyTree[FilterSpec] = True, ) -> TrainerState[M]: """ - Initializes the model, optimizer state, and random key. Also handles loading a checkpoint if needed. + Either loads a checkpoint or initializes a fresh trainer state. This is the recommended way to initialize + a trainer state. + + This method is smart enough to handle subclasses of TrainerState. If you want to extend TrainerState, you + can override _initialize_state_from_scratch Args is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable @@ -303,52 +306,55 @@ def initial_state( # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) + del model assert model_init is not None - if self.config.initialize_from is not None: - model_shape = eqx.filter_eval_shape(model_init) - model_shape = _partition_trainable_params(model_shape, is_trainable)[0] - # we always load the initial model b/c it might have different non-trainables - logger.info(f"Initializing model from checkpoint {self.config.initialize_from}") - ckpt_model = levanter.checkpoint.load_checkpoint( - model_shape, + # we don't save the full trainer state, so we need to filter out the non-trainable parameters + trainer_state_shape = eqx.filter_eval_shape( + self._initialize_state_from_scratch, model_init, training_key, is_trainable + ) + saveable_state_shape = self._make_saveable_trainer_state(trainer_state_shape, is_trainable) + + # first try to load a full trainer state checkpoint + path = self.config.load_checkpoint_path + if path is None: + path = self.config.checkpointer.expanded_path(self.run_id) + + if self.config.load_checkpoint is not False: + try: + state = load_checkpoint( + saveable_state_shape, path, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh + ) + except FileNotFoundError: + if self.config.load_checkpoint: + raise + else: + state = None + + # if that fails, try to load just a model from a checkpoint for initialization + if state is None and self.config.initialize_from is not None: + logger.info(f"Initializing from {self.config.initialize_from}") + # todo: we are potentially holding two models in memory at once here, if we pass in a model + # instead of a model_init and we use initialize_from. We could avoid this by deleting + # any to-be-loaded parameters from the model before loading, but that's a bit more complicated + loaded_model = load_checkpoint( + saveable_state_shape.model, self.config.initialize_from, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh, subpath="model", ) - if ckpt_model is not None: - if model is not None: - # populate any missing parameters from the passed in model - model = eqx.combine(ckpt_model, model) - model_init = jax.tree_util.Partial(lambda m: m, model) - else: - old_model_init = model_init - model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), ckpt_model, old_model_init) - else: - raise RuntimeError(f"Could not load model from checkpoint {self.config.initialize_from}") + # we don't necessarily load the full model, so we need to combine it with the model init + model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), loaded_model, model_init) - load_checkpoint_path = self.config.load_checkpoint_path + # now we initialize a fresh trainer state, possibly just to finish any missing fields + @named_jit(axis_resources=self.parameter_axis_mapping, donate_args=(True, True, True, False)) + def init_state(partial_state, model_init, training_key, is_trainable): + fresh_state = self._initialize_state_from_scratch(model_init, training_key, is_trainable) + return eqx.combine(partial_state, fresh_state) - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - - with self: - assert model_init is not None - state = load_from_checkpoint_or_initialize( - self._initialize_state_from_scratch, - load_checkpoint_path, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - force_load_checkpoint=self.config.load_checkpoint, - # if we're loading a checkpoint, we need to know which parameters are trainable - is_checkpointed=TrainerState(True, is_trainable, True, True, False), # type: ignore - )( - model_init, - training_key, - is_trainable, - ) + state = init_state(state, model_init, training_key, is_trainable) return state @@ -458,11 +464,7 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar @cached_property def _jit_train_step_fn(self): - return named_jit( - axis_resources=self.parameter_axis_mapping, - out_axis_resources=self.parameter_axis_mapping, - donate_args=(True,), - )(self._train_step) + return named_jit(self._train_step, axis_resources=self.parameter_axis_mapping, donate_args=(True,)) def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: key, new_key = jax.random.split(state.training_key) @@ -515,6 +517,16 @@ def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_k return TrainerState(0, model, opt_state, training_key, is_trainable) + def _make_saveable_trainer_state(self, trainer_state: S, is_trainable) -> S: + """ + Returns the shape of the trainer state that we save to a checkpoint. This is used to load a checkpoint. + You can override if you really need custom checkpointing logic. By default everything in the trainer state + is saved (except for non-trainable model parameters) + """ + saveable_model = eqx.filter(trainer_state.model, is_trainable) + saveable_state = dataclasses.replace(trainer_state, model=saveable_model) + return saveable_state + def _initialize_global_tracker(config, run_id): if isinstance(config, Sequence): diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index cf4875105..2cb275d70 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -196,3 +196,10 @@ def match_like(templ_leaf, tree_leaf): return tree_leaf return jax.tree_util.tree_map(match_like, template, tree, is_leaf=lambda x: x is None) + + +def as_arrayish(x): + if hasattr(x, "shape") and hasattr(x, "dtype"): + return x + else: + return jnp.asarray(x) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 14668aa5f..e511cb11d 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -9,19 +9,13 @@ import numpy as np import optax from chex import assert_trees_all_close -from jax import ShapeDtypeStruct from jax import numpy as jnp -from jax import tree_util as jtu - -import haliax.nn -from haliax import Axis from levanter.checkpoint import ( Checkpointer, CheckpointInterval, discover_latest_checkpoint, load_checkpoint, - load_from_checkpoint_or_initialize, load_metadata, save_checkpoint, ) @@ -231,50 +225,3 @@ def test_checkpoint_discovery(): assert latest == f"{tempdir}/step-30" assert discover_latest_checkpoint("file:///tmp/does-not-exist") is None - - -def test_load_from_checkpoint_or_initialize(): - In = Axis("in", 2) - Out = Axis("out", 1) - - def init_fn(key): - return haliax.nn.MLP.init(In, Out, 2, 3, key=key) - - k0 = jax.random.PRNGKey(0) - k1 = jax.random.PRNGKey(1) - - model0 = init_fn(k0) - model1 = init_fn(k1) - - is_checkpointed = jtu.tree_map(lambda _: False, model0) - is_checkpointed = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) - is_checkpointed1 = jtu.tree_map(lambda _: False, model1) - is_checkpointed1 = eqx.tree_at(lambda t: t.layers[-1], is_checkpointed, replace=True) - with jax.sharding.Mesh(jax.devices(), ("devices",)): - with tempfile.TemporaryDirectory() as tmpdir: - filtered = eqx.filter(model0, is_checkpointed) - save_checkpoint(filtered, step=0, checkpoint_path=tmpdir, exist_ok=True) - - loaded = load_from_checkpoint_or_initialize(init_fn, tmpdir, is_checkpointed=is_checkpointed)(k1) - - assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct)))) - - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed))), - ) - - assert_trees_not_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed, inverse=True))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model0, is_checkpointed, inverse=True))), - ) - - assert_trees_not_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed))), - ) - - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed, inverse=True))), - jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed1, inverse=True))), - ) From b6535b53cde27070fd986fd32560650ba6d8520d Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Dec 2023 23:49:00 -0800 Subject: [PATCH 087/205] wip factoring out the initial state stuff, again --- src/levanter/trainer.py | 115 ++++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 46 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 527ffea8f..fdfdb40fc 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -318,39 +318,46 @@ def initial_state( del model assert model_init is not None + # first try to load a full trainer state checkpoint + checkpoint_path = self.config.load_checkpoint_path + if checkpoint_path is None: + checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + + do_load_checkpoint = self.config.load_checkpoint + axis_mapping = self.parameter_axis_mapping + mesh = self.device_mesh + initial_model_path = self.config.initialize_from + # we don't save the full trainer state, so we need to filter out the non-trainable parameters - trainer_state_shape = eqx.filter_eval_shape( - self._initialize_state_from_scratch, model_init, training_key, is_trainable - ) - saveable_state_shape = self._make_saveable_trainer_state(trainer_state_shape, is_trainable) - # first try to load a full trainer state checkpoint - path = self.config.load_checkpoint_path - if path is None: - path = self.config.checkpointer.expanded_path(self.run_id) + def init_state_and_model(model_init, training_key, is_trainable): + model = model_init() + state = self._initialize_state_from_scratch(model, training_key, is_trainable) + return state - if self.config.load_checkpoint is not False: + trainer_state_shape = eqx.filter_eval_shape(init_state_and_model, model_init, training_key, is_trainable) + saveable_state_shape = _make_saveable_trainer_state(trainer_state_shape, is_trainable) + + if do_load_checkpoint is not False: try: - state = load_checkpoint( - saveable_state_shape, path, axis_mapping=self.parameter_axis_mapping, mesh=self.device_mesh - ) + state = load_checkpoint(saveable_state_shape, checkpoint_path, axis_mapping=axis_mapping, mesh=mesh) except FileNotFoundError: - if self.config.load_checkpoint: + if do_load_checkpoint: raise else: state = None # if that fails, try to load just a model from a checkpoint for initialization - if state is None and self.config.initialize_from is not None: - logger.info(f"Initializing from {self.config.initialize_from}") + if state is None and initial_model_path is not None: + logger.info(f"Initializing from {initial_model_path}") # todo: we are potentially holding two models in memory at once here, if we pass in a model # instead of a model_init and we use initialize_from. We could avoid this by deleting # any to-be-loaded parameters from the model before loading, but that's a bit more complicated loaded_model = load_checkpoint( saveable_state_shape.model, - self.config.initialize_from, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, + initial_model_path, + axis_mapping=axis_mapping, + mesh=mesh, subpath="model", ) @@ -358,9 +365,10 @@ def initial_state( model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), loaded_model, model_init) # now we initialize a fresh trainer state, possibly just to finish any missing fields - @named_jit(axis_resources=self.parameter_axis_mapping, donate_args=(True, True, True, False)) + @named_jit(axis_resources=axis_mapping, donate_args=(True, True, True, False)) def init_state(partial_state, model_init, training_key, is_trainable): - fresh_state = self._initialize_state_from_scratch(model_init, training_key, is_trainable) + model = model_init() + fresh_state = self._initialize_state_from_scratch(model, training_key, is_trainable) return eqx.combine(partial_state, fresh_state) state = init_state(state, model_init, training_key, is_trainable) @@ -502,39 +510,54 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: Takes a training step. This is a separate method so that it can be overridden or used in a subclass. """ # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us - train_grads = _partition_trainable_params(grads, state.is_trainable)[0] - trainable_model = _partition_trainable_params(model, state.is_trainable)[0] - updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) - - # Sophia, e.g. - if isinstance(self.optimizer, SecondOrderTransformation): - opt_state = self.optimizer.update_hessian(opt_state, self.loss_fn, model, *batch, **batch_kwargs) - model = eqx.apply_updates(model, updates) + with hax.axis_mapping(self.parameter_axis_mapping): + train_grads = _partition_trainable_params(grads, state.is_trainable)[0] + trainable_model = _partition_trainable_params(model, state.is_trainable)[0] + updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) - return dataclasses.replace(state, model=model, opt_state=opt_state) + # Sophia, e.g. + if isinstance(self.optimizer, SecondOrderTransformation): + opt_state = self.optimizer.update_hessian(opt_state, self.loss_fn, model, *batch, **batch_kwargs) + model = eqx.apply_updates(model, updates) - def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): - model = model_init() + return dataclasses.replace(state, model=model, opt_state=opt_state) + def _initialize_state_from_scratch(self, model, training_key, is_trainable): # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = _partition_trainable_params(model, is_trainable) - trainable = self.mp.cast_to_param(trainable) - non_trainable = self.mp.cast_to_compute(non_trainable) - model = eqx.combine(trainable, non_trainable) - - opt_state = self.optimizer.init(trainable) + model = cast_params_by_trainability(model, self.mp, is_trainable) + opt_state = init_optimizer_for_trainables(self.optimizer, model, is_trainable) return TrainerState(0, model, opt_state, training_key, is_trainable) - def _make_saveable_trainer_state(self, trainer_state: S, is_trainable) -> S: - """ - Returns the shape of the trainer state that we save to a checkpoint. This is used to load a checkpoint. - You can override if you really need custom checkpointing logic. By default everything in the trainer state - is saved (except for non-trainable model parameters) - """ - saveable_model = eqx.filter(trainer_state.model, is_trainable) - saveable_state = dataclasses.replace(trainer_state, model=saveable_model) - return saveable_state + +def init_optimizer_for_trainables(optimizer, model, is_trainable): + trainable, _ = _partition_trainable_params(model, is_trainable) + opt_state = optimizer.init(trainable) + return opt_state + + +def cast_params_by_trainability(model, mp, is_trainable): + """ + Casts the parameters of a model to the appropriate precision based on the is_trainable filter spec. + Trainable parameters are cast to param precision, non-trainable parameters are cast to compute precision. + """ + + trainable, non_trainable = _partition_trainable_params(model, is_trainable) + trainable = mp.cast_to_param(trainable) + non_trainable = mp.cast_to_compute(non_trainable) + model = eqx.combine(trainable, non_trainable) + return model + + +def _make_saveable_trainer_state(trainer_state: S, is_trainable) -> S: + """ + Returns the shape of the trainer state that we save to a checkpoint. This is used to load a checkpoint. + You can override if you really need custom checkpointing logic. By default everything in the trainer state + is saved (except for non-trainable model parameters) + """ + saveable_model = eqx.filter(trainer_state.model, is_trainable) + saveable_state = dataclasses.replace(trainer_state, model=saveable_model) + return saveable_state def _initialize_global_tracker(config, run_id): From 0d6f357524b40900bb7fc6f426fb94c0d330969d Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 29 Dec 2023 23:53:32 -0800 Subject: [PATCH 088/205] almost ready to try out doremi --- src/levanter/doremi.py | 165 +++++++++++++++++++++++++---------------- 1 file changed, 103 insertions(+), 62 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 6c588453c..c04134a12 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,29 +1,28 @@ import dataclasses -from typing import Callable, Iterator, Optional, TypeVar +from typing import Callable, Iterator, Optional, Tuple, TypeVar import equinox as eqx -import jax import jax.numpy as jnp import jax.random as jrandom import optax - -import levanter.tracker -from haliax import Scalar, named_jit from jaxtyping import PRNGKeyArray import haliax as hax from haliax.types import IntScalar +import levanter.tracker from levanter.data import ShardableDataset from levanter.data.mixture import MixtureDataset +from levanter.logging import capture_time from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState -from levanter.types import ComputeLossFunction -from optax._src.base import GradientTransformation - +from levanter.types import ComputeLossFunction, ModuleComputeLoss from levanter.utils.tree_utils import inference_mode + T = TypeVar("T") + +# TODO: should we put ref in the state? If so, need to tell it to not serialize it class DoremiState(TrainerState): alpha: hax.NamedArray average_alpha: hax.NamedArray @@ -36,32 +35,58 @@ def update_alpha(self, alpha): class DoReMiTrainer(Trainer): # we just use the DoReMi trainer for state management - def _initialize_state_from_scratch(self, model_init: Callable[[], M], training_key, is_trainable): - base_state = super()._initialize_state_from_scratch(model_init, training_key, is_trainable) - return DoremiState(base_state.step, - base_state.model, - base_state.opt_state, - base_state.training_key, - self.initial_alpha) + def __init__(self, trainer: TrainerConfig, optimizer: optax.GradientTransformation, initial_alpha: hax.NamedArray): + super().__init__(trainer, optimizer) + self.initial_alpha = initial_alpha + + # TODO: I'd like to not need to override trainer for this + def _initialize_state_from_scratch(self, model: Callable[[], M], training_key, is_trainable): + base_state = super()._initialize_state_from_scratch(model, training_key, is_trainable) + return DoremiState( + base_state.step, base_state.model, base_state.opt_state, base_state.training_key, self.initial_alpha + ) + + +@dataclasses.dataclass +class DoReMiConfig: + # This is designed to be used with estimate_mixture_weights + domain_weight_step_size: float = 1.0 + smoothing: float = 1e-3 + sampling_weights: Optional[dict[str, float]] = None def estimate_mixture_weights( - trainer: TrainerConfig, - loss_fn: ComputeLossFunction, - initial_proxy, - ref, + trainer_config: TrainerConfig, + initial_proxy: M, + ref: M, data_sources: dict[str, ShardableDataset[T]], - ref_weights: dict[str, float], + sampling_weights: Optional[dict[str, float]] = None, *, - domain_weight_step: float = 1.0, + loss_fn: ComputeLossFunction[M, T] = ModuleComputeLoss(), + domain_weight_step_size: float = 1.0, smoothing: float = 1e-3, - eps_alpha: float = 1e-6, key: PRNGKeyArray, ) -> dict[str, float]: """ Estimate the mixture weights for the data sources using DoReMi. https://arxiv.org/abs/2305.10429 + + Args: + trainer_config: Trainer config + initial_proxy: Initial proxy model + ref: Reference model + data_sources: Data sources to estimate the weights for + sampling_weights: Sampling weights for the data sources. If not provided, will use uniform sampling weights. + loss_fn: Loss function to use for the proxy and ref models. If not provided, will use the model's compute_loss + domain_weight_step_size: Step size for the domain weights + smoothing: Smoothing for the domain weights + key: PRNG key """ + if len(data_sources) <= 1: + raise ValueError("Must have at least two data sources") + + ref = _prepare_ref_model(ref, trainer_config) + training_key, data_key = jrandom.split(key) domain_indices = list(data_sources.keys()) domain_to_index = {domain: index for index, domain in enumerate(domain_indices)} @@ -71,7 +96,15 @@ def estimate_mixture_weights( Domain = hax.Axis("domain", len(domain_indices)) initial_alpha = hax.ones(Domain) / Domain.size - trainer = DoReMiTrainer(trainer, optax.adamw(1e-3), ref, initial_alpha, loss_fn) + trainer = DoReMiTrainer(trainer_config, optax.adamw(1e-3), initial_alpha) + + if sampling_weights is not None: + assert set(sampling_weights.keys()) == set(data_sources.keys()) + sampling_weights = { + domain: weight / sum(sampling_weights.values()) for domain, weight in sampling_weights.items() + } + else: + sampling_weights = {domain: 1 / len(data_sources) for domain in data_sources.keys()} # calculate per-token losses for proxy and ref def compute_excess_loss(ref, proxy, batch): @@ -81,7 +114,6 @@ def compute_excess_loss(ref, proxy, batch): excess_losses = proxy_losses - ref_losses return excess_losses - # Loss is alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) def proxy_model_loss(excess_losses, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch @@ -89,44 +121,49 @@ def proxy_model_loss(excess_losses, domains, alpha): # TODO: I'd like to make the syntax for this nicer. einsum would be like # einsum("d,bd,b... -> ()" ro something) # but it's really just collapsing all axes - loss = hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses).scalar() + # maybe we could do something like + # loss = hax.contract_to(alpha, one_hot_domains, excess_losses, out_axes=()) + # or i guess we could just do + # hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses, out_axes=()).scalar() - return loss + return hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses).scalar() - @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True, )) - def doremi_step(state: DoremiState, batch, domains): + @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True,)) + def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) # this is one of those times when PyTorch's backward() is nice - excess_losses, excess_backward = eqx.filter_vjp(lambda proxy: compute_excess_loss(proxy, ref, batch), proxy) + with hax.axis_mapping(trainer.compute_axis_mapping): + excess_losses, excess_backward = eqx.filter_vjp( + lambda proxy: compute_excess_loss(proxy, ref, batch), proxy + ) - # Update domain weights - ## Compute per-domain excess losses - clipped_losses = hax.maximum(excess_losses, 0) - one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - per_domain_losses = hax.dot(excess_losses.axes, one_hot_domains, clipped_losses) + # Compute per-domain excess losses + clipped_losses = hax.maximum(excess_losses, 0) + one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch + per_domain_losses = hax.dot(excess_losses.axes, one_hot_domains, clipped_losses) - alpha = state.alpha * hax.exp(domain_weight_step * per_domain_losses) - alpha /= hax.sum(alpha) - alpha = (1 - smoothing) * alpha + initial_alpha * smoothing + # Update domain weights + alpha = state.alpha * hax.exp(domain_weight_step_size * per_domain_losses) + alpha /= hax.sum(alpha) + alpha = (1 - smoothing) * alpha + initial_alpha * smoothing - alpha_distance = hax.sum(hax.abs(alpha - state.alpha)) - levanter.tracker.log_metrics({"alpha_distance": alpha_distance}, step=state.step) + alpha_distance = hax.sum(hax.abs(alpha - state.alpha)) - # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) - loss, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) - grad = excess_backward(grad_loss) + levanter.tracker.jit_log_metrics({"alpha_distance": alpha_distance}, step=state.step) - new_state = trainer._take_train_step(state, proxy, grad, batch) + # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) + loss, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) + grad = excess_backward(grad_loss) - return loss, new_state + new_state = trainer._take_train_step(state, proxy, grad) + new_state = new_state.update_alpha(alpha) - # TODO: we don't support serializing stuff from anything other than the model and the opt_state. should fix. - running_alpha_mean = initial_alpha + return loss, new_state # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts with trainer: - tagged_mixture = domain_tagged_mixture(data_sources, ref_weights, domain_to_index, key=data_key) - state = trainer.initial_state(training_key, model=initial_proxy) + tagged_mixture = domain_tagged_mixture(data_sources, sampling_weights, domain_to_index, key=data_key) + state: DoremiState = trainer.initial_state(training_key, model=initial_proxy) del initial_proxy train_loader = iter(trainer.sharded_loader(tagged_mixture, trainer.TrainBatch)) @@ -141,25 +178,29 @@ def doremi_step(state: DoremiState, batch, domains): while state.step < trainer.num_train_steps: example, ex_domains = next(train_loader) - key, new_key = jax.random.split(state.training_key) - proxy, alpha = state.model + with capture_time() as step_time: + loss, state = doremi_step(state, ref, example, ex_domains) + loss = loss.item() # type: ignore - loss, new_model, new_optstate = doremi_step( - proxy, state.opt_state, alpha, example, ex_domains, - ) - loss = loss.item() # type: ignore - - new_info = StepInfo(TrainerState(state.step + 1, new_model, new_optstate, new_key), loss, step_time()) + new_info = StepInfo(state, loss, step_time()) trainer.run_hooks(new_info) - state = new_info - + trainer.run_hooks(new_info, force=True) + final_weights = {domain: float(state.average_alpha[Domain, index]) for domain, index in domain_to_index.items()} + levanter.tracker.log_summary({"final_weights": final_weights}) + return final_weights +def _prepare_ref_model(ref, trainer): + return hax.named_jit( + lambda m: trainer.mp.cast_to_compute(inference_mode(m, True)), + axis_resources=trainer.parameter_axis_mapping, + donate_args=True, + )(ref) def domain_tagged_mixture( @@ -168,20 +209,20 @@ def domain_tagged_mixture( domain_to_index: dict[str, int], *, key: PRNGKeyArray, -) -> MixtureDataset[(T, IntScalar)]: +) -> MixtureDataset[Tuple[T, IntScalar]]: """ Domain tagged mixture dataset. This dataset will yield from the datasets according to the weights, and will yield the domain index as a second element of the tuple. """ tagged_datasets = { - domain_index: DomainTaggedDataset(data_sources[domain], domain_index) + domain: DomainTaggedDataset(data_sources[domain], domain_index) for domain, domain_index in domain_to_index.items() } return MixtureDataset(tagged_datasets, weights, key=key) -class DomainTaggedDataset(ShardableDataset[(T, hax.NamedArray)]): # named array is a scalar int +class DomainTaggedDataset(ShardableDataset[Tuple[T, hax.NamedArray]]): # named array is a scalar int def __init__( self, dataset: ShardableDataset[T], @@ -197,6 +238,6 @@ def __init__( def shard(self, shard_id: int, num_shards: int) -> "DomainTaggedDataset[T]": return DomainTaggedDataset(self.dataset.shard(shard_id, num_shards), self.domain_index) - def __iter__(self) -> Iterator[(T, IntScalar)]: + def __iter__(self) -> Iterator[Tuple[T, hax.NamedArray]]: for item in self.dataset: yield item, self.domain_index From 7395e3c482e3cd150c2cc2a7473e5eedd86857c9 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 2 Jan 2024 00:25:37 -0800 Subject: [PATCH 089/205] almost ready to try out doremi --- src/levanter/doremi.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index c04134a12..db17bdd54 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -65,6 +65,7 @@ def estimate_mixture_weights( loss_fn: ComputeLossFunction[M, T] = ModuleComputeLoss(), domain_weight_step_size: float = 1.0, smoothing: float = 1e-3, + weight_change_eps: float = 1e-3, key: PRNGKeyArray, ) -> dict[str, float]: """ @@ -114,7 +115,7 @@ def compute_excess_loss(ref, proxy, batch): excess_losses = proxy_losses - ref_losses return excess_losses - # Loss is alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) + # Loss is \sum_d alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) def proxy_model_loss(excess_losses, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch # basically einsum(" * -> ", alpha, one_hot_domains, excess_losses) @@ -131,25 +132,21 @@ def proxy_model_loss(excess_losses, domains, alpha): @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True,)) def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) - # this is one of those times when PyTorch's backward() is nice with hax.axis_mapping(trainer.compute_axis_mapping): + # this is one of those times when PyTorch's backward() is nice excess_losses, excess_backward = eqx.filter_vjp( lambda proxy: compute_excess_loss(proxy, ref, batch), proxy ) - # Compute per-domain excess losses clipped_losses = hax.maximum(excess_losses, 0) - one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - per_domain_losses = hax.dot(excess_losses.axes, one_hot_domains, clipped_losses) + per_domain_losses = _compute_per_domain_losses(Domain, domains, clipped_losses) # Update domain weights alpha = state.alpha * hax.exp(domain_weight_step_size * per_domain_losses) alpha /= hax.sum(alpha) alpha = (1 - smoothing) * alpha + initial_alpha * smoothing - alpha_distance = hax.sum(hax.abs(alpha - state.alpha)) - - levanter.tracker.jit_log_metrics({"alpha_distance": alpha_distance}, step=state.step) + distance_from_uniform = hax.sum(hax.abs(alpha - initial_alpha)) # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) loss, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) @@ -158,6 +155,16 @@ def doremi_step(state: DoremiState, ref, batch, domains): new_state = trainer._take_train_step(state, proxy, grad) new_state = new_state.update_alpha(alpha) + alpha_distance = hax.sum(hax.abs(new_state.average_alpha - state.average_alpha)) + + levanter.tracker.jit_log_metrics( + { + "change_in_alpha": alpha_distance, + "alpha_distance_from_uniform": distance_from_uniform, + }, + step=state.step, + ) + return loss, new_state # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts @@ -241,3 +248,9 @@ def shard(self, shard_id: int, num_shards: int) -> "DomainTaggedDataset[T]": def __iter__(self) -> Iterator[Tuple[T, hax.NamedArray]]: for item in self.dataset: yield item, self.domain_index + + +def _compute_per_domain_losses(Domain, domains, losses): + one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch + per_domain_losses = hax.dot(losses.axes, one_hot_domains, losses) + return per_domain_losses From 08996e60a8ceef3d25ce97f3a410f440304a864e Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 3 Jan 2024 00:13:27 -0800 Subject: [PATCH 090/205] cleanup typing.overloads --- src/levanter/doremi.py | 54 ++++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index db17bdd54..9834541ea 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,4 +1,5 @@ import dataclasses +import logging from typing import Callable, Iterator, Optional, Tuple, TypeVar import equinox as eqx @@ -19,6 +20,9 @@ from levanter.utils.tree_utils import inference_mode +logger = logging.getLogger(__name__) + + T = TypeVar("T") @@ -116,19 +120,6 @@ def compute_excess_loss(ref, proxy, batch): return excess_losses # Loss is \sum_d alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) - def proxy_model_loss(excess_losses, domains, alpha): - one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - # basically einsum(" * -> ", alpha, one_hot_domains, excess_losses) - # TODO: I'd like to make the syntax for this nicer. einsum would be like - # einsum("d,bd,b... -> ()" ro something) - # but it's really just collapsing all axes - # maybe we could do something like - # loss = hax.contract_to(alpha, one_hot_domains, excess_losses, out_axes=()) - # or i guess we could just do - # hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses, out_axes=()).scalar() - - return hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses).scalar() - @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True,)) def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) @@ -149,23 +140,25 @@ def doremi_step(state: DoremiState, ref, batch, domains): distance_from_uniform = hax.sum(hax.abs(alpha - initial_alpha)) # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) - loss, grad_loss = eqx.filter_value_and_grad(proxy_model_loss)(excess_losses, domains, alpha) + loss, grad_loss = eqx.filter_value_and_grad(_domain_weighted_loss)(excess_losses, Domain, domains, alpha) grad = excess_backward(grad_loss) new_state = trainer._take_train_step(state, proxy, grad) new_state = new_state.update_alpha(alpha) alpha_distance = hax.sum(hax.abs(new_state.average_alpha - state.average_alpha)) + alpha_dict = _alpha_weights_to_dict(Domain, new_state.average_alpha, domain_to_index) levanter.tracker.jit_log_metrics( { "change_in_alpha": alpha_distance, "alpha_distance_from_uniform": distance_from_uniform, + "alpha": alpha_dict, }, step=state.step, ) - return loss, new_state + return loss, alpha_distance, new_state # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts with trainer: @@ -186,19 +179,30 @@ def doremi_step(state: DoremiState, ref, batch, domains): example, ex_domains = next(train_loader) with capture_time() as step_time: - loss, state = doremi_step(state, ref, example, ex_domains) + loss, alpha_distance, state = doremi_step(state, ref, example, ex_domains) loss = loss.item() # type: ignore new_info = StepInfo(state, loss, step_time()) trainer.run_hooks(new_info) + # check convergence for alphas + if alpha_distance.item() < weight_change_eps: + logger.info(f"Converged on alpha at step {state.step}: {alpha_distance:.4f}") + break + trainer.run_hooks(new_info, force=True) - final_weights = {domain: float(state.average_alpha[Domain, index]) for domain, index in domain_to_index.items()} + alpha = state.average_alpha + final_weights = _alpha_weights_to_dict(Domain, alpha, domain_to_index) + + levanter.tracker.log_summary({"final_alpha": final_weights}) + + return final_weights - levanter.tracker.log_summary({"final_weights": final_weights}) +def _alpha_weights_to_dict(Domain, alpha, domain_name_to_index): + final_weights = {domain: float(alpha[Domain, index]) for domain, index in domain_name_to_index.items()} return final_weights @@ -254,3 +258,17 @@ def _compute_per_domain_losses(Domain, domains, losses): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch per_domain_losses = hax.dot(losses.axes, one_hot_domains, losses) return per_domain_losses + + +def _domain_weighted_loss(losses, Domain, domains, alpha): + one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch + # basically einsum(" * -> ", alpha, one_hot_domains, excess_losses) + # TODO: I'd like to make the syntax for this nicer. einsum would be like + # einsum("d,bd,b... -> ()" or something) + # but it's really just collapsing all axes + # maybe we could do something like + # loss = hax.contract_to(alpha, one_hot_domains, excess_losses, out_axes=()) + # or i guess we could just do + # hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses, out_axes=()).scalar() + + return hax.dot(losses.axes + (Domain,), alpha, one_hot_domains, losses).scalar() From 710900c6fcb79dc3f6366608bd7c584907b420e0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 3 Jan 2024 10:46:40 -0800 Subject: [PATCH 091/205] use auto_sharded internally, undeprecate it b/c it has a point --- src/levanter/tensorstore_serialization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 51a253163..f75ee87ff 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -17,6 +17,7 @@ from tensorstore import TensorStore import haliax as hax +import haliax.tree_util as htu from haliax.partitioning import ResourceMapping from haliax.util import is_named_array @@ -138,7 +139,7 @@ def tree_deserialize_leaves_tensorstore( """ # TODO: support ShapeDtypeStructs that are not NamedArrays leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) - specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=is_named_array) + specs = htu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths) deser_partial = functools.partial(_deserialize_one_leaf, axis_mapping=axis_mapping, mesh=mesh) From 5f9d96d992cfe2a2a1d9a53227f476d3008caac1 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 4 Jan 2024 09:01:12 -0800 Subject: [PATCH 092/205] fix docs --- docs/Configuration-Guide.md | 4 ++-- src/levanter/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index f203f0dcc..8336b1eb8 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -290,7 +290,7 @@ If you're not using SLURM or TPUs, you can specify the cluster manually using th ## Optimizer -[levanter.trainer.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: +[levanter.optim.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: | Parameter | Description | Default | |-----------------|-------------------------------------------------------------------|----------| @@ -358,7 +358,7 @@ trainer: ### Optimizer -::: levanter.trainer.OptimizerConfig +::: levanter.optim.OptimizerConfig ### LM Model diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index fdfdb40fc..0309e49d2 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -175,7 +175,7 @@ def __init__( Args: config: the trainer config - optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.trainer.OptimizerConfig][] + optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.optim.OptimizerConfig][] loss_fn (Callable): the loss function. This should be a function that takes a model and some inputs and returns a scalar loss. It should be jit-able and should not have any side effects. """ From 04a74a181bcfb6935977e1af9e835e72f9ea9393 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 4 Jan 2024 09:01:50 -0800 Subject: [PATCH 093/205] use new dot syntax in doremi --- src/levanter/doremi.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 9834541ea..3b79f5d9e 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -262,13 +262,4 @@ def _compute_per_domain_losses(Domain, domains, losses): def _domain_weighted_loss(losses, Domain, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - # basically einsum(" * -> ", alpha, one_hot_domains, excess_losses) - # TODO: I'd like to make the syntax for this nicer. einsum would be like - # einsum("d,bd,b... -> ()" or something) - # but it's really just collapsing all axes - # maybe we could do something like - # loss = hax.contract_to(alpha, one_hot_domains, excess_losses, out_axes=()) - # or i guess we could just do - # hax.dot(excess_losses.axes + (Domain,), alpha, one_hot_domains, excess_losses, out_axes=()).scalar() - - return hax.dot(losses.axes + (Domain,), alpha, one_hot_domains, losses).scalar() + return hax.dot(alpha, one_hot_domains, losses, axis=None) From 6a20c95cb4eca8720ddaadf8515ba80e79ad3cb6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 12:15:28 -0800 Subject: [PATCH 094/205] fix mixture init with prngkey --- src/levanter/data/mixture.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index dbe255748..71556833a 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -2,7 +2,6 @@ import jax.random import numpy as np -from jax.random import PRNGKey from jaxtyping import PRNGKeyArray from haliax.util import StringHolderEnum @@ -48,7 +47,7 @@ def __init__( self.stop_strategy = stop_strategy if not isinstance(key, int): - key = jax.random.randint(PRNGKey(key)[0], (), 0, 2**31).item() + key = jax.random.randint(key, (), 0, 2**20).item() self.key = key From fd6d343d636e5bffd80a7697c1b35f1aabc7ee08 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 12:15:49 -0800 Subject: [PATCH 095/205] add a simple InMemoryDataset that takes a list --- src/levanter/data/dataset.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index 3c49910a6..14c8979b3 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -24,6 +24,17 @@ def __iter__(self) -> Iterator[T]: raise NotImplementedError +class InMemoryDataset(ShardableDataset[T]): + def __init__(self, items: List[T]): + self.items = items + + def __iter__(self) -> Iterator[T]: + return iter(self.items) + + def shard(self, shard_id: int, num_shards: int) -> "InMemoryDataset[T]": + return InMemoryDataset(self.items[shard_id::num_shards]) + + class ShuffleDataset(ShardableDataset[T]): def __init__(self, dataset: Dataset[T], key: PRNGKey, buffer_size: int): self.dataset = dataset From f5b8d00bf5a78811bb57a55e9617443de538cd89 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 12:18:15 -0800 Subject: [PATCH 096/205] make keyiterator support just an int seed --- src/levanter/utils/jax_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 2cb275d70..a1253b500 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -161,7 +161,9 @@ def join_key(prefix, k): return f"{prefix}.{k}" if prefix else k -def key_iterator(key: PRNGKeyArray): +def key_iterator(key: PRNGKeyArray | int): + if isinstance(key, int): + key = jax.random.PRNGKey(key) while True: key, subkey = jax.random.split(key) yield subkey From 288e7fb9b2bbc06b35a0644550778f31b2027dd8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 12:18:41 -0800 Subject: [PATCH 097/205] dumb bug in grad accum --- src/levanter/grad_accum.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 9b92e7b3a..39258665a 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -61,6 +61,9 @@ def microbatched( physical_axis_name = hax.partitioning.physical_axis_name(Batch, compute_axis_mapping) assert physical_axis_name is not None + if per_device_parallelism < 0: + raise ValueError(f"Bad value for {per_device_parallelism=}") + microbatch_size = data_axis_size * per_device_parallelism num_micro_steps = batch_size // microbatch_size Microbatch = Batch.resize(microbatch_size) @@ -138,7 +141,7 @@ def _zeros_like_tree(r_shape, axis_mapping, accum_dtype): def _zeros_like(mapping, dtype, n): if isinstance(n, hax.NamedArray): - return hax.auto_sharded(hax.zeros_like(n, dtype=dtype), mapping) + return hax.shard(hax.zeros_like(n, dtype=dtype), mapping) elif is_jax_array_like(n): return jnp.zeros_like(n, dtype) else: From c4da125aa84723bc33d28e0dee47b9e85b4ed185 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 12:18:57 -0800 Subject: [PATCH 098/205] fix some dumb bugs in new trainer --- src/levanter/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 0309e49d2..0d428e812 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -409,8 +409,7 @@ def training_steps( levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) state = info.state - - yield info + yield info def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ @@ -490,7 +489,7 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal loss, grads = self._compute_gradients_microbatched(model, batch, **batch_kwargs, key=key) new_state = self._take_train_step(state, model, grads, *batch, **batch_kwargs, key=key) - new_state = dataclasses.replace(new_state, _step=state._step + 1, training_key=new_key) + new_state = dataclasses.replace(new_state, training_key=new_key) return loss, new_state @@ -520,7 +519,7 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: opt_state = self.optimizer.update_hessian(opt_state, self.loss_fn, model, *batch, **batch_kwargs) model = eqx.apply_updates(model, updates) - return dataclasses.replace(state, model=model, opt_state=opt_state) + return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) def _initialize_state_from_scratch(self, model, training_key, is_trainable): # only force trainable params to param precision. Other params are cast to compute precision From 92575976f125447ebcb970d96f6b4524a965abc0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 13:32:46 -0800 Subject: [PATCH 099/205] test for doremi and associated fixes --- src/levanter/doremi.py | 36 ++++++---- tests/test_doremi.py | 152 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 15 deletions(-) create mode 100644 tests/test_doremi.py diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 3b79f5d9e..10dd07228 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -32,7 +32,7 @@ class DoremiState(TrainerState): average_alpha: hax.NamedArray def update_alpha(self, alpha): - average_alpha = self.average_alpha + (alpha - self.average_alpha) / (self.step + 1) + average_alpha = self.average_alpha + (alpha - self.average_alpha) / (self._step + 1) return dataclasses.replace(self, alpha=alpha, average_alpha=average_alpha) @@ -46,9 +46,8 @@ def __init__(self, trainer: TrainerConfig, optimizer: optax.GradientTransformati # TODO: I'd like to not need to override trainer for this def _initialize_state_from_scratch(self, model: Callable[[], M], training_key, is_trainable): base_state = super()._initialize_state_from_scratch(model, training_key, is_trainable) - return DoremiState( - base_state.step, base_state.model, base_state.opt_state, base_state.training_key, self.initial_alpha - ) + + return DoremiState(**base_state.__dict__, alpha=self.initial_alpha, average_alpha=self.initial_alpha) @dataclasses.dataclass @@ -59,13 +58,19 @@ class DoReMiConfig: sampling_weights: Optional[dict[str, float]] = None +DEFAULT_DOREMI_TRAINER_CONFIG = TrainerConfig( + num_train_steps=10000, + train_batch_size=512, +) + + def estimate_mixture_weights( - trainer_config: TrainerConfig, initial_proxy: M, ref: M, data_sources: dict[str, ShardableDataset[T]], sampling_weights: Optional[dict[str, float]] = None, *, + trainer_config: TrainerConfig = DEFAULT_DOREMI_TRAINER_CONFIG, loss_fn: ComputeLossFunction[M, T] = ModuleComputeLoss(), domain_weight_step_size: float = 1.0, smoothing: float = 1e-3, @@ -90,8 +95,6 @@ def estimate_mixture_weights( if len(data_sources) <= 1: raise ValueError("Must have at least two data sources") - ref = _prepare_ref_model(ref, trainer_config) - training_key, data_key = jrandom.split(key) domain_indices = list(data_sources.keys()) domain_to_index = {domain: index for index, domain in enumerate(domain_indices)} @@ -102,6 +105,8 @@ def estimate_mixture_weights( initial_alpha = hax.ones(Domain) / Domain.size trainer = DoReMiTrainer(trainer_config, optax.adamw(1e-3), initial_alpha) + with trainer: + ref = _prepare_ref_model(ref, trainer_config) if sampling_weights is not None: assert set(sampling_weights.keys()) == set(data_sources.keys()) @@ -112,7 +117,7 @@ def estimate_mixture_weights( sampling_weights = {domain: 1 / len(data_sources) for domain in data_sources.keys()} # calculate per-token losses for proxy and ref - def compute_excess_loss(ref, proxy, batch): + def compute_excess_loss(proxy, ref, batch): proxy_losses = loss_fn(proxy, batch, reduction_axis=()) ref_losses = loss_fn(ref, batch, reduction_axis=()) # calculate excess losses @@ -130,6 +135,7 @@ def doremi_step(state: DoremiState, ref, batch, domains): ) clipped_losses = hax.maximum(excess_losses, 0) + per_domain_losses = _compute_per_domain_losses(Domain, domains, clipped_losses) # Update domain weights @@ -141,7 +147,7 @@ def doremi_step(state: DoremiState, ref, batch, domains): # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) loss, grad_loss = eqx.filter_value_and_grad(_domain_weighted_loss)(excess_losses, Domain, domains, alpha) - grad = excess_backward(grad_loss) + grad = excess_backward(grad_loss)[0] new_state = trainer._take_train_step(state, proxy, grad) new_state = new_state.update_alpha(alpha) @@ -155,7 +161,7 @@ def doremi_step(state: DoremiState, ref, batch, domains): "alpha_distance_from_uniform": distance_from_uniform, "alpha": alpha_dict, }, - step=state.step, + step=state._step, ) return loss, alpha_distance, new_state @@ -193,16 +199,16 @@ def doremi_step(state: DoremiState, ref, batch, domains): trainer.run_hooks(new_info, force=True) - alpha = state.average_alpha - final_weights = _alpha_weights_to_dict(Domain, alpha, domain_to_index) + alpha = state.average_alpha + final_weights = _alpha_weights_to_dict(Domain, alpha, domain_to_index) - levanter.tracker.log_summary({"final_alpha": final_weights}) + levanter.tracker.log_summary({"final_alpha": final_weights}) - return final_weights + return {k: float(v) for k, v in final_weights.items()} def _alpha_weights_to_dict(Domain, alpha, domain_name_to_index): - final_weights = {domain: float(alpha[Domain, index]) for domain, index in domain_name_to_index.items()} + final_weights = {domain: alpha[Domain, index] for domain, index in domain_name_to_index.items()} return final_weights diff --git a/tests/test_doremi.py b/tests/test_doremi.py new file mode 100644 index 000000000..dbb9a9889 --- /dev/null +++ b/tests/test_doremi.py @@ -0,0 +1,152 @@ +import equinox +import jax.random +import optax + +import haliax as hax + +from levanter.callbacks import eval_loss_loop +from levanter.data.dataset import ShardableDataset +from levanter.data.mixture import MixtureDataset +from levanter.trainer import Trainer, TrainerConfig +from levanter.utils.jax_utils import key_iterator +from levanter.utils.py_utils import non_caching_cycle + + +class Example(equinox.Module): + x: hax.NamedArray + y: hax.NamedArray + + +Block = hax.Axis("Block", 1024) + + +class LogitDataset(ShardableDataset[Example]): + def __init__(self, W, noise, x_mask, x_bias, *, key): + self.W = W + self.noise = noise + self.x_mask = x_mask + self.x_bias = x_bias + self.key = key + + def __iter__(self): + key_iter = key_iterator(self.key) + Dim = self.W.axes[0] + while True: + x_block = hax.random.normal(next(key_iter), (Block, Dim)) * self.x_mask + self.x_bias + noise = hax.random.normal(next(key_iter), (Block,)) * self.noise + y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float) + for i in range(Block.size): + yield Example(x=x_block[Block, i], y=hax.named(y_block[Block, i], ())) + + def shard(self, shard_id: int, num_shards: int): + return LogitDataset(self.W, self.noise, self.x_mask, self.x_bias, key=jax.random.fold_in(self.key, shard_id)) + + +def test_estimate_mixture_weights(): + # we create 3 simple logistic regression datasets + # 1. x is moderately predictive of y (y ~ [0, 0.5, 0.5] x + N(0, noise^2) > 0.5) + # 2. x is not predictive of y at all, y is highly random (y ~ N(0, 1)) + # 3. x is highly predictive of y, but it's very easy (y = sigmoid([1, 0, 0] x > 0.5) + + Dim = hax.Axis("Dim", 5) + Batch = hax.Axis("Batch", 32) + + keys = key_iterator(0) + + # W = hax.random.normal(next(keys), (Dim,)) + W1 = hax.named([0.0, 0.5, 0.5, 0.0, 0.0], (Dim,)) + x1_mask = hax.named([0.0, 1.0, 1.0, 0.0, 0.0], (Dim,)) + W2 = hax.named([0.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x2_mask = hax.named([0.0, 0.0, 0.0, 1.0, 1.0], (Dim,)) + W3 = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x3_mask = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x3_bias = hax.named([4.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + + # y = sigmoid(Wx + b + N(0, noise^2)) > 0.5 + ds1 = LogitDataset(W1, 0.1, x1_mask, 0.0, key=next(keys)) + ds2 = LogitDataset(W2, 2.0, x2_mask, 0.0, key=next(keys)) + ds3 = LogitDataset(W3, 0.0, x3_mask, x3_bias, key=next(keys)) + + # TODO: remove key as a requirement for models + def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None): + del key + y_pred = model(example.x) + return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis) + + tiny_trainer_config = TrainerConfig( + num_train_steps=600, train_batch_size=Batch.size, tracker=(), id="kmaklfmaf", per_device_parallelism=Batch.size + ) + + optimizer = optax.adam(1e-2) + import jax + + jax.config.update("jax_traceback_filtering", "off") + + trainer = Trainer(tiny_trainer_config, optimizer, compute_loss_fn) + + def fit_to_dataset(dataset): + initial_model = init_model() + with trainer: + state = trainer.initial_state(next(keys), model=initial_model) + loader = trainer.replicated_loader(dataset, Batch) + loader = non_caching_cycle(loader) + + loss = 0.0 + + # state = trainer.train(state, loader, run_hooks=False) + for state in trainer.training_steps(state, loader, run_hooks=False): + if state.step >= 200: + loss += state.loss + + return state.model, (loss / (state.step - 200)) + + def init_model(): + return hax.nn.Linear.init( + Dim, + (), + use_bias=True, + key=next(keys), + ) + + m1, loss1 = fit_to_dataset(ds1) + m2, loss2 = fit_to_dataset(ds2) + m3, loss3 = fit_to_dataset(ds3) + + assert loss3 < loss1 < loss2 + + datasets = {"d1": ds1, "d2": ds2, "d3": ds3} + + ref_model, ref_loss = fit_to_dataset(MixtureDataset(datasets, weights={k: 1 / 3.0 for k in datasets.keys()})) + + # let's see the loss on each dataset + l1_ref = eval_loss_loop( + compute_loss_fn, ref_model, trainer.replicated_loader(ds1, Batch), max_batches=10, name="d1" + ) + l2_ref = eval_loss_loop( + compute_loss_fn, ref_model, trainer.replicated_loader(ds2, Batch), max_batches=10, name="d2" + ) + l3_ref = eval_loss_loop( + compute_loss_fn, ref_model, trainer.replicated_loader(ds3, Batch), max_batches=10, name="d3" + ) + + assert l3_ref < l1_ref < l2_ref + + from levanter.doremi import estimate_mixture_weights + + w = estimate_mixture_weights( + initial_proxy=init_model(), + ref=ref_model, + data_sources=datasets, + weight_change_eps=1e-4, + trainer_config=tiny_trainer_config, + key=next(keys), + loss_fn=compute_loss_fn, + ) + + w1 = w["d1"] + w2 = w["d2"] + w3 = w["d3"] + + assert w1 > w3 > w2 + assert abs(w1 + w2 + w3 - 1.0) < 1e-3 + assert w2 < 0.05 # the noise distribution should get a very low weight From 317b10dd2b2984840b9ef19e796409f8a932f1d3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 13:33:42 -0800 Subject: [PATCH 100/205] depend on haliax dev for levanter dev --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2e134974d..342e9bd45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ # jax = {version = ">=0.4.10,<0.5.0"} # "haliax>=1.3,<2.0", # Haliax changes in step with levanter, so we'll just use the git version except for releases. - "haliax @ git+https://github.com/stanford-crfm/haliax.git", + "haliax @ git+https://github.com/stanford-crfm/haliax.git@dev", "equinox>=0.10.7", "jaxtyping>=0.2.20", "transformers>=4.22.0", From e4d1385189da2d8340368e424c0537230991819e Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 13:54:43 -0800 Subject: [PATCH 101/205] fix gsm8k_lora --- examples/gsm8k-lora/gsm8k_lora.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 7531d2411..0f4ba005f 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -148,12 +148,7 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key) - - # end major difference from Alpaca - - with Trainer(config.trainer, optimizer, compute_loss) as trainer: + with Trainer(config.trainer, optimizer) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping From ddcdac75b73637444b5e656f411b43fb7f6f85ce Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 14:09:03 -0800 Subject: [PATCH 102/205] add a small_pile configuration --- config/gpt2_small_pile.yaml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 config/gpt2_small_pile.yaml diff --git a/config/gpt2_small_pile.yaml b/config/gpt2_small_pile.yaml new file mode 100644 index 000000000..ee59ffe6d --- /dev/null +++ b/config/gpt2_small_pile.yaml @@ -0,0 +1,23 @@ +data: !include data/pile_source_old.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "pile", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 512 + num_train_steps: 50000 +optimizer: + learning_rate: 6e-4 + weight_decay: 0.1 From 792f7691991c41f533a2bd70d7e3791b648e7a6b Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 9 Jan 2024 14:11:07 -0800 Subject: [PATCH 103/205] make it len 2048 --- config/gpt2_small_pile.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/gpt2_small_pile.yaml b/config/gpt2_small_pile.yaml index ee59ffe6d..19512c3dd 100644 --- a/config/gpt2_small_pile.yaml +++ b/config/gpt2_small_pile.yaml @@ -4,7 +4,7 @@ model: hidden_dim: 768 num_heads: 12 num_layers: 12 - seq_len: 1024 + seq_len: 2048 gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: @@ -16,7 +16,7 @@ trainer: model_axis_size: 1 per_device_parallelism: 8 - train_batch_size: 512 + train_batch_size: 256 num_train_steps: 50000 optimizer: learning_rate: 6e-4 From e16b3af78977bd7e3cbe8c8f1a1581c50b04dad5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 09:17:33 -0800 Subject: [PATCH 104/205] add doremi main --- config/doremi/doremi_nano.yaml | 28 +++++++ src/levanter/data/text.py | 7 ++ src/levanter/doremi.py | 31 ++++++- src/levanter/main/doremi_lm.py | 145 +++++++++++++++++++++++++++++++++ 4 files changed, 207 insertions(+), 4 deletions(-) create mode 100644 config/doremi/doremi_nano.yaml create mode 100644 src/levanter/main/doremi_lm.py diff --git a/config/doremi/doremi_nano.yaml b/config/doremi/doremi_nano.yaml new file mode 100644 index 000000000..397e91239 --- /dev/null +++ b/config/doremi/doremi_nano.yaml @@ -0,0 +1,28 @@ +data: + configs: + wikitext: + id: dlwh/wikitext_103_detokenized + w2: + id: dlwh/wikitext_103_detokenized + train_weights: + wikitext: 0.5 + w2: 0.5 +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index a2cfa947d..f952a882a 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -658,6 +658,13 @@ def train_set( token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} return MixtureDataset(datasets=token_datasets, weights=self.train_weights) + def training_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, ShardableDataset[np.ndarray]]: + doc_caches = self.build_caches("train", monitors=monitors) + token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} + return token_datasets + def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Mapping[str, ShardableDataset[np.ndarray]]: diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 10dd07228..cf10345a3 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -12,6 +12,7 @@ from haliax.types import IntScalar import levanter.tracker +from levanter.callbacks import eval_loss_loop from levanter.data import ShardableDataset from levanter.data.mixture import MixtureDataset from levanter.logging import capture_time @@ -56,6 +57,7 @@ class DoReMiConfig: domain_weight_step_size: float = 1.0 smoothing: float = 1e-3 sampling_weights: Optional[dict[str, float]] = None + weight_change_eps: float = 1e-3 DEFAULT_DOREMI_TRAINER_CONFIG = TrainerConfig( @@ -70,7 +72,9 @@ def estimate_mixture_weights( data_sources: dict[str, ShardableDataset[T]], sampling_weights: Optional[dict[str, float]] = None, *, + validation_sets: Optional[dict[str, ShardableDataset[T]]] = None, trainer_config: TrainerConfig = DEFAULT_DOREMI_TRAINER_CONFIG, + optimizer: optax.GradientTransformation = optax.adamw(1e-3), loss_fn: ComputeLossFunction[M, T] = ModuleComputeLoss(), domain_weight_step_size: float = 1.0, smoothing: float = 1e-3, @@ -104,10 +108,26 @@ def estimate_mixture_weights( Domain = hax.Axis("domain", len(domain_indices)) initial_alpha = hax.ones(Domain) / Domain.size - trainer = DoReMiTrainer(trainer_config, optax.adamw(1e-3), initial_alpha) + trainer = DoReMiTrainer(trainer_config, optimizer, initial_alpha) with trainer: ref = _prepare_ref_model(ref, trainer_config) + if validation_sets is not None: + for domain, dataset in validation_sets.items(): + loss = eval_loss_loop( + trainer.loss_fn, + ref, + trainer.replicated_loader(dataset, trainer.EvalBatch), + name=f"ref {domain}", + max_batches=trainer_config.max_eval_batches, + ) + print(f"Loss of ref model on domain {domain}: {loss:.3f}") + levanter.tracker.log_summary({f"eval/ref/loss/{domain}": loss}) + + if validation_sets is not None: + for domain, dataset in validation_sets.items(): + trainer.add_eval_hook(dataset, name=domain) + if sampling_weights is not None: assert set(sampling_weights.keys()) == set(data_sources.keys()) sampling_weights = { @@ -261,11 +281,14 @@ def __iter__(self) -> Iterator[Tuple[T, hax.NamedArray]]: def _compute_per_domain_losses(Domain, domains, losses): + # TODO: this should weight by masked tokens one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - per_domain_losses = hax.dot(losses.axes, one_hot_domains, losses) - return per_domain_losses + return hax.mean(losses.broadcast_axis(Domain) * one_hot_domains, axis=losses.axes) + # per_domain_losses = hax.dot(losses.axes, one_hot_domains, losses) + # return per_domain_losses def _domain_weighted_loss(losses, Domain, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - return hax.dot(alpha, one_hot_domains, losses, axis=None) + return hax.mean(losses.broadcast_axis(Domain) * one_hot_domains * alpha, axis=None) + # return hax.dot(alpha, one_hot_domains, losses, axis=None) diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py new file mode 100644 index 000000000..5f9fa7862 --- /dev/null +++ b/src/levanter/main/doremi_lm.py @@ -0,0 +1,145 @@ +import logging +from dataclasses import dataclass, field +from typing import Union + +import equinox as eqx +import jax.random as jrandom + +from haliax import Axis +from haliax.partitioning import named_jit, round_axis_for_partitioning + +import levanter +from levanter.compat.hf_checkpoints import HFCompatConfig +from levanter.data.text import CausalLmDataset, LMMixtureDatasetConfig +from levanter.doremi import DoReMiConfig, estimate_mixture_weights +from levanter.models.gpt2 import Gpt2Config +from levanter.models.lm_model import LmConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import TrainerConfig +from levanter.utils.tree_utils import inference_mode + + +logger = logging.getLogger(__name__) + + +@dataclass +class TrainLmConfig: + ref_model_path: str + ref_model_from_hf: bool = False + + data: LMMixtureDatasetConfig = field(default_factory=LMMixtureDatasetConfig) + trainer: TrainerConfig = field(default_factory=TrainerConfig) + model: LmConfig = field(default_factory=Gpt2Config) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + doremi: DoReMiConfig = field(default_factory=DoReMiConfig) + + # config related to continued pretraining + initialize_from_hf: Union[bool, str] = False + """if provided, this will override the model config in the config. if true, use the default hf checkpoint for this model class""" + use_hf_model_config: bool = False # if true, replace the model config with the hf config from the checkpoint + + # TODO: atm we don't support loading from a checkpoint that has a different tokenizer. this is a bit annoying + # TODO: atm you have to at least specify a levanter model config with the same type as the hf checkpoint + + +def main(config: TrainLmConfig): + levanter.initialize(config) + + tokenizer = config.data.the_tokenizer + + # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, + # I recommend skipping it for now + if config.initialize_from_hf: + if config.trainer.initialize_from is not None: + raise ValueError("Cannot specify both initialize_from_hf and initialize_from") + + assert isinstance(config.model, HFCompatConfig) + converter = config.model.default_hf_checkpoint_converter + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + + if isinstance(config.initialize_from_hf, str): + converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer) + else: + converter = converter.replaced(tokenizer=tokenizer) + + if config.use_hf_model_config: + # TODO: log diff of old and new config + # NB: gross mutability + config.model = converter.config_from_hf_config(converter.default_hf_config) + elif isinstance(config.model, HFCompatConfig): + converter = config.model.default_hf_checkpoint_converter + converter = converter.replaced(tokenizer=tokenizer) + else: + converter = None + + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + parameter_axis_mapping = config.trainer.parameter_axis_mapping + + with config.trainer.device_mesh: + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + # initialize the ref model + if config.ref_model_from_hf: + assert converter is not None + ref_model = converter.load_pretrained(type(config.model)) + else: + ref_model_shape = eqx.filter_eval_shape(config.model.build, Vocab, key=jrandom.PRNGKey(0)) + ref_model = levanter.checkpoint.load_checkpoint( + ref_model_shape, config.ref_model_path, axis_mapping=parameter_axis_mapping, subpath="model" + ) + + ref_model = inference_mode(ref_model, True) + + training_key, model_key = jrandom.split(jrandom.PRNGKey(config.trainer.seed), 2) + + @named_jit(axis_resources=parameter_axis_mapping) + def init_proxy_model(): + return config.model.build(Vocab, key=model_key) + + proxy_model = init_proxy_model() + + train_datasets = config.data.training_sets(ref_model.Pos.size) + valid_datasets = config.data.validation_sets(ref_model.Pos.size) + + train_datasets = { + k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + for k, v in train_datasets.items() + } + valid_datasets = { + k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + for k, v in valid_datasets.items() + } + + mixture_weights = estimate_mixture_weights( + proxy_model, + ref=ref_model, + data_sources=train_datasets, + trainer_config=config.trainer, + optimizer=optimizer, + weight_change_eps=config.doremi.weight_change_eps, + domain_weight_step_size=config.doremi.domain_weight_step_size, + sampling_weights=config.doremi.sampling_weights, + validation_sets=valid_datasets, + key=training_key, + ) + + print(mixture_weights) + + # dump to a yaml file + weights_path = "mixture_weights.yaml" + with open(weights_path, "w") as f: + import yaml + + yaml.dump(mixture_weights, f) + + # log as an artifact + levanter.tracker.current_tracker().log_artifact(weights_path, name="mixture_weights.yaml") + + +if __name__ == "__main__": + levanter.config.main(main)() From a272ca9a2ed34b8853833d29e4a8a4c0be6f9382 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 10:54:43 -0800 Subject: [PATCH 105/205] we install haliax from source with the pyprojec.toml --- .github/workflows/run_entry_tests.yaml | 2 -- .github/workflows/run_tests.yaml | 2 -- 2 files changed, 4 deletions(-) diff --git a/.github/workflows/run_entry_tests.yaml b/.github/workflows/run_entry_tests.yaml index c958e9bf2..dbde2dbd1 100644 --- a/.github/workflows/run_entry_tests.yaml +++ b/.github/workflows/run_entry_tests.yaml @@ -21,8 +21,6 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: Run entry tests with pytest run: | diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 46828d5b8..553c88b44 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -21,8 +21,6 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: Test with pytest run: | From e8d4b9d5a85734eb749e1315759bae4b7edaa5d8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 11:25:19 -0800 Subject: [PATCH 106/205] fix doremi test when doing multidevice --- tests/test_doremi.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_doremi.py b/tests/test_doremi.py index dbb9a9889..e1c545ead 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -1,4 +1,5 @@ import equinox +import jax import jax.random import optax @@ -74,13 +75,14 @@ def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis) tiny_trainer_config = TrainerConfig( - num_train_steps=600, train_batch_size=Batch.size, tracker=(), id="kmaklfmaf", per_device_parallelism=Batch.size + num_train_steps=600, + train_batch_size=Batch.size, + tracker=(), + id="kmaklfmaf", + per_device_parallelism=Batch.size // len(jax.devices()), ) optimizer = optax.adam(1e-2) - import jax - - jax.config.update("jax_traceback_filtering", "off") trainer = Trainer(tiny_trainer_config, optimizer, compute_loss_fn) From 5c489c151c98f3ac1b7a0818694ee83622d067b8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 12:32:09 -0800 Subject: [PATCH 107/205] add a pile_mixture.yaml --- config/data/pile_mixture.yaml | 137 ++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 config/data/pile_mixture.yaml diff --git a/config/data/pile_mixture.yaml b/config/data/pile_mixture.yaml new file mode 100644 index 000000000..59adc4128 --- /dev/null +++ b/config/data/pile_mixture.yaml @@ -0,0 +1,137 @@ +cache_dir: "gs://levanter-data/tokenized/pile_domains/" +tokenizer: "EleutherAI/gpt-neox-20b" +configs: + arxiv: + train_urls: + - gs://levanter-data/pile_domains/arxiv/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/arxiv/validation.jsonl.zst + books2: + train_urls: + - gs://levanter-data/pile_domains/books2/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/books2/validation.jsonl.zst + books3: + train_urls: + - gs://levanter-data/pile_domains/books3/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/books3/validation.jsonl.zst + dm_math: + train_urls: + - gs://levanter-data/pile_domains/dm_math/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/dm_math/validation.jsonl.zst + enron: + train_urls: + - gs://levanter-data/pile_domains/enron/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/enron/validation.jsonl.zst + europarl: + train_urls: + - gs://levanter-data/pile_domains/europarl/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/europarl/validation.jsonl.zst + free_law: + train_urls: + - gs://levanter-data/pile_domains/free_law/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/free_law/validation.jsonl.zst + github: + train_urls: + - gs://levanter-data/pile_domains/github/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/github/validation.jsonl.zst + hackernews: + train_urls: + - gs://levanter-data/pile_domains/hackernews/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/hackernews/validation.jsonl.zst + nih: + train_urls: + - gs://levanter-data/pile_domains/nih/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/nih/validation.jsonl.zst + opensubtitles: + train_urls: + - gs://levanter-data/pile_domains/opensubtitles/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/opensubtitles/validation.jsonl.zst + owt2: + train_urls: + - gs://levanter-data/pile_domains/owt2/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/owt2/validation.jsonl.zst + pg_19: + train_urls: + - gs://levanter-data/pile_domains/pg_19/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/pg_19/validation.jsonl.zst + philpapers: + train_urls: + - gs://levanter-data/pile_domains/philpapers/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/philpapers/validation.jsonl.zst + pile_cc: + train_urls: + - gs://levanter-data/pile_domains/pile_cc/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/pile_cc/validation.jsonl.zst + pubmed_abs: + train_urls: + - gs://levanter-data/pile_domains/pubmed_abs/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/pubmed_abs/validation.jsonl.zst + pubmed_central: + train_urls: + - gs://levanter-data/pile_domains/pubmed_central/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/pubmed_central/validation.jsonl.zst + stack_exchange: + train_urls: + - gs://levanter-data/pile_domains/stack_exchange/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/stack_exchange/validation.jsonl.zst + ubuntu_irc: + train_urls: + - gs://levanter-data/pile_domains/ubuntu_irc/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/ubuntu_irc/validation.jsonl.zst + uspto: + train_urls: + - gs://levanter-data/pile_domains/uspto/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/uspto/validation.jsonl.zst + wiki_en: + train_urls: + - gs://levanter-data/pile_domains/wiki_en/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/wiki_en/validation.jsonl.zst + youtube_subtitles: + train_urls: + - gs://levanter-data/pile_domains/youtube_subtitles/train-{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile_domains/youtube_subtitles/validation.jsonl.zst +train_weights: + # these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf + pile_cc: 0.1811 + pubmed_central: 0.1440 + books3: 0.1207 + owt2: 0.1001 + arxiv: 0.0896 + github: 0.0759 + free_law: 0.0612 + stack_exchange: 0.0513 + uspto: 0.0365 + pubmed_abs: 0.0307 + pg_19: 0.0217 + opensubtitles: 0.0155 + wiki_en: 0.0153 + dm_math: 0.0124 + ubuntu_irc: 0.0088 + books2: 0.0075 + europarl: 0.0073 + hackernews: 0.0062 + youtube_subtitles: 0.0060 + philpapers: 0.0038 + nih: 0.0030 + enron: 0.0014 From 16721485ac2abf6a954fcf9138d29de49cb95faf Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 12:32:43 -0800 Subject: [PATCH 108/205] add a config for the small pile mixture --- config/gpt2_small_pile_mixture.yaml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 config/gpt2_small_pile_mixture.yaml diff --git a/config/gpt2_small_pile_mixture.yaml b/config/gpt2_small_pile_mixture.yaml new file mode 100644 index 000000000..a79ec8052 --- /dev/null +++ b/config/gpt2_small_pile_mixture.yaml @@ -0,0 +1,23 @@ +data: !include data/pile_mixture.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 2048 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "pile", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 50000 +optimizer: + learning_rate: 6e-4 + weight_decay: 0.1 From f485c5f33328743d81f2767304fb89f9dcd7caa9 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 13:18:37 -0800 Subject: [PATCH 109/205] reduce default rows per chunk and see if that helps with these big subcorpora --- src/levanter/data/shard_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 33e9fcd41..f01ed32ff 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -45,7 +45,7 @@ logger = pylogging.getLogger(__name__) -DEFAULT_ROWS_PER_CHUNK = 1024 * 32 +DEFAULT_ROWS_PER_CHUNK = 8192 LEDGER_FILE_NAME = "cache_ledger.json" From b2d8a585a88e5d11e480c58a8722bd41d48093f4 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 14:24:45 -0800 Subject: [PATCH 110/205] add some more logging to see if we can figure out why it's running out of memory still --- src/levanter/data/shard_cache.py | 27 ++++++++++++++++++++------- src/levanter/data/sharded_dataset.py | 7 ++++++- src/levanter/main/cache_dataset.py | 1 + 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f01ed32ff..d0b3876b5 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -46,6 +46,7 @@ logger = pylogging.getLogger(__name__) DEFAULT_ROWS_PER_CHUNK = 8192 +DEFAULT_MAX_BYTES_PER_BATCH = 256 * 1024 * 1024 # 256 MB, this is pre-preprocessing python object size LEDGER_FILE_NAME = "cache_ledger.json" @@ -76,7 +77,7 @@ def build_cache( from shard names to iterators over the data in that shard. processor: A BatchProcessor that will be used to process batches of data. This is the main place where you can customize the preprocessing pipeline. - batch_size: The number of input examples to process at once. + batch_size: When reading from the cache, how many examples to read at a time. rows_per_chunk: The number of rows to write to each chunk. May be smaller at the end of a shard. await_finished: If True, this function will block until the cache is finished. If False, it will return immediately. @@ -321,6 +322,7 @@ def _shard_reader_generator(shard_source: ShardedDataset[T], shard_idx: int, sta # chunks and earlier shards. (So that we approximately generate following the global order.) @ray.remote(num_cpus=1, scheduling_strategy="SPREAD") def _alternating_shard_reader( + name: str, builder_ref: ActorHandle, # _ChunkCacheBuilder shard_writers: ActorHandle, # _GroupedShardWriter shard_source: ShardedDataset[T], @@ -330,6 +332,8 @@ def _alternating_shard_reader( batch_size, num_rows_per_chunk, ): + pylogging.basicConfig(level=pylogging.INFO) + logger = pylogging.getLogger(f"shard_reader.{name}") shard_pqueue: list[tuple[int, int]] = [] # heapq of (num_chunks, shard_idx) shard_readers: dict[int, Iterator[list[T]]] = {} try: @@ -350,10 +354,11 @@ def _alternating_shard_reader( ) except Exception as e: # noqa logger.exception(f"Error while initializing shard {shard_name}") - ray.get(shard_writers[shard_name].shard_failed.remote(ser_exc_info())) + # fire and forget + shard_writers[shard_name].shard_failed.remote(ser_exc_info()) raise e - MAX_INFLIGHT = 30 + MAX_INFLIGHT = 10 back_pressure_queue: list[ray.ObjectRef] = [] while len(shard_pqueue) > 0: @@ -381,6 +386,7 @@ def _alternating_shard_reader( # we want to limit the number of pending tasks, so we wait until we're below the limit # before we start reading the next batch while len(back_pressure_queue) >= MAX_INFLIGHT: + logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) priority = priority_fn(shard_idx, chunk_id) @@ -410,7 +416,8 @@ def _alternating_shard_reader( except Exception as e: # noqa logger.exception(f"Error while processing shard {shard_name}") - ray.get(shard_writers.shard_failed.remote(shard_name, ser_exc_info())) + # fire and forget + shard_writers.shard_failed.remote(shard_name, ser_exc_info()) raise e @@ -890,11 +897,13 @@ def __init__( self, broker_ref, cache_dir: str, + name: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int, ): pylogging.basicConfig(level=pylogging.INFO) + self.logger = pylogging.getLogger(f"{__name__}.{name}") self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() self._current_round_robin = [] @@ -904,10 +913,10 @@ def __init__( self_ref = current_actor_handle() if len(source.shard_names) == 0: - logger.warning("No shards to index?!?") + self.logger.warning("No shards to index?!?") self._finish() else: - logger.info(f"Starting cache build for {len(source.shard_names)} shards") + self.logger.info(f"Starting cache build for {len(source.shard_names)} shards") self._shard_writers = [] self._shard_readers = [] @@ -936,6 +945,7 @@ def priority_fn(shard_idx, chunk_idx): self._processor_actors.append(processor_actor) reader = _alternating_shard_reader.remote( + name, self_ref, writer, source, @@ -1072,7 +1082,10 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr self._finished_promise.set_result(None) except FileNotFoundError: self_ref = ray.runtime_context.get_runtime_context().current_actor - self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, self._source, self._processor, self._rows_per_chunk) # type: ignore + # only use the last two components of the name since it gets kind of long + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"builder::{path_for_name}" + self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, name, self._source, self._processor, self._rows_per_chunk) # type: ignore def is_finished(self): return self._is_finished diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 3f3f8c036..1ceae6366 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -93,7 +93,12 @@ def build_cache( source, processor = _construct_composite_batch_processor(self) cache = build_cache( - path, source, processor, rows_per_chunk=rows_per_chunk, await_finished=await_finished, monitors=monitors + path, + source, + processor, + rows_per_chunk=rows_per_chunk, + await_finished=await_finished, + monitors=monitors, ) return DictCacheDataset(cache) diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 9ee6614ca..4dd46e63c 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -46,6 +46,7 @@ def main(args: RayCachedLMDatasetConfig): rows_per_chunk=args.rows_per_chunk, await_finished=False, monitors=monitors, + batch_size=128, ) cache.await_finished() From f76e46644d731fc855a305beb904cd54cc5f54ed Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 16:58:15 -0800 Subject: [PATCH 111/205] add some more logging to see if we can figure out why it's running out of memory still --- src/levanter/data/shard_cache.py | 45 +++++++++++++++++++++++++------- src/levanter/utils/py_utils.py | 23 ++++++++++++++++ tests/test_py_utils.py | 8 ++++++ 3 files changed, 66 insertions(+), 10 deletions(-) create mode 100644 tests/test_py_utils.py diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index d0b3876b5..398d097bb 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -33,6 +33,7 @@ import levanter.tracker from .. import logging +from ..utils.py_utils import actual_sizeof from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info from . import ShardableDataset from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch @@ -360,6 +361,33 @@ def _alternating_shard_reader( MAX_INFLIGHT = 10 back_pressure_queue: list[ray.ObjectRef] = [] + back_pressure_sizes: dict[ray.ObjectRef, int] = {} + retained_batch_sizes = 0 + + def enqueue_to_backpressure(batch, batch_result_ref): + nonlocal back_pressure_queue, retained_batch_sizes + back_pressure_queue.append(batch_result_ref) + if logger.level <= pylogging.DEBUG: + size = actual_sizeof(batch) + retained_batch_sizes += size + back_pressure_sizes[batch_result_ref] = size + if retained_batch_sizes > 1024 * 1024 * 1024: + logger.debug(f"Retained batch sizes: {retained_batch_sizes}") + + size = MAX_INFLIGHT + # we want to limit the number of pending tasks, so we wait until we're below the limit + # before we start reading the next batch + drain_back_pressure_to(size) + + def drain_back_pressure_to(size): + nonlocal back_pressure_queue, retained_batch_sizes + while len(back_pressure_queue) >= size: + logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") + finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) + + if logger.level <= pylogging.DEBUG: + size = back_pressure_sizes.pop(finished_ref[0]) + retained_batch_sizes -= size while len(shard_pqueue) > 0: chunk_id, shard_idx = heapq.heappop(shard_pqueue) @@ -383,21 +411,16 @@ def _alternating_shard_reader( total_chunk_rows += len(batch) if batch: - # we want to limit the number of pending tasks, so we wait until we're below the limit - # before we start reading the next batch - while len(back_pressure_queue) >= MAX_INFLIGHT: - logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") - finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) - priority = priority_fn(shard_idx, chunk_id) - batch = ray.put(batch) - batch_result_ref = ray.get(processor_actor.submit.remote(priority=priority, batch=RefBox(batch))) + batch_result_ref = ray.get( + processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) + ) shard_writers.chunk_batch_finished.remote( shard_name, chunk_id, chunk_batch_idx, RefBox(batch_result_ref) ) - back_pressure_queue.append(batch_result_ref) - chunk_batch_idx += 1 + enqueue_to_backpressure(batch, batch_result_ref) + del batch if total_chunk_rows >= num_rows_per_chunk or exhausted_shard: chunk_filled = True @@ -420,6 +443,8 @@ def _alternating_shard_reader( shard_writers.shard_failed.remote(shard_name, ser_exc_info()) raise e + drain_back_pressure_to(0) + def _initial_shard_metadatas(shard_source, shard_names, shard_group_writer): shard_metadatas: dict[str, ShardMetadata] = {} diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index a172b4498..a181c8193 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -1,4 +1,5 @@ import os +import sys from dataclasses import dataclass from typing import Callable, TypeVar @@ -142,3 +143,25 @@ def cached_classproperty(func: Callable[..., PropReturn]) -> PropReturn: cached_classproperty.__doc__ = _CachedClassProperty.__doc__ + + +def actual_sizeof(obj): + """similar to sys.getsizeof, but recurses into dicts and lists and other objects""" + seen = set() + size = 0 + objects = [obj] + while objects: + need_to_see = [] + for obj in objects: + if id(obj) in seen: + continue + seen.add(id(obj)) + size += sys.getsizeof(obj) + if isinstance(obj, dict): + need_to_see.extend(obj.values()) + elif hasattr(obj, "__dict__"): + need_to_see.extend(obj.__dict__.values()) + elif isinstance(obj, (list, tuple, set, frozenset)): + need_to_see.extend(obj) + objects = need_to_see + return size diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py new file mode 100644 index 000000000..50b3461ee --- /dev/null +++ b/tests/test_py_utils.py @@ -0,0 +1,8 @@ +from levanter.utils.py_utils import actual_sizeof + + +def test_actual_sizeof(): + d1 = {"a": 1, "b": 2} + d2 = {"a": "this is a string", "b": "this is another string"} + + assert actual_sizeof(d1) < actual_sizeof(d2) From fc78716a3cc0a0db15f02692054ce07f218cbc6b Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 10 Jan 2024 20:51:37 -0800 Subject: [PATCH 112/205] dumb --- src/levanter/models/lm_model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 674b70b5c..620a02f87 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -124,16 +124,14 @@ def compute_loss( reduced, and the result is a named array with axes (*batch axes, sequence_length). """ logits = self(example.tokens, example.attn_mask, key=key).astype(jnp.float32) + logits = logits.astype(jnp.float32) targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) - losses = cross_entropy_loss( + loss = cross_entropy_loss( logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask ) - if reduction is None: - return hax.where(example.loss_mask, losses, 0) - else: - return losses + return loss @property def vocab_size(self) -> int: From 4927f674511136fc969e61a5d7884bd379de263b Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 11 Jan 2024 00:25:35 -0800 Subject: [PATCH 113/205] don't run the slow tests in CI --- .github/workflows/run_tests.yaml | 2 +- tests/test_doremi.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 553c88b44..3af69bacf 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -24,4 +24,4 @@ jobs: pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: Test with pytest run: | - XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry" + XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow" diff --git a/tests/test_doremi.py b/tests/test_doremi.py index e1c545ead..a0c28e800 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -2,6 +2,7 @@ import jax import jax.random import optax +import pytest import haliax as hax @@ -43,6 +44,7 @@ def shard(self, shard_id: int, num_shards: int): return LogitDataset(self.W, self.noise, self.x_mask, self.x_bias, key=jax.random.fold_in(self.key, shard_id)) +@pytest.mark.slow def test_estimate_mixture_weights(): # we create 3 simple logistic regression datasets # 1. x is moderately predictive of y (y ~ [0, 0.5, 0.5] x + N(0, noise^2) > 0.5) From 1ceb00a493aaa496825f554616085042a33b7afa Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 12 Jan 2024 12:44:36 -0800 Subject: [PATCH 114/205] wip --- scripts/split-pile-shards.py | 50 ++++++++++++++++++++++++++++++++ scripts/split-pile-shards.sh | 51 +++++++++++++++++++++++++++++++++ src/levanter/models/lm_model.py | 3 +- 3 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 scripts/split-pile-shards.py create mode 100644 scripts/split-pile-shards.sh diff --git a/scripts/split-pile-shards.py b/scripts/split-pile-shards.py new file mode 100644 index 000000000..e6dd67661 --- /dev/null +++ b/scripts/split-pile-shards.py @@ -0,0 +1,50 @@ +import io +import json +import re +import sys +from pathlib import Path +import zstandard as zstd +import tqdm + + +def format_category(category): + return re.sub(r'[^a-z0-9_]', '', category.lower().replace('-', '_')) + + +def process_file(input_file_path): + base_file = Path(input_file_path).stem + ctx = zstd.ZstdDecompressor() + compressors = {} + + with open(input_file_path, 'rb') as compressed_file: + with ctx.stream_reader(compressed_file) as reader: + text_stream = io.TextIOWrapper(reader, encoding='utf-8') + for line in tqdm.tqdm(text_stream): + if not line.strip(): + continue # Skip empty lines + + # Decode line to string and load as JSON + data = json.loads(line) + category = data['meta']['pile_set_name'] + category = format_category(category) + output_dir = Path(category) + output_dir.mkdir(exist_ok=True) + output_file_path = output_dir / f"{base_file}.zst" + + # Check if compressor exists for this category, if not create it + if category not in compressors: + output_file = open(output_file_path, 'wb') + compressors[category] = zstd.ZstdCompressor().stream_writer(output_file) + + # Write to the compressor + compressors[category].write(line.encode('utf-8')) + compressors[category].flush() + + # Close all open compressors + for compressor in compressors.values(): + compressor.close() + + +if __name__ == '__main__': + for path in sys.argv[1:]: + process_file(path) \ No newline at end of file diff --git a/scripts/split-pile-shards.sh b/scripts/split-pile-shards.sh new file mode 100644 index 000000000..11fb43002 --- /dev/null +++ b/scripts/split-pile-shards.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# Assuming the input file is provided as the first argument +input_file="$1" +base_file=$(basename "$input_file" .jsonl.zst) + +declare -A fds + +# Function to set up a file descriptor for zstd for a given category +setup_fd() { + local category=$1 + local output_file="${category}/${base_file}.jsonl.zst" + + # Create the directory if it doesn't exist + mkdir -p "$category" + + # Set up a file descriptor for zstd process for this category + exec {fd}> >(zstd -z > "$output_file") + fds[$category]=$fd +} + +cleanup() { +for fd in "${fds[@]}"; do + exec {fd}>&- +done +} + +trap cleanup EXIT + +C=0 + +# Decompress the input file and process line by line +while IFS= read -r line; do + # Print a progress indicator + if [ $((C++ % 1000)) -eq 0 ]; then + echo -ne "\r$C" + fi + # Extract the category value + category=$(echo "$line" | jq -r '.meta.pile_set_name' | tr [:upper:] [:lower:] | tr - _ | tr -Cd [a-z0-9_]) + + # Check if we already have a pipe for this category, if not, set it up + if [ -z "${fds[$category]}" ]; then + setup_fd "$category" + fi + + # Write to the appropriate pipe + eval "echo \"\$line\" >&${fds[$category]}" +done < <(zstdcat "$input_file") + + +cleanup \ No newline at end of file diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 620a02f87..b68a33f2b 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -123,7 +123,8 @@ def compute_loss( across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ - logits = self(example.tokens, example.attn_mask, key=key).astype(jnp.float32) + logits = self(example.tokens, example.attn_mask, key=key) + # TODO: would be nice if we made the dtype configurable logits = logits.astype(jnp.float32) targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) From bc7108c94a1bcfa0c16acab2a0781669e693afc0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 13 Jan 2024 00:20:26 -0800 Subject: [PATCH 115/205] move the script, make it read off fsspec --- scripts/preproc/split-pile-shards.py | 75 ++++++++++++++++++++++++++++ scripts/split-pile-shards.py | 50 ------------------- scripts/split-pile-shards.sh | 51 ------------------- 3 files changed, 75 insertions(+), 101 deletions(-) create mode 100644 scripts/preproc/split-pile-shards.py delete mode 100644 scripts/split-pile-shards.py delete mode 100644 scripts/split-pile-shards.sh diff --git a/scripts/preproc/split-pile-shards.py b/scripts/preproc/split-pile-shards.py new file mode 100644 index 000000000..768b24874 --- /dev/null +++ b/scripts/preproc/split-pile-shards.py @@ -0,0 +1,75 @@ +import json +import os +import sys +from pathlib import Path + +import fsspec +import tqdm + + +OUT_PATH = "gs://levanter-data/pile-domains" + +categories_to_out_names = { + "ArXiv": "arxiv", + "BookCorpus2": "books2", + "Books3": "books3", + "DM Mathematics": "dm_math", + "Enron Emails": "enron", + "EuroParl": "europarl", + "FreeLaw": "freelaw", + "Github": "github", + "Gutenberg (PG-19)": "pg_19", + "HackerNews": "hackernews", + "NIH ExPorter": "nih", + "OpenSubtitles": "opensubtitles", + "OpenWebText2": "owt2", + "PhilPapers": "philpapers", + "Pile-CC": "pile_cc", + "PubMed Abstracts": "pubmed_abs", + "PubMed Central": "pubmed_central", + "StackExchange": "stack_exchange", + "USPTO Backgrounds": "uspto", + "Ubuntu IRC": "ubuntu_irc", + "Wikipedia (en)": "wiki_en", + "YoutubeSubtitles": "youtube_subtitles", +} + + +def format_category(category): + return categories_to_out_names[category] + + +def process_file(input_file_path): + base_file = Path(input_file_path).stem + compressors = {} + + with fsspec.open(input_file_path, "r", compression="infer") as text_stream: + for line in tqdm.tqdm(text_stream): + if not line.strip(): + continue # Skip empty lines + + # Decode line to string and load as JSON + data = json.loads(line) + category = data["meta"]["pile_set_name"] + category = format_category(category) + output_file_path = os.path.join(OUT_PATH, category, f"{base_file}.zst") + + # Check if compressor exists for this category, if not create it + if category not in compressors: + # output_file = open(output_file_path, 'wb') + output_file = fsspec.open(str(output_file_path), "wb", compression="infer").open() + print("opened", output_file_path) + compressors[category] = output_file + + # Write to the compressor + compressors[category].write(line.encode("utf-8")) + compressors[category].flush() + + # Close all open compressors + for compressor in compressors.values(): + compressor.close() + + +if __name__ == "__main__": + for path in sys.argv[1:]: + process_file(path) diff --git a/scripts/split-pile-shards.py b/scripts/split-pile-shards.py deleted file mode 100644 index e6dd67661..000000000 --- a/scripts/split-pile-shards.py +++ /dev/null @@ -1,50 +0,0 @@ -import io -import json -import re -import sys -from pathlib import Path -import zstandard as zstd -import tqdm - - -def format_category(category): - return re.sub(r'[^a-z0-9_]', '', category.lower().replace('-', '_')) - - -def process_file(input_file_path): - base_file = Path(input_file_path).stem - ctx = zstd.ZstdDecompressor() - compressors = {} - - with open(input_file_path, 'rb') as compressed_file: - with ctx.stream_reader(compressed_file) as reader: - text_stream = io.TextIOWrapper(reader, encoding='utf-8') - for line in tqdm.tqdm(text_stream): - if not line.strip(): - continue # Skip empty lines - - # Decode line to string and load as JSON - data = json.loads(line) - category = data['meta']['pile_set_name'] - category = format_category(category) - output_dir = Path(category) - output_dir.mkdir(exist_ok=True) - output_file_path = output_dir / f"{base_file}.zst" - - # Check if compressor exists for this category, if not create it - if category not in compressors: - output_file = open(output_file_path, 'wb') - compressors[category] = zstd.ZstdCompressor().stream_writer(output_file) - - # Write to the compressor - compressors[category].write(line.encode('utf-8')) - compressors[category].flush() - - # Close all open compressors - for compressor in compressors.values(): - compressor.close() - - -if __name__ == '__main__': - for path in sys.argv[1:]: - process_file(path) \ No newline at end of file diff --git a/scripts/split-pile-shards.sh b/scripts/split-pile-shards.sh deleted file mode 100644 index 11fb43002..000000000 --- a/scripts/split-pile-shards.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash - -# Assuming the input file is provided as the first argument -input_file="$1" -base_file=$(basename "$input_file" .jsonl.zst) - -declare -A fds - -# Function to set up a file descriptor for zstd for a given category -setup_fd() { - local category=$1 - local output_file="${category}/${base_file}.jsonl.zst" - - # Create the directory if it doesn't exist - mkdir -p "$category" - - # Set up a file descriptor for zstd process for this category - exec {fd}> >(zstd -z > "$output_file") - fds[$category]=$fd -} - -cleanup() { -for fd in "${fds[@]}"; do - exec {fd}>&- -done -} - -trap cleanup EXIT - -C=0 - -# Decompress the input file and process line by line -while IFS= read -r line; do - # Print a progress indicator - if [ $((C++ % 1000)) -eq 0 ]; then - echo -ne "\r$C" - fi - # Extract the category value - category=$(echo "$line" | jq -r '.meta.pile_set_name' | tr [:upper:] [:lower:] | tr - _ | tr -Cd [a-z0-9_]) - - # Check if we already have a pipe for this category, if not, set it up - if [ -z "${fds[$category]}" ]; then - setup_fd "$category" - fi - - # Write to the appropriate pipe - eval "echo \"\$line\" >&${fds[$category]}" -done < <(zstdcat "$input_file") - - -cleanup \ No newline at end of file From 69ca4a4af8e1873e41f6808734d3ae4bb197cda6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 13 Jan 2024 00:21:00 -0800 Subject: [PATCH 116/205] update for reverted Haliax change --- src/levanter/doremi.py | 3 ++- tests/test_flash_attention.py | 2 +- tests/test_grad_accum.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index cf10345a3..10de53635 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -145,6 +145,7 @@ def compute_excess_loss(proxy, ref, batch): return excess_losses # Loss is \sum_d alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) + # Note that (\sum_d \alpha_d ref) is a constant in the model params, so we can ignore it @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True,)) def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) @@ -290,5 +291,5 @@ def _compute_per_domain_losses(Domain, domains, losses): def _domain_weighted_loss(losses, Domain, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - return hax.mean(losses.broadcast_axis(Domain) * one_hot_domains * alpha, axis=None) + return hax.mean(losses.broadcast_axis(Domain) * one_hot_domains * alpha, axis=None).scalar() # return hax.dot(alpha, one_hot_domains, losses, axis=None) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index d4b1f08b4..8d1f5aab0 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -60,7 +60,7 @@ def test_grad_attention(): def d_attn(qkv, fn): q, k, v = qkv x_out = fn(KPos, Key, q, k, v, mask=mask) - return (x_out * x_out).sum() + return (x_out * x_out).sum().scalar() hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention) fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True)) diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py index dd6bfa761..131ca3b89 100644 --- a/tests/test_grad_accum.py +++ b/tests/test_grad_accum.py @@ -44,7 +44,7 @@ def test_accumulate_gradients_sharded(parallelism, accum_steps): mlp = Mlp.init(In, Out, Mid, key=jax.random.PRNGKey(0)) def loss_fn(mlp, x): - return mlp(x).mean() + return mlp(x).mean().scalar() x = hax.random.normal(jax.random.PRNGKey(0), (Batch, In)) From ff5cb6d3cf668830ba41863b1e7bffb1fc622190 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 13 Jan 2024 00:22:26 -0800 Subject: [PATCH 117/205] update for reverted Haliax change --- tests/test_doremi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_doremi.py b/tests/test_doremi.py index a0c28e800..382389559 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -38,7 +38,7 @@ def __iter__(self): noise = hax.random.normal(next(key_iter), (Block,)) * self.noise y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float) for i in range(Block.size): - yield Example(x=x_block[Block, i], y=hax.named(y_block[Block, i], ())) + yield Example(x=x_block[Block, i], y=y_block[Block, i]) def shard(self, shard_id: int, num_shards: int): return LogitDataset(self.W, self.noise, self.x_mask, self.x_bias, key=jax.random.fold_in(self.key, shard_id)) From d6bf2c077a8937f3072501c01445fe7e0d4964fd Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 10:36:58 -0800 Subject: [PATCH 118/205] update paths for pile mixture --- config/data/pile_mixture.yaml | 90 +++++++++++++++++------------------ 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/config/data/pile_mixture.yaml b/config/data/pile_mixture.yaml index 59adc4128..fcae71f22 100644 --- a/config/data/pile_mixture.yaml +++ b/config/data/pile_mixture.yaml @@ -1,116 +1,116 @@ -cache_dir: "gs://levanter-data/tokenized/pile_domains/" +cache_dir: "gs://levanter-data/tokenized/pile-domains/" tokenizer: "EleutherAI/gpt-neox-20b" configs: arxiv: train_urls: - - gs://levanter-data/pile_domains/arxiv/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/arxiv/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/arxiv/validation.jsonl.zst + - gs://levanter-data/pile-domains/arxiv/validation.jsonl.zst books2: train_urls: - - gs://levanter-data/pile_domains/books2/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/books2/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/books2/validation.jsonl.zst + - gs://levanter-data/pile-domains/books2/validation.jsonl.zst books3: train_urls: - - gs://levanter-data/pile_domains/books3/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/books3/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/books3/validation.jsonl.zst + - gs://levanter-data/pile-domains/books3/validation.jsonl.zst dm_math: train_urls: - - gs://levanter-data/pile_domains/dm_math/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/dm_math/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/dm_math/validation.jsonl.zst + - gs://levanter-data/pile-domains/dm_math/validation.jsonl.zst enron: train_urls: - - gs://levanter-data/pile_domains/enron/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/enron/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/enron/validation.jsonl.zst + - gs://levanter-data/pile-domains/enron/validation.jsonl.zst europarl: train_urls: - - gs://levanter-data/pile_domains/europarl/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/europarl/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/europarl/validation.jsonl.zst + - gs://levanter-data/pile-domains/europarl/validation.jsonl.zst free_law: train_urls: - - gs://levanter-data/pile_domains/free_law/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/free_law/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/free_law/validation.jsonl.zst + - gs://levanter-data/pile-domains/free_law/validation.jsonl.zst github: train_urls: - - gs://levanter-data/pile_domains/github/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/github/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/github/validation.jsonl.zst + - gs://levanter-data/pile-domains/github/validation.jsonl.zst hackernews: train_urls: - - gs://levanter-data/pile_domains/hackernews/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/hackernews/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/hackernews/validation.jsonl.zst + - gs://levanter-data/pile-domains/hackernews/validation.jsonl.zst nih: train_urls: - - gs://levanter-data/pile_domains/nih/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/nih/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/nih/validation.jsonl.zst + - gs://levanter-data/pile-domains/nih/validation.jsonl.zst opensubtitles: train_urls: - - gs://levanter-data/pile_domains/opensubtitles/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/opensubtitles/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/opensubtitles/validation.jsonl.zst + - gs://levanter-data/pile-domains/opensubtitles/validation.jsonl.zst owt2: train_urls: - - gs://levanter-data/pile_domains/owt2/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/owt2/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/owt2/validation.jsonl.zst + - gs://levanter-data/pile-domains/owt2/validation.jsonl.zst pg_19: train_urls: - - gs://levanter-data/pile_domains/pg_19/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pg_19/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/pg_19/validation.jsonl.zst + - gs://levanter-data/pile-domains/pg_19/validation.jsonl.zst philpapers: train_urls: - - gs://levanter-data/pile_domains/philpapers/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/philpapers/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/philpapers/validation.jsonl.zst + - gs://levanter-data/pile-domains/philpapers/validation.jsonl.zst pile_cc: train_urls: - - gs://levanter-data/pile_domains/pile_cc/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pile_cc/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/pile_cc/validation.jsonl.zst + - gs://levanter-data/pile-domains/pile_cc/validation.jsonl.zst pubmed_abs: train_urls: - - gs://levanter-data/pile_domains/pubmed_abs/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_abs/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/pubmed_abs/validation.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_abs/validation.jsonl.zst pubmed_central: train_urls: - - gs://levanter-data/pile_domains/pubmed_central/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_central/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/pubmed_central/validation.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_central/validation.jsonl.zst stack_exchange: train_urls: - - gs://levanter-data/pile_domains/stack_exchange/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/stack_exchange/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/stack_exchange/validation.jsonl.zst + - gs://levanter-data/pile-domains/stack_exchange/validation.jsonl.zst ubuntu_irc: train_urls: - - gs://levanter-data/pile_domains/ubuntu_irc/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/ubuntu_irc/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/ubuntu_irc/validation.jsonl.zst + - gs://levanter-data/pile-domains/ubuntu_irc/validation.jsonl.zst uspto: train_urls: - - gs://levanter-data/pile_domains/uspto/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/uspto/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/uspto/validation.jsonl.zst + - gs://levanter-data/pile-domains/uspto/validation.jsonl.zst wiki_en: train_urls: - - gs://levanter-data/pile_domains/wiki_en/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/wiki_en/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/wiki_en/validation.jsonl.zst + - gs://levanter-data/pile-domains/wiki_en/validation.jsonl.zst youtube_subtitles: train_urls: - - gs://levanter-data/pile_domains/youtube_subtitles/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/youtube_subtitles/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile_domains/youtube_subtitles/validation.jsonl.zst + - gs://levanter-data/pile-domains/youtube_subtitles/validation.jsonl.zst train_weights: # these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf pile_cc: 0.1811 From cc6044c2f6a13e4ef25ef8030768710af185d315 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 10:55:20 -0800 Subject: [PATCH 119/205] fix new import --- src/levanter/optim/second_order.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/levanter/optim/second_order.py b/src/levanter/optim/second_order.py index fd0da7325..e3c205432 100644 --- a/src/levanter/optim/second_order.py +++ b/src/levanter/optim/second_order.py @@ -8,7 +8,7 @@ import optax from jax import numpy as jnp from optax._src import numerics -from optax._src.schedule import InjectHyperparamsState, _convert_floats +from optax._src.schedule import InjectHyperparamsState class HessianUpdateFn(typing.Protocol): @@ -226,3 +226,11 @@ def update_hessian(state, fn, model, *batch, **batch_kwargs): return SecondOrderTransformation(init_fn, update_fn, update_hessian) return wrapped_transform + + +# Cribbed from optax._src.schedule, which recently deleted this function. +def _convert_floats(x, dtype): + """Convert float-like inputs to dtype, rest pass through.""" + if jax.dtypes.scalar_type_of(x) == float: + return jnp.asarray(x, dtype=dtype) + return x From 415158a197a926210db33314224b3c150e99690d Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 10:57:57 -0800 Subject: [PATCH 120/205] sigh --- config/data/pile_mixture.yaml | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/config/data/pile_mixture.yaml b/config/data/pile_mixture.yaml index fcae71f22..062a4b70b 100644 --- a/config/data/pile_mixture.yaml +++ b/config/data/pile_mixture.yaml @@ -5,112 +5,112 @@ configs: train_urls: - gs://levanter-data/pile-domains/arxiv/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/arxiv/validation.jsonl.zst + - gs://levanter-data/pile-domains/arxiv/val.jsonl.zst books2: train_urls: - gs://levanter-data/pile-domains/books2/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/books2/validation.jsonl.zst + - gs://levanter-data/pile-domains/books2/val.jsonl.zst books3: train_urls: - gs://levanter-data/pile-domains/books3/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/books3/validation.jsonl.zst + - gs://levanter-data/pile-domains/books3/val.jsonl.zst dm_math: train_urls: - gs://levanter-data/pile-domains/dm_math/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/dm_math/validation.jsonl.zst + - gs://levanter-data/pile-domains/dm_math/val.jsonl.zst enron: train_urls: - gs://levanter-data/pile-domains/enron/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/enron/validation.jsonl.zst + - gs://levanter-data/pile-domains/enron/val.jsonl.zst europarl: train_urls: - gs://levanter-data/pile-domains/europarl/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/europarl/validation.jsonl.zst + - gs://levanter-data/pile-domains/europarl/val.jsonl.zst free_law: train_urls: - gs://levanter-data/pile-domains/free_law/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/free_law/validation.jsonl.zst + - gs://levanter-data/pile-domains/free_law/val.jsonl.zst github: train_urls: - gs://levanter-data/pile-domains/github/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/github/validation.jsonl.zst + - gs://levanter-data/pile-domains/github/val.jsonl.zst hackernews: train_urls: - gs://levanter-data/pile-domains/hackernews/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/hackernews/validation.jsonl.zst + - gs://levanter-data/pile-domains/hackernews/val.jsonl.zst nih: train_urls: - gs://levanter-data/pile-domains/nih/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/nih/validation.jsonl.zst + - gs://levanter-data/pile-domains/nih/val.jsonl.zst opensubtitles: train_urls: - gs://levanter-data/pile-domains/opensubtitles/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/opensubtitles/validation.jsonl.zst + - gs://levanter-data/pile-domains/opensubtitles/val.jsonl.zst owt2: train_urls: - gs://levanter-data/pile-domains/owt2/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/owt2/validation.jsonl.zst + - gs://levanter-data/pile-domains/owt2/val.jsonl.zst pg_19: train_urls: - gs://levanter-data/pile-domains/pg_19/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/pg_19/validation.jsonl.zst + - gs://levanter-data/pile-domains/pg_19/val.jsonl.zst philpapers: train_urls: - gs://levanter-data/pile-domains/philpapers/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/philpapers/validation.jsonl.zst + - gs://levanter-data/pile-domains/philpapers/val.jsonl.zst pile_cc: train_urls: - gs://levanter-data/pile-domains/pile_cc/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/pile_cc/validation.jsonl.zst + - gs://levanter-data/pile-domains/pile_cc/val.jsonl.zst pubmed_abs: train_urls: - gs://levanter-data/pile-domains/pubmed_abs/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/pubmed_abs/validation.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_abs/val.jsonl.zst pubmed_central: train_urls: - gs://levanter-data/pile-domains/pubmed_central/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/pubmed_central/validation.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_central/val.jsonl.zst stack_exchange: train_urls: - gs://levanter-data/pile-domains/stack_exchange/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/stack_exchange/validation.jsonl.zst + - gs://levanter-data/pile-domains/stack_exchange/val.jsonl.zst ubuntu_irc: train_urls: - gs://levanter-data/pile-domains/ubuntu_irc/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/ubuntu_irc/validation.jsonl.zst + - gs://levanter-data/pile-domains/ubuntu_irc/val.jsonl.zst uspto: train_urls: - gs://levanter-data/pile-domains/uspto/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/uspto/validation.jsonl.zst + - gs://levanter-data/pile-domains/uspto/val.jsonl.zst wiki_en: train_urls: - gs://levanter-data/pile-domains/wiki_en/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/wiki_en/validation.jsonl.zst + - gs://levanter-data/pile-domains/wiki_en/val.jsonl.zst youtube_subtitles: train_urls: - gs://levanter-data/pile-domains/youtube_subtitles/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/youtube_subtitles/validation.jsonl.zst + - gs://levanter-data/pile-domains/youtube_subtitles/val.jsonl.zst train_weights: # these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf pile_cc: 0.1811 From d2a90aea7c2e741ff9d489da3e3ce488878d2e96 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 11:21:06 -0800 Subject: [PATCH 121/205] isjfo --- config/data/pile_mixture.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/data/pile_mixture.yaml b/config/data/pile_mixture.yaml index 062a4b70b..1578c7042 100644 --- a/config/data/pile_mixture.yaml +++ b/config/data/pile_mixture.yaml @@ -33,9 +33,9 @@ configs: - gs://levanter-data/pile-domains/europarl/val.jsonl.zst free_law: train_urls: - - gs://levanter-data/pile-domains/free_law/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/freelaw/train-{00..29}.jsonl.zst validation_urls: - - gs://levanter-data/pile-domains/free_law/val.jsonl.zst + - gs://levanter-data/pile-domains/freelaw/val.jsonl.zst github: train_urls: - gs://levanter-data/pile-domains/github/train-{00..29}.jsonl.zst From 058a9e000315860528181640429e5f8d5759c562 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 11:32:11 -0800 Subject: [PATCH 122/205] mdklmdlm --- config/data/pile_mixture.yaml | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/config/data/pile_mixture.yaml b/config/data/pile_mixture.yaml index 1578c7042..ff75b8941 100644 --- a/config/data/pile_mixture.yaml +++ b/config/data/pile_mixture.yaml @@ -3,112 +3,112 @@ tokenizer: "EleutherAI/gpt-neox-20b" configs: arxiv: train_urls: - - gs://levanter-data/pile-domains/arxiv/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/arxiv/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/arxiv/val.jsonl.zst books2: train_urls: - - gs://levanter-data/pile-domains/books2/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/books2/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/books2/val.jsonl.zst books3: train_urls: - - gs://levanter-data/pile-domains/books3/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/books3/val.jsonl.zst dm_math: train_urls: - - gs://levanter-data/pile-domains/dm_math/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/dm_math/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/dm_math/val.jsonl.zst enron: train_urls: - - gs://levanter-data/pile-domains/enron/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/enron/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/enron/val.jsonl.zst europarl: train_urls: - - gs://levanter-data/pile-domains/europarl/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/europarl/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/europarl/val.jsonl.zst free_law: train_urls: - - gs://levanter-data/pile-domains/freelaw/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/freelaw/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/freelaw/val.jsonl.zst github: train_urls: - - gs://levanter-data/pile-domains/github/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/github/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/github/val.jsonl.zst hackernews: train_urls: - - gs://levanter-data/pile-domains/hackernews/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/hackernews/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/hackernews/val.jsonl.zst nih: train_urls: - - gs://levanter-data/pile-domains/nih/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/nih/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/nih/val.jsonl.zst opensubtitles: train_urls: - - gs://levanter-data/pile-domains/opensubtitles/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/opensubtitles/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/opensubtitles/val.jsonl.zst owt2: train_urls: - - gs://levanter-data/pile-domains/owt2/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/owt2/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/owt2/val.jsonl.zst pg_19: train_urls: - - gs://levanter-data/pile-domains/pg_19/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pg_19/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/pg_19/val.jsonl.zst philpapers: train_urls: - - gs://levanter-data/pile-domains/philpapers/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/philpapers/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/philpapers/val.jsonl.zst pile_cc: train_urls: - - gs://levanter-data/pile-domains/pile_cc/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pile_cc/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/pile_cc/val.jsonl.zst pubmed_abs: train_urls: - - gs://levanter-data/pile-domains/pubmed_abs/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_abs/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/pubmed_abs/val.jsonl.zst pubmed_central: train_urls: - - gs://levanter-data/pile-domains/pubmed_central/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/pubmed_central/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/pubmed_central/val.jsonl.zst stack_exchange: train_urls: - - gs://levanter-data/pile-domains/stack_exchange/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/stack_exchange/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/stack_exchange/val.jsonl.zst ubuntu_irc: train_urls: - - gs://levanter-data/pile-domains/ubuntu_irc/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/ubuntu_irc/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/ubuntu_irc/val.jsonl.zst uspto: train_urls: - - gs://levanter-data/pile-domains/uspto/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/uspto/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/uspto/val.jsonl.zst wiki_en: train_urls: - - gs://levanter-data/pile-domains/wiki_en/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/wiki_en/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/wiki_en/val.jsonl.zst youtube_subtitles: train_urls: - - gs://levanter-data/pile-domains/youtube_subtitles/train-{00..29}.jsonl.zst + - gs://levanter-data/pile-domains/youtube_subtitles/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/youtube_subtitles/val.jsonl.zst train_weights: From 9f16fbe10dd207f05ea6aacea6b5bf7425be717f Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 11:55:09 -0800 Subject: [PATCH 123/205] make logging list names of caches --- src/levanter/data/text.py | 8 +++++--- src/levanter/utils/hf_utils.py | 3 +-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index f952a882a..212318433 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -568,9 +568,11 @@ def token_seq_dataset( return TokenSeqDataset(cache, seq_len) def build_or_load_cache( - self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True + self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None ) -> Optional[TokenizedDocumentCache]: split_cache_dir = os.path.join(self.cache_dir, split) + name = logger_name or os.path.basename(self.cache_dir) + try: return TokenizedDocumentCache.load(split_cache_dir, flatten_docs=True) except FileNotFoundError: @@ -585,8 +587,8 @@ def build_or_load_cache( if monitors is True: monitors = [ - LoggingMetricsMonitor(prefix=f"preprocessing/{split}", commit=False), - LoggerMetricsMonitor(f"preprocessing.{split}"), + LoggingMetricsMonitor(prefix=f"preprocessing/{name}/{split}", commit=False), + LoggerMetricsMonitor(f"preprocessing.{name}.{split}"), ] elif monitors is False: monitors = [] diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index ff9fdb7af..408a8c8da 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,8 +18,7 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - # Empirically I never see it get past 10 (usually more like 5-8), so we'll say 8 - return min(max(1, logical_cpu_core_count() - 2), 8) + return min(max(1, logical_cpu_core_count() - 2), 32) else: return 1 From b80ef6acbba4eb9cf49f667e594beeffd8a745ea Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 12:20:49 -0800 Subject: [PATCH 124/205] lower resource requirements to see if this gets us processing faster --- src/levanter/data/shard_cache.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 398d097bb..83b75bc8f 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -700,7 +700,7 @@ def task_running(self): # Ray does poorly with large numbers of actors (grumble grumble), so we can't have one actor per shard. # This class wraps a map of shard names to _ShardWriterWorkers, and manages the lifecycle of the workers. -@ray.remote(num_cpus=1, scheduling_strategy="SPREAD") # type: ignore +@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") # type: ignore class _GroupShardWriterWorker: def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): pylogging.basicConfig(level=pylogging.INFO) @@ -907,7 +907,7 @@ def _attempt_to_write_chunk_fragments(self, chunk_id) -> Optional[ChunkMetadata] return None -@ray.remote +@ray.remote(num_cpus=0.25) # keep this small b/c it doesn't do a lot class ChunkCacheBuilder: """ Actor that manages the in-progress global ordering on chunks. ChunkCacheWriter's job is to hold the list of all @@ -1298,6 +1298,8 @@ def __init__( self._num_readers = num_readers self._reader_offset = reader_offset + name = os.path.join(*cache_dir.split("/")[-2:]) + self.logger = pylogging.getLogger(f"ShardCache.{name}") @staticmethod def load(cache_dir: str, batch_size: int) -> "ShardCache": @@ -1356,13 +1358,15 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N time_in = time.time() # we want to also log if we're waiting for a long time, so we do this in a loop while timeout is None or time.time() - time_in < timeout: - current_timeout = 20.0 # be generous + current_timeout = 20.0 if timeout is not None: current_timeout = min(current_timeout, timeout - (time.time() - time_in)) try: chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) except GetTimeoutError: - logger.warning(f"Waiting for chunk {mapped_index} after {int(time.time() - time_in)} seconds") + self.logger.warning(f"Waiting for chunk {mapped_index} for {int(time.time() - time_in)} seconds") + current_timeout *= 2 + current_timeout = min(current_timeout, 80) continue if chunk is None: @@ -1417,7 +1421,7 @@ def iter_batches_from_chunks(self, loop: bool = False): i = shard_offset while True: try: - logger.debug(f"Reading chunk {i}") + self.logger.debug(f"Reading chunk {i}") chunk = self._get_chunk_unmapped(i) i += self._num_readers yield from self._read_chunk(chunk) @@ -1430,7 +1434,7 @@ def iter_batches_from_chunks(self, loop: bool = False): else: break except Exception as e: - logger.exception("Error while reading from shard cache.") + self.logger.exception("Error while reading from shard cache.") raise e def __iter__(self): @@ -1491,7 +1495,7 @@ def _monitor_metrics(self): if metrics.is_finished: break except Exception as e: - logger.exception("Error while reading metrics from shard cache.") + self.logger.exception("Error while reading metrics from shard cache.") raise e From 6983ff097ebb46c91961fcefd05deb4b3ab4d766 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 14:52:37 -0800 Subject: [PATCH 125/205] let's make the chunkcachebuilders free --- src/levanter/data/shard_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 83b75bc8f..3f91f7775 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -907,7 +907,7 @@ def _attempt_to_write_chunk_fragments(self, chunk_id) -> Optional[ChunkMetadata] return None -@ray.remote(num_cpus=0.25) # keep this small b/c it doesn't do a lot +@ray.remote(num_cpus=0.0) # keep this small b/c it doesn't do a lot class ChunkCacheBuilder: """ Actor that manages the in-progress global ordering on chunks. ChunkCacheWriter's job is to hold the list of all From 5f42ad8a1af5f21f726202241585c7ee7e418fb0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 15:38:41 -0800 Subject: [PATCH 126/205] minimize use of optax internals --- src/levanter/optim/second_order.py | 7 ++----- src/levanter/optim/sophia.py | 33 ++++++++++++++++++++++-------- src/levanter/tracker/helpers.py | 10 ++++++--- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/levanter/optim/second_order.py b/src/levanter/optim/second_order.py index e3c205432..036c6e157 100644 --- a/src/levanter/optim/second_order.py +++ b/src/levanter/optim/second_order.py @@ -7,8 +7,7 @@ import jax import optax from jax import numpy as jnp -from optax._src import numerics -from optax._src.schedule import InjectHyperparamsState +from optax import InjectHyperparamsState class HessianUpdateFn(typing.Protocol): @@ -189,10 +188,8 @@ def update_fn(updates, state, params=None): hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()} hparams.update(schedule_fn(state.count, dtype)) updates, inner_state = inner_factory(**other_hps, **hparams).update(updates, state.inner_state, params) - count_inc = numerics.safe_int32_increment(state.count) - # pylint:disable=too-many-function-args - return updates, InjectHyperparamsState(count_inc, hparams, inner_state) + return updates, InjectHyperparamsState(state.count + 1, hparams, inner_state) # pylint:enable=too-many-function-args def _find_first_floating_dtype(updates): diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index ce41758cd..6a26c1253 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -1,4 +1,5 @@ import abc +import functools import typing from dataclasses import dataclass from typing import Any, NamedTuple, Optional, TypeVar, runtime_checkable @@ -11,10 +12,6 @@ from jax.random import PRNGKey from jaxtyping import PRNGKeyArray -# TODO: remove dependency on _src internals -from optax._src import numerics -from optax._src.transform import bias_correction, update_moment - import levanter.tracker from levanter.optim.config import HessianOptConfig, OptimizerConfig from levanter.optim.second_order import SecondOrderTransformation, chain_second_order, inject_hyperparams @@ -294,8 +291,7 @@ def init_fn(params): def update_fn(updates, state, params=None): mu = update_moment(updates, state.mu, b1, 1) # nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) - count_inc = numerics.safe_int32_increment(state.count) - mu_hat = bias_correction(mu, b1, count_inc) + mu_hat = bias_correction(mu, b1, state.count + 1) h_hat = state.h # track how often hessian is used mu_leaves = jax.tree_util.tree_leaves(mu_hat) @@ -328,7 +324,7 @@ def update_fn(updates, state, params=None): mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) return updates, ScaleBySophiaState( - count=count_inc, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key + count=state.count + 1, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key ) def update_hessian(state, fn, model, *batch, **batch_kwargs): @@ -339,10 +335,9 @@ def _do_update(): new_hess = tree_filter_like(state.h, new_hess) # EMAs of hessian - hessian_count_inc = numerics.safe_int32_increment(state.hessian_count) nu = update_moment(new_hess, state.h, b2, 1) return ScaleBySophiaState( - count=state.count, hessian_count=hessian_count_inc, mu=state.mu, h=nu, hess_key=next_key + count=state.count, hessian_count=state.hessian_count + 1, mu=state.mu, h=nu, hess_key=next_key ) def _dont_update(): @@ -411,3 +406,23 @@ def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs): hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g) return hessian + + +# Cribbed from optax._src.transform +def update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order`-th moment.""" + return jax.tree_util.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +@functools.partial(jax.jit, inline=True) +def bias_correction(moment, decay, count): + """Performs bias correction. It becomes a no-op as count goes to infinity.""" + # The conversion to the data type of the moment ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_correction_` is calculated as calculating `decay**count` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + bias_correction_ = 1 - decay**count + + # Perform division in the original precision. + return jax.tree_util.tree_map(lambda t: t / bias_correction_.astype(t.dtype), moment) diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py index 31131d1ac..1091840c5 100644 --- a/src/levanter/tracker/helpers.py +++ b/src/levanter/tracker/helpers.py @@ -4,7 +4,6 @@ from typing import Optional from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax._src.wrappers import MultiStepsState import levanter.tracker from levanter.utils.jax_utils import jnp_to_python @@ -14,8 +13,13 @@ def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state + try: + from optax._src.wrappers import MultiStepsState + + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + except ImportError: + pass def wrap_key(key): if prefix: From e6e8d27f4ef29490c1a52e083c6020dfdeb9a63c Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 17:32:34 -0800 Subject: [PATCH 127/205] fix a crash i don't understand --- src/levanter/data/shard_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 3f91f7775..87708b6fa 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -385,7 +385,7 @@ def drain_back_pressure_to(size): logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) - if logger.level <= pylogging.DEBUG: + if len(back_pressure_queue) and logger.level <= pylogging.DEBUG: size = back_pressure_sizes.pop(finished_ref[0]) retained_batch_sizes -= size From ab29e92fcfa8cc61069009d71f3af5314573daec Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 22:37:00 -0800 Subject: [PATCH 128/205] let's reduce requirements some more to see if we can keep everything running. should really solve this correctly --- src/levanter/data/shard_cache.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 87708b6fa..a45209834 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -321,7 +321,7 @@ def _shard_reader_generator(shard_source: ShardedDataset[T], shard_idx: int, sta # This class is responsible for reading batches from a set of shards, prioritizing earlier # chunks and earlier shards. (So that we approximately generate following the global order.) -@ray.remote(num_cpus=1, scheduling_strategy="SPREAD") +@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") def _alternating_shard_reader( name: str, builder_ref: ActorHandle, # _ChunkCacheBuilder @@ -374,14 +374,13 @@ def enqueue_to_backpressure(batch, batch_result_ref): if retained_batch_sizes > 1024 * 1024 * 1024: logger.debug(f"Retained batch sizes: {retained_batch_sizes}") - size = MAX_INFLIGHT # we want to limit the number of pending tasks, so we wait until we're below the limit # before we start reading the next batch - drain_back_pressure_to(size) + drain_back_pressure_to(MAX_INFLIGHT) def drain_back_pressure_to(size): nonlocal back_pressure_queue, retained_batch_sizes - while len(back_pressure_queue) >= size: + while len(back_pressure_queue) > size: logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) From 83f0616de5880270030d294dc0fb9d09e2ab5960 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 22:39:42 -0800 Subject: [PATCH 129/205] let's reduce requirements some more to see if we can keep everything running. should really solve this correctly --- src/levanter/data/shard_cache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index a45209834..48e59d097 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -321,7 +321,9 @@ def _shard_reader_generator(shard_source: ShardedDataset[T], shard_idx: int, sta # This class is responsible for reading batches from a set of shards, prioritizing earlier # chunks and earlier shards. (So that we approximately generate following the global order.) -@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") +# TODO: it's not great we set this to 0.0, but it's the only way to get it to work with the current +# ray scheduler when we try to index a large number of corpora. We should centralize this a bit better. +@ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") def _alternating_shard_reader( name: str, builder_ref: ActorHandle, # _ChunkCacheBuilder @@ -699,7 +701,7 @@ def task_running(self): # Ray does poorly with large numbers of actors (grumble grumble), so we can't have one actor per shard. # This class wraps a map of shard names to _ShardWriterWorkers, and manages the lifecycle of the workers. -@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") # type: ignore +@ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") # type: ignore class _GroupShardWriterWorker: def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): pylogging.basicConfig(level=pylogging.INFO) From def45cc8a0a800ff231dbc0b72d633f18053755f Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 22:42:14 -0800 Subject: [PATCH 130/205] silly --- src/levanter/data/shard_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 48e59d097..0df7f98d8 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -323,7 +323,7 @@ def _shard_reader_generator(shard_source: ShardedDataset[T], shard_idx: int, sta # chunks and earlier shards. (So that we approximately generate following the global order.) # TODO: it's not great we set this to 0.0, but it's the only way to get it to work with the current # ray scheduler when we try to index a large number of corpora. We should centralize this a bit better. -@ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") +@ray.remote(num_cpus=1.0, scheduling_strategy="SPREAD") def _alternating_shard_reader( name: str, builder_ref: ActorHandle, # _ChunkCacheBuilder @@ -383,7 +383,7 @@ def enqueue_to_backpressure(batch, batch_result_ref): def drain_back_pressure_to(size): nonlocal back_pressure_queue, retained_batch_sizes while len(back_pressure_queue) > size: - logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") + logger.info(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) if len(back_pressure_queue) and logger.level <= pylogging.DEBUG: From de821cae3c2058bbf46c2767452ac13aee6e9af9 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 22:54:14 -0800 Subject: [PATCH 131/205] ok so we're ok maybe --- src/levanter/data/shard_cache.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 0df7f98d8..7d01e190e 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -323,7 +323,7 @@ def _shard_reader_generator(shard_source: ShardedDataset[T], shard_idx: int, sta # chunks and earlier shards. (So that we approximately generate following the global order.) # TODO: it's not great we set this to 0.0, but it's the only way to get it to work with the current # ray scheduler when we try to index a large number of corpora. We should centralize this a bit better. -@ray.remote(num_cpus=1.0, scheduling_strategy="SPREAD") +@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") def _alternating_shard_reader( name: str, builder_ref: ActorHandle, # _ChunkCacheBuilder @@ -383,12 +383,12 @@ def enqueue_to_backpressure(batch, batch_result_ref): def drain_back_pressure_to(size): nonlocal back_pressure_queue, retained_batch_sizes while len(back_pressure_queue) > size: - logger.info(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") + logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) if len(back_pressure_queue) and logger.level <= pylogging.DEBUG: - size = back_pressure_sizes.pop(finished_ref[0]) - retained_batch_sizes -= size + obj_size = back_pressure_sizes.pop(finished_ref[0]) + retained_batch_sizes -= obj_size while len(shard_pqueue) > 0: chunk_id, shard_idx = heapq.heappop(shard_pqueue) @@ -398,9 +398,9 @@ def drain_back_pressure_to(size): exhausted_shard = False - chunk_batch_idx = 0 - chunk_filled = False - total_chunk_rows = 0 + chunk_batch_idx = 0 # the index of the batch within the chunk + chunk_filled = False # whether or not we've filled the chunk to max size + total_chunk_rows = 0 # the total number of rows in the chunk so far while not chunk_filled: batch = next(shard_iter, None) From cbddab85974b6f14f24f4fd698f8e122fb56d765 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 22:58:26 -0800 Subject: [PATCH 132/205] don't fetch local --- src/levanter/data/shard_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 7d01e190e..f876edaca 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -384,7 +384,7 @@ def drain_back_pressure_to(size): nonlocal back_pressure_queue, retained_batch_sizes while len(back_pressure_queue) > size: logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") - finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) + finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1, fetch_local=False) if len(back_pressure_queue) and logger.level <= pylogging.DEBUG: obj_size = back_pressure_sizes.pop(finished_ref[0]) From 5d0f9870f4dc30c5bbf7fcb874cfb2360ece669e Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 23:06:47 -0800 Subject: [PATCH 133/205] wtf --- src/levanter/data/shard_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f876edaca..0b899baf2 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1003,6 +1003,7 @@ def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): def shard_finished(self, shard_name: str, expected_num_chunks: int): """Callback method for when a shard worker has finished.""" shard_status = self.shard_status[shard_name] + assert shard_status.expected_num_chunks is None shard_status.expected_num_chunks = expected_num_chunks # we might still have buffered chunks, so we need to check if we can append them From 13cc5565713e83b4b92dec371facb15c99de6f6e Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 23:12:08 -0800 Subject: [PATCH 134/205] what --- src/levanter/data/shard_cache.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 0b899baf2..69101d708 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -763,6 +763,7 @@ def __init__( if self.metadata_writer.is_finished: logger.info(f"Shard {shard_name} already finished. Skipping.") self._expected_num_chunks = self.metadata_writer.num_chunks + logger.info(f"Shard {shard_name} finished, from {self._expected_num_chunks} chunks, init") self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) self.collator = _ChunkCollator(cache_dir, shard_name) @@ -828,6 +829,9 @@ def _attempt_to_commit_chunks(self): if self._expected_num_chunks is not None and self.metadata_writer.num_chunks == self._expected_num_chunks: self.metadata_writer.finish() + logger.info( + f"Shard {self.shard_name} finished, from {self._expected_num_chunks} chunks, _attempt_to_commit_chunks" + ) self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) @@ -1003,7 +1007,9 @@ def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): def shard_finished(self, shard_name: str, expected_num_chunks: int): """Callback method for when a shard worker has finished.""" shard_status = self.shard_status[shard_name] - assert shard_status.expected_num_chunks is None + assert ( + shard_status.expected_num_chunks is None + ), f"Shard {shard_name} already finished: {shard_status.expected_num_chunks} {expected_num_chunks}" shard_status.expected_num_chunks = expected_num_chunks # we might still have buffered chunks, so we need to check if we can append them From 5afac01ec3ede83a799efe88ff83212d7b79a6cc Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 23:15:50 -0800 Subject: [PATCH 135/205] ok, think we figured it out --- src/levanter/data/shard_cache.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 69101d708..2c3b3f862 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -765,6 +765,9 @@ def __init__( self._expected_num_chunks = self.metadata_writer.num_chunks logger.info(f"Shard {shard_name} finished, from {self._expected_num_chunks} chunks, init") self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) + self.finished = True + else: + self.finished = False self.collator = _ChunkCollator(cache_dir, shard_name) @@ -823,12 +826,15 @@ def _attempt_to_commit_chunks(self): chunks_committed.append(chunk) if len(chunks_committed) > 0: + if self.finished: + raise RuntimeError("Tried to commit chunks after shard finished") # TODO: this is called inside an async call so we need to not block, but we do need to sequence # this to come before the shard_finished self.parent_ref.new_chunk.remote(self.shard_name, *chunks_committed) - if self._expected_num_chunks is not None and self.metadata_writer.num_chunks == self._expected_num_chunks: + if not self.finished and self.metadata_writer.num_chunks == self._expected_num_chunks: self.metadata_writer.finish() + self.finished = True logger.info( f"Shard {self.shard_name} finished, from {self._expected_num_chunks} chunks, _attempt_to_commit_chunks" ) From 41ac3621c2897a6acff01b2eeed6c3be6c681534 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 23:39:58 -0800 Subject: [PATCH 136/205] less logging --- src/levanter/data/shard_cache.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 2c3b3f862..d8d66639a 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -763,7 +763,6 @@ def __init__( if self.metadata_writer.is_finished: logger.info(f"Shard {shard_name} already finished. Skipping.") self._expected_num_chunks = self.metadata_writer.num_chunks - logger.info(f"Shard {shard_name} finished, from {self._expected_num_chunks} chunks, init") self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) self.finished = True else: @@ -835,9 +834,6 @@ def _attempt_to_commit_chunks(self): if not self.finished and self.metadata_writer.num_chunks == self._expected_num_chunks: self.metadata_writer.finish() self.finished = True - logger.info( - f"Shard {self.shard_name} finished, from {self._expected_num_chunks} chunks, _attempt_to_commit_chunks" - ) self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) @@ -1041,17 +1037,17 @@ def other_failed(self, error: ExceptionInfo): def _attempt_to_flush_buffers(self): # this is the most complex logic in this class. - # The global order on chunks is defined as "roundrobin" over shards, until one shard is done. - # after that, that shard is removed from the roundrobin and the process continues. - # roundrobin order is determined by self.source.shard_names - - # we are happy to release chunks that form a prefix of the global order so that they can be read - # to do that, we maintain the roundrobin order in self._current_round_robin - # and we maintain the current buffer for each shard in self.shard_status - # when we get a new chunk, we append it to the buffer for that shard - # when we get a finished message, we mark that shard as finished - # in either case, we check if we can send any chunks from the front of the roundrobin - # if we can, we send them to the broker + # The global order on chunks is defined as a roundrobin over shards, until one shard is done. + # After that, that shard is removed from the roundrobin and the process continues. + # Roundrobin order is determined by self.source.shard_names + + # We are happy to release chunks that form a prefix of the global order so that they can be read. + # To do that, we maintain the roundrobin order in self._current_round_robin + # and we maintain the current buffer for each shard in self.shard_status. + # When we get a new chunk, we append it to the buffer for that shard. + # When we get a finished message, we mark that shard as finished. + # In either case, we check if we can send any chunks from the front of the roundrobin. + # If we can, we send them to the broker # here "finished" means that the shard has sent all of its chunks and has told us that it's done. From c621a08c8eec0fdb51a99d9397094b8b661370bd Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 16 Jan 2024 22:13:33 -0800 Subject: [PATCH 137/205] toward turning the reader process into an actor too --- src/levanter/data/shard_cache.py | 272 ++++++++++++++++++++----------- 1 file changed, 179 insertions(+), 93 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index d8d66639a..62aefbf84 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -33,7 +33,6 @@ import levanter.tracker from .. import logging -from ..utils.py_utils import actual_sizeof from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info from . import ShardableDataset from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch @@ -319,132 +318,219 @@ def _shard_reader_generator(shard_source: ShardedDataset[T], shard_idx: int, sta yield batch -# This class is responsible for reading batches from a set of shards, prioritizing earlier -# chunks and earlier shards. (So that we approximately generate following the global order.) -# TODO: it's not great we set this to 0.0, but it's the only way to get it to work with the current -# ray scheduler when we try to index a large number of corpora. We should centralize this a bit better. -@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") -def _alternating_shard_reader( - name: str, - builder_ref: ActorHandle, # _ChunkCacheBuilder - shard_writers: ActorHandle, # _GroupedShardWriter - shard_source: ShardedDataset[T], - shard_names: Sequence[str], - priority_fn: Callable[[int, int], float], - processor_actor: ActorHandle, # BatchProcessorQueue - batch_size, - num_rows_per_chunk, -): - pylogging.basicConfig(level=pylogging.INFO) - logger = pylogging.getLogger(f"shard_reader.{name}") - shard_pqueue: list[tuple[int, int]] = [] # heapq of (num_chunks, shard_idx) - shard_readers: dict[int, Iterator[list[T]]] = {} - try: - shard_metadatas = _initial_shard_metadatas(shard_source, shard_names, shard_writers) - except Exception as e: - builder_ref.other_failed.remote(ser_exc_info()) - raise e +class PriorityWorkTaskGroupSpec(Protocol): + name: str + + def build(self) -> "PriorityWorkTaskGroup": + raise NotImplementedError() + + +class PriorityWorkTaskGroup(Protocol): + name: str + + def items(self) -> Sequence["PriorityWorkItem"]: + raise NotImplementedError() - batch_size = min(batch_size, num_rows_per_chunk) - for shard_name in shard_names: - shard_idx = shard_source.shard_names.index(shard_name) +class PriorityWorkItem(Protocol): + name: str + priority: float + + # TODO: bring back backpressure here? + def execute(self) -> bool: + """ + Returns true if the item is finished, false if it should be rescheduled. + """ + raise NotImplementedError() + + # needs to be sortable by priority + def __lt__(self, other: "PriorityWorkItem"): + if self.priority == other.priority: + return self.name < other.name + else: + return self.priority < other.priority + + def __le__(self, other: "PriorityWorkItem"): + if self.priority == other.priority: + return self.name <= other.name + else: + return self.priority <= other.priority + + +@dataclass +class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): + name: str + builder_ref: ray.actor.ActorHandle # _ChunkCacheBuilder + writer: ray.actor.ActorHandle # _GroupedShardWriter + shard_source: ShardedDataset + shard_names: Sequence[str] + priority_fn: Callable[[int, int], float] + processor_actor: ray.actor.ActorHandle # BatchProcessorQueue + batch_size: int + num_rows_per_chunk: int + + def build(self) -> "PriorityWorkTaskGroup": + return ShardGroupTaskGroup(self) + + +class ShardGroupTaskGroup(PriorityWorkTaskGroup): + def __init__(self, spec: ShardGroupToBeProcessed): + self.spec = spec + self.logger = pylogging.getLogger(f"shard_reader.{self.spec.name}") + try: - shard_metadata = shard_metadatas[shard_name] - heapq.heappush(shard_pqueue, (len(shard_metadata.chunks), shard_idx)) - shard_readers[shard_idx] = _shard_reader_generator( - shard_source, shard_idx, shard_metadata.total_rows, batch_size + metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( + self.spec.shard_source, self.spec.shard_names, self.spec.writer ) - except Exception as e: # noqa - logger.exception(f"Error while initializing shard {shard_name}") - # fire and forget - shard_writers[shard_name].shard_failed.remote(ser_exc_info()) + except Exception as e: + self.spec.builder_ref.other_failed.remote(ser_exc_info()) raise e - MAX_INFLIGHT = 10 - back_pressure_queue: list[ray.ObjectRef] = [] - back_pressure_sizes: dict[ray.ObjectRef, int] = {} - retained_batch_sizes = 0 - - def enqueue_to_backpressure(batch, batch_result_ref): - nonlocal back_pressure_queue, retained_batch_sizes - back_pressure_queue.append(batch_result_ref) - if logger.level <= pylogging.DEBUG: - size = actual_sizeof(batch) - retained_batch_sizes += size - back_pressure_sizes[batch_result_ref] = size - if retained_batch_sizes > 1024 * 1024 * 1024: - logger.debug(f"Retained batch sizes: {retained_batch_sizes}") - - # we want to limit the number of pending tasks, so we wait until we're below the limit - # before we start reading the next batch - drain_back_pressure_to(MAX_INFLIGHT) - - def drain_back_pressure_to(size): - nonlocal back_pressure_queue, retained_batch_sizes - while len(back_pressure_queue) > size: - logger.debug(f"Waiting for back pressure queue to drain: {len(back_pressure_queue)}") - finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1, fetch_local=False) - - if len(back_pressure_queue) and logger.level <= pylogging.DEBUG: - obj_size = back_pressure_sizes.pop(finished_ref[0]) - retained_batch_sizes -= obj_size + batch_size = min(self.spec.batch_size, self.spec.num_rows_per_chunk) - while len(shard_pqueue) > 0: - chunk_id, shard_idx = heapq.heappop(shard_pqueue) - shard_name = shard_source.shard_names[shard_idx] - try: - shard_iter = shard_readers[shard_idx] + self._items: list[PriorityWorkItem] = [] - exhausted_shard = False + for shard_name in self.spec.shard_names: + shard_idx = self.spec.shard_source.shard_names.index(shard_name) + try: + shard_metadata = metadata[shard_name] + reader = _shard_reader_generator( + self.spec.shard_source, shard_idx, shard_metadata.total_rows, batch_size + ) + + if shard_metadata.is_finished: + self.logger.info(f"Shard {shard_name} already finished. Skipping.") + + task_name = f"shard_reader.{self.spec.name}.{shard_name}" + + chunk_idx = len(shard_metadata.chunks) + item = ShardReaderItem(self, task_name, shard_name, shard_idx, chunk_idx, reader) + + heapq.heappush(self._items, item) + except Exception as e: + self.logger.exception(f"Error while initializing shard {shard_name}") + self.spec.writer[shard_name].shard_failed.remote(ser_exc_info()) + raise e + + @property + def name(self): + return self.spec.name + + def items(self) -> Sequence["PriorityWorkItem"]: + return self._items - chunk_batch_idx = 0 # the index of the batch within the chunk - chunk_filled = False # whether or not we've filled the chunk to max size - total_chunk_rows = 0 # the total number of rows in the chunk so far +# NB This class is stateful +@dataclass +class ShardReaderItem(PriorityWorkItem): + group: ShardGroupTaskGroup + name: str + shard_name: str + shard_idx: int + chunk_idx: int + reader: Iterator[list] + + @property + def priority(self): + return self.group.spec.priority_fn(self.shard_idx, self.chunk_idx) + + @property + def spec(self): + return self.group.spec + + def execute(self) -> bool: + exhausted_shard = False + writer = self.spec.writer + + chunk_batch_idx = 0 # the index of the batch within the chunk + chunk_filled = False # whether or not we've filled the chunk to max size + total_chunk_rows = 0 # the total number of rows in the chunk + + try: while not chunk_filled: - batch = next(shard_iter, None) + batch = next(self.reader, None) if batch is None: exhausted_shard = True break - exhausted_shard = len(batch) < batch_size + exhausted_shard = len(batch) < self.spec.batch_size total_chunk_rows += len(batch) if batch: - priority = priority_fn(shard_idx, chunk_id) + priority = self.spec.priority_fn(self.shard_idx, self.chunk_idx) batch_result_ref = ray.get( - processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) + self.spec.processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) ) - shard_writers.chunk_batch_finished.remote( - shard_name, chunk_id, chunk_batch_idx, RefBox(batch_result_ref) + writer.chunk_batch_finished.remote( + self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref) ) chunk_batch_idx += 1 - enqueue_to_backpressure(batch, batch_result_ref) + # enqueue_to_backpressure(batch, batch_result_ref) del batch - if total_chunk_rows >= num_rows_per_chunk or exhausted_shard: + if total_chunk_rows >= self.spec.num_rows_per_chunk or exhausted_shard: chunk_filled = True if chunk_batch_idx > 0: - shard_writers.chunk_finished_reading.remote(shard_name, chunk_id, chunk_batch_idx) - chunk_id += 1 + writer.chunk_finished_reading.remote(self.shard_name, self.chunk_idx, chunk_batch_idx) + old_prio = self.priority + self.chunk_idx += 1 + assert self.priority > old_prio if exhausted_shard: - shard_writers.shard_finished_reading.remote(shard_name, chunk_id) - del shard_readers[shard_idx] - del shard_metadatas[shard_name] - else: - # we're not done with this shard, so put it back in the queue - heapq.heappush(shard_pqueue, (chunk_id, shard_idx)) + writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) + # del shard_readers[shard_idx] + # del shard_metadatas[self.shard_name] + return exhausted_shard except Exception as e: # noqa - logger.exception(f"Error while processing shard {shard_name}") + self.group.logger.exception(f"Error while processing shard {self.shard_name}") # fire and forget - shard_writers.shard_failed.remote(shard_name, ser_exc_info()) + writer.shard_failed.remote(self.shard_name, ser_exc_info()) raise e - drain_back_pressure_to(0) + +# This class is responsible for reading batches from a set of shards, prioritizing earlier +# chunks and earlier shards. (So that we approximately generate following the global order.) +# TODO: it's not great we set this to 0.0, but it's the only way to get it to work with the current +# ray scheduler when we try to index a large number of corpora. We should centralize this a bit better. +@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") +def _alternating_shard_reader( + name: str, + builder_ref: ActorHandle, # _ChunkCacheBuilder + shard_writers: ActorHandle, # _GroupedShardWriter + shard_source: ShardedDataset[T], + shard_names: Sequence[str], + priority_fn: Callable[[int, int], float], + processor_actor: ActorHandle, # BatchProcessorQueue + batch_size, + num_rows_per_chunk, +): + + group = ShardGroupToBeProcessed( + name=name, + builder_ref=builder_ref, + writer=shard_writers, + shard_source=shard_source, + shard_names=shard_names, + priority_fn=priority_fn, + processor_actor=processor_actor, + batch_size=batch_size, + num_rows_per_chunk=num_rows_per_chunk, + ).build() + pylogging.basicConfig(level=pylogging.INFO) + # logger = pylogging.getLogger(f"shard_reader.{name}") + shard_pqueue: list[PriorityWorkItem] = list(group.items()) + + while len(shard_pqueue) > 0: + item = heapq.heappop(shard_pqueue) + + if item.execute(): + # we're done, do nothing + pass + else: + # we're not done, put it back in the queue + heapq.heappush(shard_pqueue, item) def _initial_shard_metadatas(shard_source, shard_names, shard_group_writer): From 4d92af9c1f4f98d63fc38478dcdfde0280713d0f Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 16 Jan 2024 23:23:56 -0800 Subject: [PATCH 138/205] did we do it? --- src/levanter/data/shard_cache.py | 199 +++++++++++++++++++++---------- 1 file changed, 138 insertions(+), 61 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 62aefbf84..67b098b38 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -327,6 +327,7 @@ def build(self) -> "PriorityWorkTaskGroup": class PriorityWorkTaskGroup(Protocol): name: str + spec: PriorityWorkTaskGroupSpec def items(self) -> Sequence["PriorityWorkItem"]: raise NotImplementedError() @@ -335,11 +336,13 @@ def items(self) -> Sequence["PriorityWorkItem"]: class PriorityWorkItem(Protocol): name: str priority: float + spec: PriorityWorkTaskGroupSpec - # TODO: bring back backpressure here? - def execute(self) -> bool: + def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: """ Returns true if the item is finished, false if it should be rescheduled. + The object ref is used (1) to block shutting down the actor too early + and (2) for backpressure. """ raise NotImplementedError() @@ -357,6 +360,95 @@ def __le__(self, other: "PriorityWorkItem"): return self.priority <= other.priority +@ray.remote(num_cpus=1, scheduling_strategy="SPREAD") +class PriorityProcessorActor: + def __init__(self, max_in_flight: Optional[int] = 200): + pylogging.basicConfig(level=pylogging.INFO) + self._queue: list[PriorityWorkItem] = [] # heapq + self._queue_lock = threading.Lock() + self._shutdown_event = threading.Event() + self._current_item: Optional[PriorityWorkItem] = None + self._max_in_flight = max_in_flight + + self._processing_thread = threading.Thread(target=self._loop, daemon=True) + self._processing_thread.start() + + def add_work_group(self, group: PriorityWorkTaskGroupSpec): + items = group.build().items() + with self._queue_lock: + for item in items: + heapq.heappush(self._queue, item) + + def is_group_finished(self, group: PriorityWorkTaskGroupSpec): + with self._queue_lock: + if any(item.spec == group for item in self._queue): + return False + + if self._current_item is not None and self._current_item.spec == group: + return False + + logger.info(f"Group {group.name} is finished.") + + return True + + def cancel_work_group(self, group: PriorityWorkTaskGroupSpec): + # kill all the items in the group + with self._queue_lock: + self._queue = [item for item in self._queue if item.spec != group] + heapq.heapify(self._queue) + + def shutdown(self): + if not self._shutdown_event.is_set(): + self._shutdown_event.set() + + if self._processing_thread.is_alive(): + self._processing_thread.join() + + def _loop(self): + should_sleep = False + backpressure_queue: list[ray.ObjectRef] = [] + + def drain_backpressure_to(count): + nonlocal backpressure_queue + while len(backpressure_queue) > count: + finished, remaining = ray.wait(backpressure_queue, num_returns=1, fetch_local=False) + backpressure_queue = remaining + + while not self._shutdown_event.is_set(): + if should_sleep: + time.sleep(0.1) + + drain_backpressure_to(self._max_in_flight) + + with self._queue_lock: + if len(self._queue) == 0: + should_sleep = True + continue + else: + should_sleep = False + + item = heapq.heappop(self._queue) + self._current_item = item + + try: + item_is_finished, ref = item.execute() + if ref is not None: + backpressure_queue.append(ref) + except Exception: + logger.exception(f"Error while processing {item.name}. Killing all associated work.") + self.cancel_work_group(item.spec) + continue + + with self._queue_lock: + self._current_item = None + if not item_is_finished: + heapq.heappush(self._queue, item) + + logger.info("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") + drain_backpressure_to(0) + logger.info("Backpressure drained. Shutting down PriorityProcessorActor.") + + @dataclass class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): name: str @@ -423,6 +515,11 @@ def items(self) -> Sequence["PriorityWorkItem"]: # NB This class is stateful @dataclass class ShardReaderItem(PriorityWorkItem): + """ + Each time execute is called, this class reads one chunk's worth of batches from the shard + and dispatches them to the processor. + """ + group: ShardGroupTaskGroup name: str shard_name: str @@ -438,13 +535,14 @@ def priority(self): def spec(self): return self.group.spec - def execute(self) -> bool: + def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: exhausted_shard = False writer = self.spec.writer chunk_batch_idx = 0 # the index of the batch within the chunk chunk_filled = False # whether or not we've filled the chunk to max size total_chunk_rows = 0 # the total number of rows in the chunk + batch_result_ref = None try: while not chunk_filled: @@ -479,10 +577,10 @@ def execute(self) -> bool: if exhausted_shard: writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - # del shard_readers[shard_idx] - # del shard_metadatas[self.shard_name] - return exhausted_shard + logger.debug(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") + + return exhausted_shard, batch_result_ref except Exception as e: # noqa self.group.logger.exception(f"Error while processing shard {self.shard_name}") # fire and forget @@ -490,49 +588,6 @@ def execute(self) -> bool: raise e -# This class is responsible for reading batches from a set of shards, prioritizing earlier -# chunks and earlier shards. (So that we approximately generate following the global order.) -# TODO: it's not great we set this to 0.0, but it's the only way to get it to work with the current -# ray scheduler when we try to index a large number of corpora. We should centralize this a bit better. -@ray.remote(num_cpus=0.25, scheduling_strategy="SPREAD") -def _alternating_shard_reader( - name: str, - builder_ref: ActorHandle, # _ChunkCacheBuilder - shard_writers: ActorHandle, # _GroupedShardWriter - shard_source: ShardedDataset[T], - shard_names: Sequence[str], - priority_fn: Callable[[int, int], float], - processor_actor: ActorHandle, # BatchProcessorQueue - batch_size, - num_rows_per_chunk, -): - - group = ShardGroupToBeProcessed( - name=name, - builder_ref=builder_ref, - writer=shard_writers, - shard_source=shard_source, - shard_names=shard_names, - priority_fn=priority_fn, - processor_actor=processor_actor, - batch_size=batch_size, - num_rows_per_chunk=num_rows_per_chunk, - ).build() - pylogging.basicConfig(level=pylogging.INFO) - # logger = pylogging.getLogger(f"shard_reader.{name}") - shard_pqueue: list[PriorityWorkItem] = list(group.items()) - - while len(shard_pqueue) > 0: - item = heapq.heappop(shard_pqueue) - - if item.execute(): - # we're done, do nothing - pass - else: - # we're not done, put it back in the queue - heapq.heappush(shard_pqueue, item) - - def _initial_shard_metadatas(shard_source, shard_names, shard_group_writer): shard_metadatas: dict[str, ShardMetadata] = {} _metadata_futures = [shard_group_writer.current_metadata.remote(name) for name in shard_names] @@ -1055,25 +1110,47 @@ def priority_fn(shard_idx, chunk_idx): self._current_round_robin.append(shard_name) shard_groups[i % num_shard_groups].append(shard_name) - for shard_group in shard_groups: + for group_id, shard_group in enumerate(shard_groups): writer = _GroupShardWriterWorker.remote(self_ref, cache_dir, shard_group) # type: ignore self._shard_writers.append(writer) + # TODO: would probably be better if we didn't create one of these per shard group processor_actor = _BatchProcessorQueue.remote(processor) # type: ignore self._processor_actors.append(processor_actor) - reader = _alternating_shard_reader.remote( - name, - self_ref, - writer, - source, - shard_group, - priority_fn, - processor_actor, - processor.batch_size, - rows_per_chunk, + work_item = ShardGroupToBeProcessed( + name=name, + builder_ref=self_ref, + writer=writer, + shard_source=source, + shard_names=shard_group, + priority_fn=priority_fn, + processor_actor=processor_actor, + batch_size=processor.batch_size, + num_rows_per_chunk=rows_per_chunk, ) - self._shard_readers.append(reader) + + # we want global names so that different tasks can coordinate priorities + priority_actor_name = f"priority_processor.{group_id}" + + reader_actor = PriorityProcessorActor.options( # type: ignore + name=priority_actor_name, get_if_exists=True + ).remote() + + ray.get(reader_actor.add_work_group.remote(work_item)) + + # reader = _alternating_shard_reader.remote( + # name, + # self_ref, + # writer, + # source, + # shard_group, + # priority_fn, + # processor_actor, + # processor.batch_size, + # rows_per_chunk, + # ) + self._shard_readers.append(reader_actor) def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): """Callback method for when a shard worker has produced a new chunk.""" From 257dfa7ddbffe8ae86f216f277a5d6701e6b696d Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 16 Jan 2024 23:34:49 -0800 Subject: [PATCH 139/205] wandb: only force a step if commit is true --- src/levanter/tracker/wandb.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 83866656e..28416389d 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -49,6 +49,11 @@ def log_hyperparameters(self, hparams: dict[str, Any]): self.run.config.update(hparams, allow_val_change=True) def log(self, metrics: dict[str, Any], *, step, commit=None): + if step is None: + if commit is False: + step = self.run.step + else: + step = None self.run.log(metrics, step=step, commit=commit) def log_summary(self, metrics: dict[str, Any]): From 23865a14305c608a7a08db62626c0856e5afa220 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 16 Jan 2024 23:35:47 -0800 Subject: [PATCH 140/205] don't crash if n == 0 --- src/levanter/callbacks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 409b235f9..0c462a997 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -59,8 +59,8 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n - logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") - logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") + logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") + logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") return total_loss From 70c00f129fd39bf1d97f04af54569ce3dfaf942c Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 16 Jan 2024 23:49:05 -0800 Subject: [PATCH 141/205] wandb: maybe this gives the behavior i want? --- src/levanter/tracker/wandb.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 28416389d..d217ab000 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -49,11 +49,9 @@ def log_hyperparameters(self, hparams: dict[str, Any]): self.run.config.update(hparams, allow_val_change=True) def log(self, metrics: dict[str, Any], *, step, commit=None): - if step is None: - if commit is False: - step = self.run.step - else: - step = None + if step is None and not commit: + step = self.run.step + self.run.log(metrics, step=step, commit=commit) def log_summary(self, metrics: dict[str, Any]): From 4c5436536cf956bde8843f1de5004f69567bef9f Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 00:01:41 -0800 Subject: [PATCH 142/205] mklafmlkafml --- src/levanter/tracker/tracker.py | 2 +- src/levanter/tracker/wandb.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 62668b70f..8b6816f17 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -27,7 +27,7 @@ def log_hyperparameters(self, hparams: dict[str, Any]): pass @abc.abstractmethod - def log(self, metrics: dict[str, typing.Any], *, step, commit: Optional[bool] = None): + def log(self, metrics: dict[str, typing.Any], *, step: Optional[int], commit: Optional[bool] = None): """ Log metrics to the tracker. Step is always required. diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index d217ab000..723b9618c 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -50,6 +50,7 @@ def log_hyperparameters(self, hparams: dict[str, Any]): def log(self, metrics: dict[str, Any], *, step, commit=None): if step is None and not commit: + print("WARNING: logging metrics without step or commit. Inferring") step = self.run.step self.run.log(metrics, step=step, commit=commit) From 6148381de3864e00d6fc181629498f14cbcb1772 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 15 Jan 2024 15:38:41 -0800 Subject: [PATCH 143/205] minimize use of optax internals --- src/levanter/optim/second_order.py | 14 ++++++++----- src/levanter/optim/sophia.py | 33 ++++++++++++++++++++++-------- src/levanter/tracker/helpers.py | 10 ++++++--- 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/levanter/optim/second_order.py b/src/levanter/optim/second_order.py index fd0da7325..f262980c5 100644 --- a/src/levanter/optim/second_order.py +++ b/src/levanter/optim/second_order.py @@ -7,8 +7,7 @@ import jax import optax from jax import numpy as jnp -from optax._src import numerics -from optax._src.schedule import InjectHyperparamsState, _convert_floats +from optax import InjectHyperparamsState class HessianUpdateFn(typing.Protocol): @@ -189,10 +188,8 @@ def update_fn(updates, state, params=None): hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()} hparams.update(schedule_fn(state.count, dtype)) updates, inner_state = inner_factory(**other_hps, **hparams).update(updates, state.inner_state, params) - count_inc = numerics.safe_int32_increment(state.count) - # pylint:disable=too-many-function-args - return updates, InjectHyperparamsState(count_inc, hparams, inner_state) + return updates, InjectHyperparamsState(state.count + 1, hparams, inner_state) # pylint:enable=too-many-function-args def _find_first_floating_dtype(updates): @@ -226,3 +223,10 @@ def update_hessian(state, fn, model, *batch, **batch_kwargs): return SecondOrderTransformation(init_fn, update_fn, update_hessian) return wrapped_transform + + +def _convert_floats(x, dtype): + """Convert float-like inputs to dtype, rest pass through.""" + if jax.dtypes.scalar_type_of(x) == float: + return jnp.asarray(x, dtype=dtype) + return x diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 9506e9eb7..a73f55329 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -1,4 +1,5 @@ import abc +import functools import typing from dataclasses import dataclass from typing import Any, NamedTuple, Optional, TypeVar, runtime_checkable @@ -11,10 +12,6 @@ from jax.random import PRNGKey from jaxtyping import PRNGKeyArray -# TODO: remove dependency on _src internals -from optax._src import numerics -from optax._src.transform import bias_correction, update_moment - import levanter.tracker from levanter.optim.config import HessianOptConfig, OptimizerConfig from levanter.optim.second_order import SecondOrderTransformation, chain_second_order, inject_hyperparams @@ -294,8 +291,7 @@ def init_fn(params): def update_fn(updates, state, params=None): mu = update_moment(updates, state.mu, b1, 1) # nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) - count_inc = numerics.safe_int32_increment(state.count) - mu_hat = bias_correction(mu, b1, count_inc) + mu_hat = bias_correction(mu, b1, state.count + 1) h_hat = state.h # track how often hessian is used mu_leaves = jax.tree_util.tree_leaves(mu_hat) @@ -328,7 +324,7 @@ def update_fn(updates, state, params=None): mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) return updates, ScaleBySophiaState( - count=count_inc, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key + count=state.count + 1, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key ) def update_hessian(state, fn, model, *batch, **batch_kwargs): @@ -338,10 +334,9 @@ def _do_update(): # new_hess = jax.tree_util.tree_map(lambda h: jnp.clip(h, -1, 1), new_hess) # EMAs of hessian - hessian_count_inc = numerics.safe_int32_increment(state.hessian_count) nu = update_moment(new_hess, state.h, b2, 1) return ScaleBySophiaState( - count=state.count, hessian_count=hessian_count_inc, mu=state.mu, h=nu, hess_key=next_key + count=state.count, hessian_count=state.hessian_count + 1, mu=state.mu, h=nu, hess_key=next_key ) def _dont_update(): @@ -410,3 +405,23 @@ def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs): hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g) return hessian + + +# Cribbed from optax._src.transform +def update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order`-th moment.""" + return jax.tree_util.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +@functools.partial(jax.jit, inline=True) +def bias_correction(moment, decay, count): + """Performs bias correction. It becomes a no-op as count goes to infinity.""" + # The conversion to the data type of the moment ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_correction_` is calculated as calculating `decay**count` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + bias_correction_ = 1 - decay**count + + # Perform division in the original precision. + return jax.tree_util.tree_map(lambda t: t / bias_correction_.astype(t.dtype), moment) diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py index 31131d1ac..1091840c5 100644 --- a/src/levanter/tracker/helpers.py +++ b/src/levanter/tracker/helpers.py @@ -4,7 +4,6 @@ from typing import Optional from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax._src.wrappers import MultiStepsState import levanter.tracker from levanter.utils.jax_utils import jnp_to_python @@ -14,8 +13,13 @@ def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state + try: + from optax._src.wrappers import MultiStepsState + + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + except ImportError: + pass def wrap_key(key): if prefix: From 8274cad619f331c836540d7174c416aa13beed32 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 21:56:51 -0800 Subject: [PATCH 144/205] what --- src/levanter/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 0d428e812..c70afb294 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -199,7 +199,7 @@ def loss_fn(self): Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute """ - @named_jit(axis_resources=self.compute_axis_mapping) + @named_jit(axis_resources=self.parameter_axis_mapping) @functools.wraps(self._raw_loss_function) def fn(model, *batch, **batch_kwargs): with hax.axis_mapping(self.compute_axis_mapping): From b980c9f9883f2e7c7ec3a68acb3df0f3a4c60ee2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 22:05:20 -0800 Subject: [PATCH 145/205] actually this is probably better --- src/levanter/doremi.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 10de53635..e35c753d9 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -114,8 +114,14 @@ def estimate_mixture_weights( if validation_sets is not None: for domain, dataset in validation_sets.items(): + + @eqx.filter_jit + def eval_loss(model, *batch, **batch_kwargs): + model = inference_mode(model, True) + return trainer.loss_fn(model, *batch, **batch_kwargs, key=None) + loss = eval_loss_loop( - trainer.loss_fn, + eval_loss, ref, trainer.replicated_loader(dataset, trainer.EvalBatch), name=f"ref {domain}", From 36f25a048aec0c7ba689e9c8964c0123abb7ac0e Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 22:05:31 -0800 Subject: [PATCH 146/205] actually this is probably better --- src/levanter/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index c70afb294..0d428e812 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -199,7 +199,7 @@ def loss_fn(self): Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute """ - @named_jit(axis_resources=self.parameter_axis_mapping) + @named_jit(axis_resources=self.compute_axis_mapping) @functools.wraps(self._raw_loss_function) def fn(model, *batch, **batch_kwargs): with hax.axis_mapping(self.compute_axis_mapping): From 4da71124260f3e7c6b4f3ca5bf8b14a064f56ee4 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 22:25:22 -0800 Subject: [PATCH 147/205] dumb --- src/levanter/doremi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index e35c753d9..0d0e9eef8 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -184,8 +184,8 @@ def doremi_step(state: DoremiState, ref, batch, domains): levanter.tracker.jit_log_metrics( { - "change_in_alpha": alpha_distance, - "alpha_distance_from_uniform": distance_from_uniform, + "change_in_alpha": alpha_distance.scalar(), + "alpha_distance_from_uniform": distance_from_uniform.scalar(), "alpha": alpha_dict, }, step=state._step, From 614752046aa36e0b06d52bfadf323b0d68387482 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 22:25:40 -0800 Subject: [PATCH 148/205] mkladmlkad --- src/levanter/doremi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 0d0e9eef8..54ce7607c 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -235,7 +235,7 @@ def doremi_step(state: DoremiState, ref, batch, domains): def _alpha_weights_to_dict(Domain, alpha, domain_name_to_index): - final_weights = {domain: alpha[Domain, index] for domain, index in domain_name_to_index.items()} + final_weights = {domain: alpha[Domain, index].scalar() for domain, index in domain_name_to_index.items()} return final_weights From e166a7813d9e590a8ee6b40e72558a8393177180 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 22:27:47 -0800 Subject: [PATCH 149/205] fix key order for doremi --- src/levanter/doremi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 54ce7607c..27488d080 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -128,7 +128,7 @@ def eval_loss(model, *batch, **batch_kwargs): max_batches=trainer_config.max_eval_batches, ) print(f"Loss of ref model on domain {domain}: {loss:.3f}") - levanter.tracker.log_summary({f"eval/ref/loss/{domain}": loss}) + levanter.tracker.log_summary({f"eval/ref/{domain}/loss": loss}) if validation_sets is not None: for domain, dataset in validation_sets.items(): From e6b581bd0f2a4434eb07be9500f1d6a5c4091a10 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 22:46:21 -0800 Subject: [PATCH 150/205] remove excess log --- src/levanter/tracker/wandb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 723b9618c..d217ab000 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -50,7 +50,6 @@ def log_hyperparameters(self, hparams: dict[str, Any]): def log(self, metrics: dict[str, Any], *, step, commit=None): if step is None and not commit: - print("WARNING: logging metrics without step or commit. Inferring") step = self.run.step self.run.log(metrics, step=step, commit=commit) From 8c64be5b26173104f4deee07caca4b09da7a60ef Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 17 Jan 2024 22:48:27 -0800 Subject: [PATCH 151/205] remove a redundant log message --- src/levanter/data/shard_cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 67b098b38..f3565665b 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -902,7 +902,6 @@ def __init__( self.parent_ref.new_chunk.remote(shard_name, *self.metadata_writer.chunks) if self.metadata_writer.is_finished: - logger.info(f"Shard {shard_name} already finished. Skipping.") self._expected_num_chunks = self.metadata_writer.num_chunks self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) self.finished = True From e89e7093ab840656f81ba16dd1e373331b0516ce Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 18 Jan 2024 10:04:06 -0800 Subject: [PATCH 152/205] fixed more bugs --- src/levanter/doremi.py | 38 ++++++++++++++++------------------ src/levanter/main/doremi_lm.py | 1 - tests/test_doremi.py | 1 - 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 27488d080..40da82bbb 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -57,7 +57,6 @@ class DoReMiConfig: domain_weight_step_size: float = 1.0 smoothing: float = 1e-3 sampling_weights: Optional[dict[str, float]] = None - weight_change_eps: float = 1e-3 DEFAULT_DOREMI_TRAINER_CONFIG = TrainerConfig( @@ -78,7 +77,6 @@ def estimate_mixture_weights( loss_fn: ComputeLossFunction[M, T] = ModuleComputeLoss(), domain_weight_step_size: float = 1.0, smoothing: float = 1e-3, - weight_change_eps: float = 1e-3, key: PRNGKeyArray, ) -> dict[str, float]: """ @@ -113,13 +111,13 @@ def estimate_mixture_weights( ref = _prepare_ref_model(ref, trainer_config) if validation_sets is not None: - for domain, dataset in validation_sets.items(): - @eqx.filter_jit - def eval_loss(model, *batch, **batch_kwargs): - model = inference_mode(model, True) - return trainer.loss_fn(model, *batch, **batch_kwargs, key=None) + @eqx.filter_jit + def eval_loss(model, *batch, **batch_kwargs): + model = inference_mode(model, True) + return trainer.loss_fn(model, *batch, **batch_kwargs, key=None) + for domain, dataset in validation_sets.items(): loss = eval_loss_loop( eval_loss, ref, @@ -163,7 +161,9 @@ def doremi_step(state: DoremiState, ref, batch, domains): clipped_losses = hax.maximum(excess_losses, 0) - per_domain_losses = _compute_per_domain_losses(Domain, domains, clipped_losses) + mean_excess_loss = hax.mean(excess_losses, axis=None).scalar() + + per_domain_losses = _compute_per_domain_losses(trainer.TrainBatch, Domain, domains, clipped_losses) # Update domain weights alpha = state.alpha * hax.exp(domain_weight_step_size * per_domain_losses) @@ -186,7 +186,8 @@ def doremi_step(state: DoremiState, ref, batch, domains): { "change_in_alpha": alpha_distance.scalar(), "alpha_distance_from_uniform": distance_from_uniform.scalar(), - "alpha": alpha_dict, + **{f"alpha/{domain}": weight for domain, weight in alpha_dict.items()}, + "mean_excess_loss": mean_excess_loss, }, step=state._step, ) @@ -219,11 +220,6 @@ def doremi_step(state: DoremiState, ref, batch, domains): trainer.run_hooks(new_info) - # check convergence for alphas - if alpha_distance.item() < weight_change_eps: - logger.info(f"Converged on alpha at step {state.step}: {alpha_distance:.4f}") - break - trainer.run_hooks(new_info, force=True) alpha = state.average_alpha @@ -287,15 +283,17 @@ def __iter__(self) -> Iterator[Tuple[T, hax.NamedArray]]: yield item, self.domain_index -def _compute_per_domain_losses(Domain, domains, losses): +def _compute_per_domain_losses(Batch, Domain, domains, losses): # TODO: this should weight by masked tokens one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - return hax.mean(losses.broadcast_axis(Domain) * one_hot_domains, axis=losses.axes) - # per_domain_losses = hax.dot(losses.axes, one_hot_domains, losses) - # return per_domain_losses + # return hax.mean(losses * one_hot_domains, axis=losses.axes) + per_domain_losses = hax.dot(one_hot_domains, losses, axis=losses.axes) + norm = hax.maximum(hax.dot(one_hot_domains, losses != 0, axis=losses.axes), 1) + return per_domain_losses / norm def _domain_weighted_loss(losses, Domain, domains, alpha): one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - return hax.mean(losses.broadcast_axis(Domain) * one_hot_domains * alpha, axis=None).scalar() - # return hax.dot(alpha, one_hot_domains, losses, axis=None) + # return hax.mean(losses * one_hot_domains * alpha, axis=None).scalar() + total = hax.dot(alpha, one_hot_domains, losses, axis=None) + return total.scalar() / losses.size diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py index 5f9fa7862..850a2c730 100644 --- a/src/levanter/main/doremi_lm.py +++ b/src/levanter/main/doremi_lm.py @@ -121,7 +121,6 @@ def init_proxy_model(): data_sources=train_datasets, trainer_config=config.trainer, optimizer=optimizer, - weight_change_eps=config.doremi.weight_change_eps, domain_weight_step_size=config.doremi.domain_weight_step_size, sampling_weights=config.doremi.sampling_weights, validation_sets=valid_datasets, diff --git a/tests/test_doremi.py b/tests/test_doremi.py index 382389559..c6ac76a47 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -141,7 +141,6 @@ def init_model(): initial_proxy=init_model(), ref=ref_model, data_sources=datasets, - weight_change_eps=1e-4, trainer_config=tiny_trainer_config, key=next(keys), loss_fn=compute_loss_fn, From 33600fd2163723bcb8ae1c1cfc301ca69bce86eb Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 18 Jan 2024 11:50:28 -0800 Subject: [PATCH 153/205] almost there --- src/levanter/doremi.py | 80 ++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 40da82bbb..c08be31d5 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -126,7 +126,7 @@ def eval_loss(model, *batch, **batch_kwargs): max_batches=trainer_config.max_eval_batches, ) print(f"Loss of ref model on domain {domain}: {loss:.3f}") - levanter.tracker.log_summary({f"eval/ref/{domain}/loss": loss}) + levanter.tracker.log_metrics({f"eval/ref/{domain}/loss": loss}, step=0, commit=False) if validation_sets is not None: for domain, dataset in validation_sets.items(): @@ -140,54 +140,51 @@ def eval_loss(model, *batch, **batch_kwargs): else: sampling_weights = {domain: 1 / len(data_sources) for domain in data_sources.keys()} - # calculate per-token losses for proxy and ref - def compute_excess_loss(proxy, ref, batch): - proxy_losses = loss_fn(proxy, batch, reduction_axis=()) - ref_losses = loss_fn(ref, batch, reduction_axis=()) - # calculate excess losses - excess_losses = proxy_losses - ref_losses - return excess_losses - # Loss is \sum_d alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) - # Note that (\sum_d \alpha_d ref) is a constant in the model params, so we can ignore it + # Note that (\sum_d \alpha_d ref) is a constant in the model params, so we can ignore it for gradient computation + # (JAX would ignore it for us I think but it's nice to be explicit and lets us log better) @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True,)) def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) with hax.axis_mapping(trainer.compute_axis_mapping): - # this is one of those times when PyTorch's backward() is nice - excess_losses, excess_backward = eqx.filter_vjp( - lambda proxy: compute_excess_loss(proxy, ref, batch), proxy - ) + # calculate per-token losses for proxy and ref + proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy) + ref_losses = loss_fn(ref, batch, reduction_axis=()) + # calculate excess losses, aggregate per-domain losses + excess_losses = proxy_losses - ref_losses clipped_losses = hax.maximum(excess_losses, 0) - - mean_excess_loss = hax.mean(excess_losses, axis=None).scalar() - - per_domain_losses = _compute_per_domain_losses(trainer.TrainBatch, Domain, domains, clipped_losses) + per_domain_losses = _compute_per_domain_losses(clipped_losses, Domain, domains) # Update domain weights alpha = state.alpha * hax.exp(domain_weight_step_size * per_domain_losses) alpha /= hax.sum(alpha) alpha = (1 - smoothing) * alpha + initial_alpha * smoothing - distance_from_uniform = hax.sum(hax.abs(alpha - initial_alpha)) - # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) + # Note DoReMi says to use the unclipped excess loss here. Confirmed with Michael loss, grad_loss = eqx.filter_value_and_grad(_domain_weighted_loss)(excess_losses, Domain, domains, alpha) - grad = excess_backward(grad_loss)[0] + grad = proxy_loss_bwd(grad_loss)[0] new_state = trainer._take_train_step(state, proxy, grad) new_state = new_state.update_alpha(alpha) + # log metrics + distance_from_uniform = hax.sum(hax.abs(alpha - initial_alpha)) + mean_excess_loss = hax.mean(excess_losses).scalar() + mean_proxy_loss = hax.mean(proxy_losses).scalar() alpha_distance = hax.sum(hax.abs(new_state.average_alpha - state.average_alpha)) - alpha_dict = _alpha_weights_to_dict(Domain, new_state.average_alpha, domain_to_index) + alpha_dict = _decode_domain_array(Domain, new_state.average_alpha, domain_to_index) + per_domain_dict = _decode_domain_array(Domain, per_domain_losses, domain_to_index) levanter.tracker.jit_log_metrics( { "change_in_alpha": alpha_distance.scalar(), "alpha_distance_from_uniform": distance_from_uniform.scalar(), + "train/mean_excess_loss": mean_excess_loss, + "train/mean_proxy_loss": mean_proxy_loss, **{f"alpha/{domain}": weight for domain, weight in alpha_dict.items()}, - "mean_excess_loss": mean_excess_loss, + **{f"train/{domain}/loss": loss for domain, loss in per_domain_dict.items()}, }, step=state._step, ) @@ -223,18 +220,35 @@ def doremi_step(state: DoremiState, ref, batch, domains): trainer.run_hooks(new_info, force=True) alpha = state.average_alpha - final_weights = _alpha_weights_to_dict(Domain, alpha, domain_to_index) + final_weights = _decode_domain_array(Domain, alpha, domain_to_index) levanter.tracker.log_summary({"final_alpha": final_weights}) return {k: float(v) for k, v in final_weights.items()} -def _alpha_weights_to_dict(Domain, alpha, domain_name_to_index): +def _decode_domain_array(Domain, alpha, domain_name_to_index): final_weights = {domain: alpha[Domain, index].scalar() for domain, index in domain_name_to_index.items()} return final_weights +def _compute_per_domain_losses(losses, Domain, domains): + """Compute per-domain average losses from a batch of losses""" + # out[d] = E[losses | domain=d] + one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch + per_domain_losses = hax.dot(one_hot_domains, losses, axis=losses.axes, out_axes=(Domain,)) + # count the number of losses for each domain + norm = hax.dot(one_hot_domains, losses != 0, axis=losses.axes, out_axes=(Domain,)) + norm = hax.maximum(norm, 1.0) # don't nan if there are no losses for a domain + return per_domain_losses / norm + + +def _domain_weighted_loss(losses, Domain, domains, alpha): + """Average loss weighted by domain weights""" + per_domain_losses = _compute_per_domain_losses(losses, Domain, domains) + return hax.dot(alpha, per_domain_losses, axis=Domain).scalar() + + def _prepare_ref_model(ref, trainer): return hax.named_jit( lambda m: trainer.mp.cast_to_compute(inference_mode(m, True)), @@ -281,19 +295,3 @@ def shard(self, shard_id: int, num_shards: int) -> "DomainTaggedDataset[T]": def __iter__(self) -> Iterator[Tuple[T, hax.NamedArray]]: for item in self.dataset: yield item, self.domain_index - - -def _compute_per_domain_losses(Batch, Domain, domains, losses): - # TODO: this should weight by masked tokens - one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - # return hax.mean(losses * one_hot_domains, axis=losses.axes) - per_domain_losses = hax.dot(one_hot_domains, losses, axis=losses.axes) - norm = hax.maximum(hax.dot(one_hot_domains, losses != 0, axis=losses.axes), 1) - return per_domain_losses / norm - - -def _domain_weighted_loss(losses, Domain, domains, alpha): - one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch - # return hax.mean(losses * one_hot_domains * alpha, axis=None).scalar() - total = hax.dot(alpha, one_hot_domains, losses, axis=None) - return total.scalar() / losses.size From efbdd313e931141276ab560ff679aa93b37ad196 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 18 Jan 2024 15:52:00 -0800 Subject: [PATCH 154/205] don't log a value for domains with no data on a step --- src/levanter/doremi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index c08be31d5..602dae4db 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -184,7 +184,8 @@ def doremi_step(state: DoremiState, ref, batch, domains): "train/mean_excess_loss": mean_excess_loss, "train/mean_proxy_loss": mean_proxy_loss, **{f"alpha/{domain}": weight for domain, weight in alpha_dict.items()}, - **{f"train/{domain}/loss": loss for domain, loss in per_domain_dict.items()}, + # just skip domains with no excess loss + **{f"train/{domain}/excess_loss": loss for domain, loss in per_domain_dict.items() if loss > 0}, }, step=state._step, ) From a8102425f2dcd8c26b1f56c8890ca0d7692d4539 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 13:36:29 -0800 Subject: [PATCH 155/205] bring over the trainer-abstraction doc --- docs/design/Trainer-Abstraction.md | 157 +++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 docs/design/Trainer-Abstraction.md diff --git a/docs/design/Trainer-Abstraction.md b/docs/design/Trainer-Abstraction.md new file mode 100644 index 000000000..647058b6e --- /dev/null +++ b/docs/design/Trainer-Abstraction.md @@ -0,0 +1,157 @@ +# Trainer Abstraction Cleanup + +## Current Status (2024-01-23) + +### Trainer's current jobs + +Trainer currently has these jobs: + +* Handle registering and running callbacks +* Handle checkpointing (delegated to `Checkpointer`) +* Handling initialization, including loading from a checkpoint (partially delegated to `Checkpointer`) +* train_step/train_steps/training loop +* holding execution environment details (e.g. mixed precision policy, device mesh, etc) +* handles making data loaders (with the right sharding etc) +* sets up microbatching/grad accum (mostly factored out into nice pieces +* actually taking the step + +It would be nice if this were orthogonalized as much as possible. + +Hooks are already mostly delegated out to TrainerHooks so that's not too bad, and checkpoints are well encapsulated in the Checkpointer class, +other than the initialization/resume logic. + +Execution Environment is just work to abstract, and dovetails well with other work (i.e. just-in-time mixed precision). + +A lot of changes live in the doremi branch, because it needs an augmented trainer state to do its work + + + +### Other things that bother me + +* the cached_property loss_fn is smelly and actually behaves badly because jit(grad(jit(f))) doesn't work well +* I don't love the story for extending TrainerState + +### TrainerState extension + +We want TrainerState to be extensible, which means that: + +* it needs to inheritable +* methods like train_step need to be able to be overridden +* core "train_step" logic needs to be reusable (e.g. the logic for accumulating gradients, etc) in a way that + returns the right type (e.g. TrainerState or subclass) + +In Haliax we do initialization with a static/classmethod on modules, rather than the ctor. It's useful to have +a "plain old constructor" for various modules + +## Initialization/Resume + + +### Requirements + +There are 3 core cases to consider: + +1. No checkpoint, initialize from scratch +2. Checkpoint exists, load checkpoint and initialize "unserialized"/missing state +3. Partial checkpoint exists (e.g. only model), load checkpoint and initialize "unserialized"/missing state + +Typically, (3) is a full checkpoint, but we only want to load the model. This is useful for things like +fine-tuning a model, where we want to load the model but not the optimizer state. + +On top of that, we currently differentiate between passing in a model_init function and a model. This +complicates things a bit, but model_init is preferred because: + +1. it's more memory/time efficient when initializing from checkpoint +2. it makes it easier to get sharding and mixed precision right immediately. + +For (1), I think the time isn't a big deal, but we need a way of dealing +with the memory. One could maybe delete the passed in model (preserving only the shape) +once we determine the checkpoint exists? + +For (2), we also want to get the mixed precision and sharding set up correctly immediately. Passing in a model_init +allows us to wrap it in the right jit and partition magic to get that right. +We can and should expose (2) as a function... + + +Another complexity is `is_trainable`, which is a FilterSpec that allows you to specify which parts of the model +are trainable. This is useful for things like fine-tuning, where you want to freeze some layers. We use is_trainable in +4 ways: + +* only the is_trainable parts of a model get an optimizer_state associated with them +* we only serialize/deserialize the is_trainable parts of a model +* we only compute gradients for the is_trainable parts of a model +* We store the non-trainable parts of the model in compute precision, and the trainable parts in the param precision + +### Current Initialization w/o checkpoint + +This is conceptually what happens when there is no checkpointing: + +```python +@hax.named_jit(out_axis_resources=parameter_mapping) +def initialize(optimizer, model_init, is_trainable, mp): + model = model_init() + trainable_model = eqx.filter(model, is_trainable) + optimizer_state = optimizer.init(trainable_model) + + model = _cast_model_by_trainability(model, is_trainable, mp) + + state = TrainerState( + _step=0, + model=model, + optimizer_state=optimizer_state, + is_trainable=is_trainable, + ) + + state = hax.shard(state, parameter_mapping) + + return state + + +def _cast_model_by_trainability(model, is_trainable, mp): + trainable_model, non_trainable_model = eqx.partition(model, is_trainable) + non_trainable_model = mp.cast_to_compute(non_trainable_model) + trainable_model = mp.cast_to_param(trainable_model) + model = eqx.combine(trainable_model, non_trainable_model) + return model +``` + + + +### Current logic for initialization w/ checkpoint + +The logic for initial_state is pretty complex. There are 3 cases to consider: + +1. No checkpoint, initialize from scratch +2. Checkpoint exists, load checkpoint and initialize "unserialized"/missing state +3. Partial checkpoint exists (e.g. only model), load checkpoint and initialize "unserialized"/missing state + +At the moment the flow is: + +```python + +state_shape = eval_shape(_initialize_from_scratch(model_init_fn)) +if checkpoint_exists: + partial_state = load_checkpoint(state_shape, path) +elif partial_checkpoint_exists: + partial_checkpoint = load_checkpoint(state_shape.model, path, subpath="model") + partial_state = dataclasses.replace(partial_state, model=partial_checkpoint) + +state = jit(lambda s: combine(s, _initialize_from_scratch(model_init_fn)), partial_state) +``` + +I'd like to hoist this out so it's not dependent on the Trainer class, and so that it's easier to test. + +One of the things I was trying to accomplish was to define a checkpointed_or_initialize function that was just + +```python +state_shape = eval_shape(f) +if checkpoint_exists: + partial_state = load_checkpoint(state_shape, path) +else: + partial_state = eqx.filter(state_shape, lamba v: False) + +state = jit(lambda s: combine(s, f()), partial_state) + +``` + +But this doesn't actually compose well: you can't really do IO inside of eval_shape, so you can't really combine two +of those... or can you From e49fb38031a81b6413b7461939d8c79c8925590d Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 14:51:18 -0800 Subject: [PATCH 156/205] remove the wrapped loss_fn thing from trainer --- src/levanter/doremi.py | 6 ++++-- src/levanter/trainer.py | 44 +++++++++++++++++------------------------ 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 602dae4db..fcc863c3e 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -115,7 +115,8 @@ def estimate_mixture_weights( @eqx.filter_jit def eval_loss(model, *batch, **batch_kwargs): model = inference_mode(model, True) - return trainer.loss_fn(model, *batch, **batch_kwargs, key=None) + with hax.axis_mapping(trainer.compute_axis_mapping): + return trainer.loss_fn(model, *batch, **batch_kwargs, key=None) for domain, dataset in validation_sets.items(): loss = eval_loss_loop( @@ -185,7 +186,8 @@ def doremi_step(state: DoremiState, ref, batch, domains): "train/mean_proxy_loss": mean_proxy_loss, **{f"alpha/{domain}": weight for domain, weight in alpha_dict.items()}, # just skip domains with no excess loss - **{f"train/{domain}/excess_loss": loss for domain, loss in per_domain_dict.items() if loss > 0}, + # TODO: we need to skip logging things that are 0, but can't do that in jit, have to do it in python + **{f"train/{domain}/excess_loss": loss for domain, loss in per_domain_dict.items()}, }, step=state._step, ) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 0d428e812..db7422e12 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -1,7 +1,6 @@ import atexit import copy import dataclasses -import functools import logging as pylogging import os import sys @@ -182,7 +181,7 @@ def __init__( self.hooks = TrainerHooks() self.config = config self.optimizer = optimizer - self._raw_loss_function = loss_fn or ModuleComputeLoss() + self.loss_fn = loss_fn or ModuleComputeLoss() if isinstance(config.tracker, Sequence): self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) else: @@ -193,21 +192,6 @@ def __init__( if add_default_hooks: self._add_default_hooks() - @cached_property - def loss_fn(self): - """ - Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute - """ - - @named_jit(axis_resources=self.compute_axis_mapping) - @functools.wraps(self._raw_loss_function) - def fn(model, *batch, **batch_kwargs): - with hax.axis_mapping(self.compute_axis_mapping): - model = self.mp.cast_to_compute(model) - return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs)) - - return fn - @property def run_id(self) -> str: """Returns the run id""" @@ -443,7 +427,8 @@ def add_eval_hook(self, eval_dataset, name: Optional[str] = None): @eqx.filter_jit def eval_loss(model, *batch, **batch_kwargs): model = inference_mode(model, True) - return self.loss_fn(model, *batch, **batch_kwargs, key=None) + with hax.axis_mapping(self.compute_axis_mapping): + return self.loss_fn(model, *batch, **batch_kwargs, key=None) self.add_hook( callbacks.compute_validation_loss( @@ -486,21 +471,30 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal key, new_key = jax.random.split(state.training_key) model = inference_mode(state.model, False) - loss, grads = self._compute_gradients_microbatched(model, batch, **batch_kwargs, key=key) + def loss_fn(model, *batch, **batch_kwargs): + model = inference_mode(model, False) + # TODO: when we get ResourceEnvs in place, we can remove this cast_to_compute + model = self.mp.cast_to_compute(model) + with hax.axis_mapping(self.compute_axis_mapping): + return self.loss_fn(model, *batch, **batch_kwargs).scalar() + + # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us + loss, grads = self._compute_gradients_microbatched(loss_fn, model, batch, **batch_kwargs, key=key) + train_grads = eqx.filter(grads, state.is_trainable) - new_state = self._take_train_step(state, model, grads, *batch, **batch_kwargs, key=key) + new_state = self._take_train_step(state, model, train_grads, *batch, **batch_kwargs, key=key) new_state = dataclasses.replace(new_state, training_key=new_key) return loss, new_state - def _compute_gradients_microbatched(self, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: - grad_fn = eqx.filter_value_and_grad(self.loss_fn, has_aux=False) + def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: + grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) grad_fn = microbatched( grad_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping, - self.parameter_axis_mapping, + self.compute_axis_mapping, ) return grad_fn(model, *batch, **batch_kwargs) @@ -508,11 +502,9 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: """ Takes a training step. This is a separate method so that it can be overridden or used in a subclass. """ - # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us with hax.axis_mapping(self.parameter_axis_mapping): - train_grads = _partition_trainable_params(grads, state.is_trainable)[0] trainable_model = _partition_trainable_params(model, state.is_trainable)[0] - updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) + updates, opt_state = self.optimizer.update(grads, state.opt_state, params=trainable_model) # Sophia, e.g. if isinstance(self.optimizer, SecondOrderTransformation): From 13dc392ed67dbc1f0e3eac553d43f50443ea88df Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 15:57:48 -0800 Subject: [PATCH 157/205] factor out a take_opt_step. need to decide where to put it --- src/levanter/trainer.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index db7422e12..e54f53ce7 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -480,9 +480,8 @@ def loss_fn(model, *batch, **batch_kwargs): # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us loss, grads = self._compute_gradients_microbatched(loss_fn, model, batch, **batch_kwargs, key=key) - train_grads = eqx.filter(grads, state.is_trainable) - new_state = self._take_train_step(state, model, train_grads, *batch, **batch_kwargs, key=key) + new_state = self._take_train_step(state, loss_fn, model, grads, *batch, **batch_kwargs, key=key) new_state = dataclasses.replace(new_state, training_key=new_key) return loss, new_state @@ -498,20 +497,17 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwar ) return grad_fn(model, *batch, **batch_kwargs) - def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: + def _take_train_step(self, state: S, loss_fn, model, grads, *batch, **batch_kwargs) -> S: """ Takes a training step. This is a separate method so that it can be overridden or used in a subclass. """ with hax.axis_mapping(self.parameter_axis_mapping): - trainable_model = _partition_trainable_params(model, state.is_trainable)[0] - updates, opt_state = self.optimizer.update(grads, state.opt_state, params=trainable_model) - - # Sophia, e.g. - if isinstance(self.optimizer, SecondOrderTransformation): - opt_state = self.optimizer.update_hessian(opt_state, self.loss_fn, model, *batch, **batch_kwargs) - model = eqx.apply_updates(model, updates) + partial_loss = lambda model: loss_fn(model, *batch, **batch_kwargs) + model, opt_state = take_opt_step( + self.optimizer, model, state.opt_state, grads, partial_loss, state.is_trainable + ) - return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) + return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) def _initialize_state_from_scratch(self, model, training_key, is_trainable): # only force trainable params to param precision. Other params are cast to compute precision @@ -810,3 +806,21 @@ def _ensure_scalar(x: hax.types.Scalar | hax.NamedArray) -> hax.types.Scalar: return x.scalar() else: return x + + +def take_opt_step( + optimizer, + model: M, + opt_state: OptState, + grads: M, + obj_fn: Optional[Callable[[M], Scalar]] = None, + is_trainable: PyTree[FilterSpec] = True, +) -> tuple[M, OptState]: + train_grads = eqx.filter(grads, is_trainable) + trainable_model = eqx.filter(model, is_trainable) + updates, opt_state = optimizer.update(train_grads, opt_state, params=trainable_model) + # Sophia, e.g. + if isinstance(optimizer, SecondOrderTransformation): + opt_state = optimizer.update_hessian(opt_state, obj_fn, model) + model = eqx.apply_updates(model, updates) + return model, opt_state From 514da051afa184c283a117d6feb68fc553a8d6aa Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 15:58:50 -0800 Subject: [PATCH 158/205] explicitly expose microbatch_size, use it in microbatched --- src/levanter/grad_accum.py | 14 ++++++++++---- src/levanter/trainer.py | 6 +++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 39258665a..2476d4e37 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -28,7 +28,7 @@ class ReductionType(enum.Enum): def microbatched( fn: Callable[Args, R], Batch: Axis, - per_device_parallelism: int, + microbatch_size: int, accum_axis_mapping, compute_axis_mapping, patch_in_rng_key: Optional[str] = "key", @@ -39,6 +39,13 @@ def microbatched( Wraps a function that takes a batch and changes it to instead take microbatches and accumulate the results This function has to reduce the batch axis, so it can't be used for functions that need to keep the batch axis. + Can be used as a decorator with functools.partial, e.g.: + + >>> @functools.partial(microbatched, Batch=Batch, per_device_parallelism=4) + >>> def my_fn(x): + >>> return hax.mean(x + 1) + + Args: fn: a function to wrap Batch: the batch axis @@ -61,10 +68,9 @@ def microbatched( physical_axis_name = hax.partitioning.physical_axis_name(Batch, compute_axis_mapping) assert physical_axis_name is not None - if per_device_parallelism < 0: - raise ValueError(f"Bad value for {per_device_parallelism=}") + if microbatch_size <= 0: + raise ValueError(f"Bad value for {microbatch_size=}") - microbatch_size = data_axis_size * per_device_parallelism num_micro_steps = batch_size // microbatch_size Microbatch = Batch.resize(microbatch_size) AccumStep = Axis("accum_step", num_micro_steps) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index e54f53ce7..b28186732 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -491,7 +491,7 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwar grad_fn = microbatched( grad_fn, self.TrainBatch, - self.config.per_device_parallelism, + self.config.microbatch_size, self.parameter_axis_mapping, self.compute_axis_mapping, ) @@ -623,6 +623,10 @@ def TrainBatch(self): def EvalBatch(self): return Axis("batch", self.eval_batch_size) + @property + def microbatch_size(self): + return self.per_device_parallelism * self.data_axis_size + def __post_init__(self): if self.wandb is not None: warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) From f797a859eab624665555220fae044e3502b977ac Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 16:01:44 -0800 Subject: [PATCH 159/205] comment about custom_jvp on microbatched --- src/levanter/grad_accum.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 2476d4e37..3e472c8ef 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -24,6 +24,8 @@ class ReductionType(enum.Enum): # TODO: add MAX? +# TODO: should we use a custom_jvp on microbatched? + # cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 def microbatched( fn: Callable[Args, R], From 43019308f920b29e3286b543570312597e3def3a Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 16:37:56 -0800 Subject: [PATCH 160/205] unneeded cast --- src/levanter/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index b28186732..d1cd12c55 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -472,7 +472,6 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal model = inference_mode(state.model, False) def loss_fn(model, *batch, **batch_kwargs): - model = inference_mode(model, False) # TODO: when we get ResourceEnvs in place, we can remove this cast_to_compute model = self.mp.cast_to_compute(model) with hax.axis_mapping(self.compute_axis_mapping): From d3416b1abc9101e832cac3725aaf9925e6c86916 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 20:32:01 -0800 Subject: [PATCH 161/205] rename to mixed-precision.md --- src/levanter/trainer.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index d1cd12c55..494ff760b 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -480,8 +480,16 @@ def loss_fn(model, *batch, **batch_kwargs): # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us loss, grads = self._compute_gradients_microbatched(loss_fn, model, batch, **batch_kwargs, key=key) - new_state = self._take_train_step(state, loss_fn, model, grads, *batch, **batch_kwargs, key=key) + with hax.axis_mapping(self.parameter_axis_mapping): + partial_loss = lambda model: loss_fn(model, *batch, **batch_kwargs) + model, opt_state = take_opt_step( + self.optimizer, model, state.opt_state, grads, partial_loss, state.is_trainable + ) + + new_state = dataclasses.replace(state, model=model, opt_state=opt_state) + new_state = dataclasses.replace(new_state, training_key=new_key) + new_state = dataclasses.replace(new_state, _step=state._step + 1) return loss, new_state @@ -496,18 +504,6 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwar ) return grad_fn(model, *batch, **batch_kwargs) - def _take_train_step(self, state: S, loss_fn, model, grads, *batch, **batch_kwargs) -> S: - """ - Takes a training step. This is a separate method so that it can be overridden or used in a subclass. - """ - with hax.axis_mapping(self.parameter_axis_mapping): - partial_loss = lambda model: loss_fn(model, *batch, **batch_kwargs) - model, opt_state = take_opt_step( - self.optimizer, model, state.opt_state, grads, partial_loss, state.is_trainable - ) - - return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) - def _initialize_state_from_scratch(self, model, training_key, is_trainable): # only force trainable params to param precision. Other params are cast to compute precision model = cast_params_by_trainability(model, self.mp, is_trainable) From 95529090474bd797bdb26e263c5f64803182c425 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 20:43:36 -0800 Subject: [PATCH 162/205] cleanup ctors for BatchLoaders some --- src/levanter/data/loader.py | 22 ++++++++++++++-------- src/levanter/main/eval_lm.py | 2 +- src/levanter/main/viz_logprobs.py | 6 +++--- src/levanter/trainer.py | 4 ++-- tests/test_replicated_loader.py | 14 +++++++------- 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index d3710207e..ab7a43b2f 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -35,16 +35,20 @@ class BatchLoader(Iterable[Ex], abc.ABC): Batch: hax.Axis - mesh: Mesh + _mesh: Mesh axis_resources: Optional[ResourceMapping] - def __init__(self, max_capacity: Optional[int], axis_resources: Optional[ResourceMapping]): + def __init__( + self, Batch: hax.Axis, mesh: Mesh, axis_resources: Optional[ResourceMapping], max_capacity: Optional[int] + ): """ :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread :param axis_resources: """ self.max_capacity = max_capacity self.axis_resources = axis_resources + self._mesh = mesh + self.Batch = Batch def __iter__(self) -> Iterator[Ex]: ax_resources = self.axis_resources @@ -110,7 +114,7 @@ def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Arra def make_global_array_for_leaf(leaf_index, item_leaf_shape: Union[ShapeSpec, NamedShapeSpec]): raw_array = jax.make_array_from_callback( to_raw_shape(item_leaf_shape), - jax.sharding.NamedSharding(self.mesh, self._pspec_for(item_leaf_shape)), + jax.sharding.NamedSharding(self._mesh, self._pspec_for(item_leaf_shape)), lambda indices: get_local_data_for_leaf(indices, leaf_index), ) if isinstance(item_leaf_shape, NamedShapeSpec): @@ -161,8 +165,8 @@ class ShardedBatchLoader(BatchLoader[Ex]): def __init__( self, local_dataset: ShardableDataset[Ex], - mesh: Mesh, Batch: hax.Axis, + mesh: Optional[Mesh] = None, axis_resources: Optional[ResourceMapping] = None, max_capacity: Optional[int] = 10, *, @@ -183,7 +187,7 @@ def __init__( assert self.Batch.size % num_data_process_groups == 0 self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups) - super().__init__(max_capacity, axis_resources) + super().__init__(Batch, mesh, axis_resources, max_capacity) def _produce_batches(self) -> Iterator[PyTree]: one_item_generator = non_caching_cycle(self.item_dataset) @@ -238,17 +242,19 @@ class ReplicatedBatchLoader(BatchLoader[Ex]): def __init__( self, item_dataset: Dataset[Ex], - mesh: Mesh, Batch: hax.Axis, + mesh: Optional[Mesh] = None, axis_resources: Optional[ResourceMapping] = None, max_capacity: Optional[int] = 10, ): assert item_dataset is not None self.item_dataset = item_dataset - self.mesh = mesh self.Batch = Batch - super().__init__(max_capacity, axis_resources) + if mesh is None: + mesh = hax.current_resource_env().mesh + + super().__init__(Batch, mesh, axis_resources, max_capacity) def _produce_batches(self): for batch in _batched(self.item_dataset, self.Batch.size): diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index c7976ad41..c0ca2188c 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -57,11 +57,11 @@ def main(config: EvalLmConfig): raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) # type: ignore - eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch) compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping with config.trainer.device_mesh, hax.axis_mapping(parameter_axis_mapping): + eval_loader = ReplicatedBatchLoader(raw_dataset, Batch) key = jax.random.PRNGKey(0) vocab_size = len(tokenizer) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index bc89620f3..a7778b896 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -44,10 +44,10 @@ def main(config: VizGpt2Config): Pos = config.model.Pos KeyPos = config.model.KeyPos + validation_set = config.data.validation_set(Pos.size) + assert validation_set is not None eval_loader = ReplicatedBatchLoader( - CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore - config.trainer.device_mesh, - EvalBatch, + CausalLmDataset(validation_set, Pos, KeyPos), EvalBatch, config.trainer.device_mesh ) compute_axis_mapping = config.trainer.compute_axis_mapping diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 494ff760b..2281208a5 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -448,7 +448,7 @@ def replicated_loader(self, dataset: Dataset[X], batch_axis: Axis) -> Replicated Returns: ReplicatedBatchLoader: the batch loader """ - return ReplicatedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) + return ReplicatedBatchLoader(dataset, batch_axis, self.device_mesh, self.compute_axis_mapping) def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> ShardedBatchLoader[X]: """Creates a sharded batch loader for the given dataset. Generally you should use this @@ -461,7 +461,7 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar Returns: ShardedBatchLoader: the batch loader """ - return ShardedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) + return ShardedBatchLoader(dataset, batch_axis, self.device_mesh, self.compute_axis_mapping) @cached_property def _jit_train_step_fn(self): diff --git a/tests/test_replicated_loader.py b/tests/test_replicated_loader.py index 431a1c0bb..ea6272165 100644 --- a/tests/test_replicated_loader.py +++ b/tests/test_replicated_loader.py @@ -45,7 +45,7 @@ def test_local_batched_data_loading_model_axis_2(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, mesh, Batch) + loader = ReplicatedBatchLoader(cache, Batch, mesh) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -65,7 +65,7 @@ def test_local_batched_data_loading_model_axis_1(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, mesh, Batch) + loader = ReplicatedBatchLoader(cache, Batch, mesh) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -109,7 +109,7 @@ def test_structured_batches_model_axis_1(): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch, mesh) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -129,7 +129,7 @@ def test_structured_batches_model_axis_2(): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch, mesh) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -185,7 +185,7 @@ def test_structured_batches_model_axis_1_with_names(): Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch, mesh) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -208,7 +208,7 @@ def test_structured_batches_model_axis_2_with_names(): Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch, mesh) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -230,7 +230,7 @@ def test_structured_batches_model_axis_2_subsharded(): with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch, mesh) batches = list(itertools.islice(loader, 10)) for batch in batches: From 888d35e19c44fe532e62cad1b8eebd6c5be512f1 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 20:47:16 -0800 Subject: [PATCH 163/205] misc cleanup --- src/levanter/main/viz_logprobs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index a7778b896..798bd152e 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -57,7 +57,7 @@ def main(config: VizGpt2Config): key = jax.random.PRNGKey(0) vocab_size = len(tokenizer) - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), compute_axis_mapping) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size)) if vocab_size != Vocab.size: logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") @@ -82,7 +82,7 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): assert model is not None - model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) + model = hax.shard(model) compute_and_visualize_log_probs( path=config.path, From 49a409b618f1c6e0911fde5c7ee78979db32435d Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 23:27:17 -0800 Subject: [PATCH 164/205] wip --- docs/design/Trainer-Abstraction.md | 24 +++---- src/levanter/checkpoint.py | 11 ++-- src/levanter/data/loader.py | 76 +++++++++++------------ src/levanter/main/eval_lm.py | 2 +- src/levanter/tensorstore_serialization.py | 12 ++-- src/levanter/trainer.py | 31 ++++++--- tests/test_lora.py | 3 +- tests/test_replicated_loader.py | 14 ++--- tests/test_sharded_loader.py | 18 +++--- 9 files changed, 99 insertions(+), 92 deletions(-) diff --git a/docs/design/Trainer-Abstraction.md b/docs/design/Trainer-Abstraction.md index 647058b6e..f03ec286a 100644 --- a/docs/design/Trainer-Abstraction.md +++ b/docs/design/Trainer-Abstraction.md @@ -88,22 +88,22 @@ This is conceptually what happens when there is no checkpointing: ```python @hax.named_jit(out_axis_resources=parameter_mapping) def initialize(optimizer, model_init, is_trainable, mp): - model = model_init() - trainable_model = eqx.filter(model, is_trainable) - optimizer_state = optimizer.init(trainable_model) + model = model_init() + trainable_model = eqx.filter(model, is_trainable) + optimizer_state = optimizer.init(trainable_model) - model = _cast_model_by_trainability(model, is_trainable, mp) + model = _cast_model_by_trainability(model, is_trainable, mp) - state = TrainerState( - _step=0, - model=model, - optimizer_state=optimizer_state, - is_trainable=is_trainable, - ) + state = TrainerState( + _step=0, + model=model, + optimizer_state=optimizer_state, + is_trainable=is_trainable, + ) - state = hax.shard(state, parameter_mapping) + state = hax.shard(state, parameter_mapping) - return state + return state def _cast_model_by_trainability(model, is_trainable, mp): diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index d1076f325..42c15ec01 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -265,11 +265,10 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( tree: M, checkpoint_path: PathLike, + env: Optional[haliax.ResourceEnv] = None, *, subpath: Optional[str] = None, discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, ) -> M: """ Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint @@ -282,8 +281,7 @@ def load_checkpoint( checkpoint_path: the path to load the checkpoint from subpath: the subpath to load from the checkpoint discover_latest: whether to discover the latest checkpoint in the given path - axis_mapping: the axis mapping to use for loading the checkpoint - mesh: the mesh to use for loading the checkpoint + env: the resource env to use for loading the checkpoint. if None, the current resource env is used Returns: the loaded checkpoint, with the same structure as the exemplar tree @@ -306,7 +304,7 @@ def load_checkpoint( checkpoint_path = os.path.join(checkpoint_path, subpath) try: - tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) + tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, env) return tree except: # noqa from levanter.trainer import TrainerState @@ -328,8 +326,7 @@ def load_checkpoint( training_state = tree_deserialize_leaves_tensorstore( os.path.join(checkpoint_path, "training_state"), training_state, - axis_mapping=axis_mapping, - mesh=mesh, + env ) opt_state, key = training_state diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index ab7a43b2f..bd02cfd9c 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -8,12 +8,11 @@ import jax.numpy as jnp import jax.tree_util as jtu from jax.experimental import multihost_utils -from jax.sharding import Mesh, PartitionSpec +from jax.sharding import PartitionSpec from jaxtyping import Array, PyTree import haliax as hax from haliax import NamedArray -from haliax.partitioning import ResourceMapping from haliax.util import is_named_array import levanter.mesh @@ -34,29 +33,26 @@ class BatchLoader(Iterable[Ex], abc.ABC): + """ + Args: + Batch: the batch size + resource_env: the resource environment, if None then use the current one + max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread + """ Batch: hax.Axis - _mesh: Mesh - axis_resources: Optional[ResourceMapping] + def __init__( - self, Batch: hax.Axis, mesh: Mesh, axis_resources: Optional[ResourceMapping], max_capacity: Optional[int] + self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: Optional[int] ): - """ - :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread - :param axis_resources: - """ + self.max_capacity = max_capacity - self.axis_resources = axis_resources - self._mesh = mesh + self.resource_env = resource_env or hax.current_resource_env() self.Batch = Batch def __iter__(self) -> Iterator[Ex]: - ax_resources = self.axis_resources - if ax_resources is None: - ax_resources = hax.partitioning.current_thread_local_mapping() - def produce_batches(): - with hax.axis_mapping(ax_resources): + with self.resource_env: for batch in self._produce_batches(): yield batch @@ -114,7 +110,7 @@ def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Arra def make_global_array_for_leaf(leaf_index, item_leaf_shape: Union[ShapeSpec, NamedShapeSpec]): raw_array = jax.make_array_from_callback( to_raw_shape(item_leaf_shape), - jax.sharding.NamedSharding(self._mesh, self._pspec_for(item_leaf_shape)), + jax.sharding.NamedSharding(self.resource_env.mesh, self._pspec_for(item_leaf_shape)), lambda indices: get_local_data_for_leaf(indices, leaf_index), ) if isinstance(item_leaf_shape, NamedShapeSpec): @@ -135,10 +131,10 @@ def make_global_array_for_leaf(leaf_index, item_leaf_shape: Union[ShapeSpec, Nam def _pspec_for(self, shape_spec: Union[ShapeSpec, NamedShapeSpec]) -> PartitionSpec: if isinstance(shape_spec, ShapeSpec): # type: ignore - batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) + batch_name = hax.partitioning.physical_axis_name(self.Batch, self.resource_env) return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) else: - return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore + return hax.partitioning.pspec_for_axis(shape_spec.shape, self.resource_env) # type: ignore class ShardedBatchLoader(BatchLoader[Ex]): @@ -157,8 +153,8 @@ class ShardedBatchLoader(BatchLoader[Ex]): load, by determining which row(s) of the device mesh the process is responsible for. :arg local_dataset: a dataset that is shardable and can be iterated over - :arg mesh: the device mesh :arg Batch: the batch size + :arg env: the resource environment, if None then use the current one :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread """ @@ -166,28 +162,31 @@ def __init__( self, local_dataset: ShardableDataset[Ex], Batch: hax.Axis, - mesh: Optional[Mesh] = None, - axis_resources: Optional[ResourceMapping] = None, + env: Optional[hax.ResourceEnv] = None, max_capacity: Optional[int] = 10, *, override_process_data_pos: Optional[int] = None, # for testing override_process_data_groups: Optional[int] = None, # for testing ): - self.mesh = mesh - self.Batch = Batch - - process_data_pos = override_process_data_pos or levanter.mesh.process_mesh_position(mesh)[0] - num_data_process_groups = override_process_data_groups or levanter.mesh.process_mesh_size(mesh)[0] + env = env or hax.current_resource_env() + # TODO: this could be better + mesh = env.mesh + if mesh is not None: + process_data_pos = override_process_data_pos or levanter.mesh.process_mesh_position(mesh)[0] + num_data_process_groups = override_process_data_groups or levanter.mesh.process_mesh_size(mesh)[0] + else: + process_data_pos = override_process_data_pos or 0 + num_data_process_groups = override_process_data_groups or 1 if not override_process_data_groups: assert num_data_process_groups <= jax.process_count() self.process_data_pos = process_data_pos self.num_data_process_groups = num_data_process_groups - assert self.Batch.size % num_data_process_groups == 0 + assert Batch.size % num_data_process_groups == 0 self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups) - super().__init__(Batch, mesh, axis_resources, max_capacity) + super().__init__(Batch, env, max_capacity) def _produce_batches(self) -> Iterator[PyTree]: one_item_generator = non_caching_cycle(self.item_dataset) @@ -232,29 +231,24 @@ class ReplicatedBatchLoader(BatchLoader[Ex]): Note: this class discards the final batch if it is smaller than the batch size. - :arg item_dataset: a dataset that is shardable and can be iterated over - :arg mesh: the device mesh - :arg Batch: the batch size - :arg axis_resources: the resources for the batch axis - :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread + Args: + item_dataset: the dataset to load + Batch: the batch size + env: the resource environment + max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread """ def __init__( self, item_dataset: Dataset[Ex], Batch: hax.Axis, - mesh: Optional[Mesh] = None, - axis_resources: Optional[ResourceMapping] = None, + env: Optional[hax.ResourceEnv] = None, max_capacity: Optional[int] = 10, ): assert item_dataset is not None self.item_dataset = item_dataset - self.Batch = Batch - - if mesh is None: - mesh = hax.current_resource_env().mesh - super().__init__(Batch, mesh, axis_resources, max_capacity) + super().__init__(Batch, env, max_capacity) def _produce_batches(self): for batch in _batched(self.item_dataset, self.Batch.size): diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index c0ca2188c..2937218bc 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -60,7 +60,7 @@ def main(config: EvalLmConfig): compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping - with config.trainer.device_mesh, hax.axis_mapping(parameter_axis_mapping): + with config.trainer.param_env: eval_loader = ReplicatedBatchLoader(raw_dataset, Batch) key = jax.random.PRNGKey(0) diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index f75ee87ff..67249e6d0 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -89,9 +89,9 @@ async def load_array_from_tensorstore(spec): return await t.read("C") -async def _deserialize_one_leaf(like, spec, axis_mapping, mesh): +async def _deserialize_one_leaf(like, spec, env): if is_named_array(like): - return await _deserialize_named_array(like, spec, axis_mapping, mesh) + return await _deserialize_named_array(like, spec, env) elif isinstance(like, jax.Array): if not like.is_fully_addressable: return await array_ser.async_deserialize(like.sharding, spec, global_shape=like.shape, dtype=like.dtype) @@ -108,11 +108,11 @@ async def _deserialize_one_leaf(like, spec, axis_mapping, mesh): raise TypeError(f"Can't deserialize {type(like)}") -async def _deserialize_named_array(like, spec, axis_mapping, mesh): +async def _deserialize_named_array(like, spec, env): # the main thing we're worried about is deserialized NamedArrays that are not yet arrays but are ShapedDtypeStructs. # These don't (currently) have sharding info, but we can infer it from the axes if isinstance(like.array, jax.ShapeDtypeStruct): - sharding = hax.partitioning.sharding_for_axis(like.axes, axis_mapping, mesh) + sharding = hax.partitioning.sharding_for_axis(like.axes, env) array = await array_ser.async_deserialize(sharding, spec, global_shape=like.array.shape, dtype=like.dtype) assert sharding.is_equivalent_to(array.sharding, len(like.array.shape)) return hax.NamedArray(array, like.axes) @@ -122,7 +122,7 @@ async def _deserialize_named_array(like, spec, axis_mapping, mesh): def tree_deserialize_leaves_tensorstore( - checkpoint_dir, pytree, axis_mapping: Optional[ResourceMapping] = None, mesh: Optional[Mesh] = None + checkpoint_dir, pytree, env: Optional[hax.ResourceEnv] = None ): """ Deserializes a PyTree of Arrays and NamedArrays from a Tensorstore checkpoint, returning a pytree with the same shape @@ -141,7 +141,7 @@ def tree_deserialize_leaves_tensorstore( leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) specs = htu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths) - deser_partial = functools.partial(_deserialize_one_leaf, axis_mapping=axis_mapping, mesh=mesh) + deser_partial = functools.partial(_deserialize_one_leaf, env=env) async def _do_deserialize(): futures = jtu.tree_map(deser_partial, pytree, specs, is_leaf=is_named_array) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 2281208a5..9ecb82210 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -229,6 +229,14 @@ def parameter_axis_mapping(self) -> ResourceMapping: def compute_axis_mapping(self) -> ResourceMapping: return self.config.compute_axis_mapping + @property + def param_env(self) -> hax.ResourceEnv: + return self.config.param_env + + @property + def compute_env(self) -> hax.ResourceEnv: + return self.config.compute_env + @property def device_mesh(self) -> Mesh: return self.config.device_mesh @@ -244,8 +252,7 @@ def EvalBatch(self): def __enter__(self): this_managers = [ levanter.current_tracker(self.tracker), - self.device_mesh, - hax.axis_mapping(self.parameter_axis_mapping), + self.param_env, ] self._cmanagers.append(this_managers) @@ -324,7 +331,7 @@ def init_state_and_model(model_init, training_key, is_trainable): if do_load_checkpoint is not False: try: - state = load_checkpoint(saveable_state_shape, checkpoint_path, axis_mapping=axis_mapping, mesh=mesh) + state = load_checkpoint(saveable_state_shape, checkpoint_path, self.param_env) except FileNotFoundError: if do_load_checkpoint: raise @@ -340,8 +347,7 @@ def init_state_and_model(model_init, training_key, is_trainable): loaded_model = load_checkpoint( saveable_state_shape.model, initial_model_path, - axis_mapping=axis_mapping, - mesh=mesh, + env=self.param_env, subpath="model", ) @@ -448,7 +454,7 @@ def replicated_loader(self, dataset: Dataset[X], batch_axis: Axis) -> Replicated Returns: ReplicatedBatchLoader: the batch loader """ - return ReplicatedBatchLoader(dataset, batch_axis, self.device_mesh, self.compute_axis_mapping) + return ReplicatedBatchLoader(dataset, batch_axis, self.config.compute_env) def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> ShardedBatchLoader[X]: """Creates a sharded batch loader for the given dataset. Generally you should use this @@ -461,7 +467,7 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar Returns: ShardedBatchLoader: the batch loader """ - return ShardedBatchLoader(dataset, batch_axis, self.device_mesh, self.compute_axis_mapping) + return ShardedBatchLoader(dataset, batch_axis, self.config.compute_env) @cached_property def _jit_train_step_fn(self): @@ -569,7 +575,6 @@ class TrainerConfig: fsdp_axis: Optional[Union[str, List[str]]] = "embed" # Axis/Axes to use for FSDP tensor_parallel_axes: Optional[List[str]] = None # Axes, if any, to use for tensor parallelism - # TODO: in theory we can support tuples of physical axis names, but I don't think anyone actually uses that. axis_resources: Mapping[str, str] = field(default_factory=dict) """mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred""" parameter_axis_resources: Mapping[str, str] = field(default_factory=dict) # overrides axis_mapping for parameter @@ -670,6 +675,16 @@ def data_axis_size(self): assert jax.device_count() % self.model_axis_size == 0 return jax.device_count() // self.model_axis_size + + @cached_property + def compute_env(self) -> hax.ResourceEnv: + return hax.ResourceEnv(self.compute_axis_mapping, self.mp, self.device_mesh) + + @cached_property + def param_env(self) -> hax.ResourceEnv: + return hax.ResourceEnv(self.parameter_axis_mapping, self.mp, self.device_mesh) + + @cached_property def compute_axis_mapping(self) -> ResourceMapping: """Mapping from logical axis to physical axis for compute.""" diff --git a/tests/test_lora.py b/tests/test_lora.py index 5ba011bce..46cc2e0f8 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -198,6 +198,7 @@ def test_lora_load_in_peft(): @skip_if_no_torch def test_lora_merged_load_in_hf(): + jax.config.update("jax_traceback_filtering", "off") import torch converter: HFCheckpointConverter = Gpt2Config.default_hf_checkpoint_converter @@ -212,7 +213,7 @@ def test_lora_merged_load_in_hf(): causal_mask = hax.nn.attention.causal_mask(model.Pos, config.KeyPos) - with (tempfile.TemporaryDirectory() as tmpdir): + with tempfile.TemporaryDirectory() as tmpdir: converter.save_pretrained(model, f"{tmpdir}/model") lora_config = LoraConfig(r=8, target_modules=["c_attn"]) diff --git a/tests/test_replicated_loader.py b/tests/test_replicated_loader.py index ea6272165..64774546b 100644 --- a/tests/test_replicated_loader.py +++ b/tests/test_replicated_loader.py @@ -45,7 +45,7 @@ def test_local_batched_data_loading_model_axis_2(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, Batch, mesh) + loader = ReplicatedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -65,7 +65,7 @@ def test_local_batched_data_loading_model_axis_1(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, Batch, mesh) + loader = ReplicatedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -109,7 +109,7 @@ def test_structured_batches_model_axis_1(): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, Batch, mesh) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -129,7 +129,7 @@ def test_structured_batches_model_axis_2(): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, Batch, mesh) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -185,7 +185,7 @@ def test_structured_batches_model_axis_1_with_names(): Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, Batch, mesh) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -208,7 +208,7 @@ def test_structured_batches_model_axis_2_with_names(): Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, Batch, mesh) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -230,7 +230,7 @@ def test_structured_batches_model_axis_2_subsharded(): with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, Batch, mesh) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: diff --git a/tests/test_sharded_loader.py b/tests/test_sharded_loader.py index 19f72bcfe..dfcef498a 100644 --- a/tests/test_sharded_loader.py +++ b/tests/test_sharded_loader.py @@ -48,7 +48,7 @@ def test_sharded_data_loading_model_axis_2(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(cache, mesh, Batch) + loader = ShardedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -67,7 +67,7 @@ def test_sharded_data_loading_model_axis_1(): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(cache, mesh, Batch) + loader = ShardedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -111,7 +111,7 @@ def test_structured_batches_model_axis_1(): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -144,7 +144,7 @@ def test_can_batch_named_scalars(): with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): dataset = ScalarDataset(0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -164,7 +164,7 @@ def test_structured_batches_model_axis_2(): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -221,7 +221,7 @@ def test_structured_batches_model_axis_1_with_names(): Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -242,7 +242,7 @@ def test_structured_batches_model_axis_2_with_names(): Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -264,7 +264,7 @@ def test_structured_batches_model_axis_2_subsharded(): with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -282,7 +282,7 @@ def test_sharded_loader_doesnt_throw_away_data(): with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): dataset = ScalarDataset(0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) dataset_examples = list(itertools.islice(dataset, 10 * Batch.size)) From 78d934235d856f6888a7c41cba22c60d957228ff Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 14:32:43 -0800 Subject: [PATCH 165/205] stable point: migrating to resourceenvs --- src/levanter/checkpoint.py | 20 ++++++-------------- src/levanter/compat/torch_serialization.py | 4 ++-- src/levanter/data/shard_cache.py | 2 +- src/levanter/doremi.py | 9 +++++++-- src/levanter/main/viz_logprobs.py | 2 +- src/levanter/tensorstore_serialization.py | 9 +++------ src/levanter/trainer.py | 9 ++------- tests/test_llama.py | 2 +- 8 files changed, 23 insertions(+), 34 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 42c15ec01..2232b085f 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -105,12 +105,11 @@ def load_checkpoint( path: Optional[PathLike] = None, *, discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, + env: Optional[haliax.ResourceEnv] = None, ) -> Optional[M]: if path is None: path = self.base_path - return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh) + return load_checkpoint(state, path, discover_latest=discover_latest, env=env) def load_model( self, @@ -118,16 +117,13 @@ def load_model( path: Optional[str] = None, *, discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, + env: Optional[haliax.ResourceEnv] = None, ) -> Optional[M]: """ Convenience method/holdover from previous API for loading checkpoints. Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. """ - ret_dict = self.load_checkpoint( - {"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) + ret_dict = self.load_checkpoint({"model": model}, path, discover_latest=discover_latest, env=env) if ret_dict is None: return None return ret_dict["model"] @@ -315,18 +311,14 @@ def load_checkpoint( logger.warning("Attempting to load old-style checkpoint") model, training_state = tree.model, (tree.opt_state, tree.training_key) - model = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh - ) + model = tree_deserialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model, env) if training_state is None: opt_state = None key = None else: training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), - training_state, - env + os.path.join(checkpoint_path, "training_state"), training_state, env ) opt_state, key = training_state diff --git a/src/levanter/compat/torch_serialization.py b/src/levanter/compat/torch_serialization.py index 1c326911b..0df3be216 100644 --- a/src/levanter/compat/torch_serialization.py +++ b/src/levanter/compat/torch_serialization.py @@ -89,8 +89,8 @@ def jax_tree_from_state_dict(tree: PyTree, state_dict: StateDict, prefix: Option raise ValueError("Cannot extract a leaf value from a torch dict without a prefix") array = state_dict[prefix] - mesh = haliax.partitioning._get_mesh() - if mesh.devices.size > 1: # this happens with the default mesh + mesh = haliax.current_resource_env().mesh + if mesh is not None: # this happens with the default mesh pspec = haliax.partitioning.pspec_for_axis(tree.axes) sharding = jax.sharding.NamedSharding(mesh, pspec) array = jax.make_array_from_callback(tree.array.shape, sharding, lambda indices: array[indices]) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f3565665b..91b19a87c 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -404,7 +404,7 @@ def shutdown(self): if self._processing_thread.is_alive(): self._processing_thread.join() - def _loop(self): + def _loop(self: "PriorityProcessorActor"): should_sleep = False backpressure_queue: list[ray.ObjectRef] = [] diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index fcc863c3e..39e26cf22 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -16,7 +16,7 @@ from levanter.data import ShardableDataset from levanter.data.mixture import MixtureDataset from levanter.logging import capture_time -from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState +from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState, take_opt_step from levanter.types import ComputeLossFunction, ModuleComputeLoss from levanter.utils.tree_utils import inference_mode @@ -167,7 +167,12 @@ def doremi_step(state: DoremiState, ref, batch, domains): loss, grad_loss = eqx.filter_value_and_grad(_domain_weighted_loss)(excess_losses, Domain, domains, alpha) grad = proxy_loss_bwd(grad_loss)[0] - new_state = trainer._take_train_step(state, proxy, grad) + partial_loss = lambda model: loss_fn(model, *batch) + model, opt_state = take_opt_step( + optimizer, state.model, state.opt_state, grad, partial_loss, state.is_trainable + ) + + new_state = dataclasses.replace(state, model=model, opt_state=opt_state, _step=state._step + 1) new_state = new_state.update_alpha(alpha) # log metrics diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 798bd152e..9e8a37059 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -47,7 +47,7 @@ def main(config: VizGpt2Config): validation_set = config.data.validation_set(Pos.size) assert validation_set is not None eval_loader = ReplicatedBatchLoader( - CausalLmDataset(validation_set, Pos, KeyPos), EvalBatch, config.trainer.device_mesh + CausalLmDataset(validation_set, Pos, KeyPos), EvalBatch, config.trainer.compute_env ) compute_axis_mapping = config.trainer.compute_axis_mapping diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 67249e6d0..25f7fa594 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -13,12 +13,10 @@ import jax.tree_util as jtu import numpy as np import tensorstore -from jax.sharding import Mesh from tensorstore import TensorStore import haliax as hax import haliax.tree_util as htu -from haliax.partitioning import ResourceMapping from haliax.util import is_named_array from levanter.utils import jax_utils @@ -117,13 +115,11 @@ async def _deserialize_named_array(like, spec, env): assert sharding.is_equivalent_to(array.sharding, len(like.array.shape)) return hax.NamedArray(array, like.axes) else: - array = await _deserialize_one_leaf(like.array, spec, axis_mapping, mesh) + array = await _deserialize_one_leaf(like.array, spec, env) return hax.NamedArray(array, like.axes) -def tree_deserialize_leaves_tensorstore( - checkpoint_dir, pytree, env: Optional[hax.ResourceEnv] = None -): +def tree_deserialize_leaves_tensorstore(checkpoint_dir, pytree, env: Optional[hax.ResourceEnv] = None): """ Deserializes a PyTree of Arrays and NamedArrays from a Tensorstore checkpoint, returning a pytree with the same shape as the one provided. This method is capable of deserializing NamedArrays that are the result of an eval_shape call @@ -138,6 +134,7 @@ def tree_deserialize_leaves_tensorstore( :return: a pytree with the same shape as the exemplar pytree, but with the arrays deserialized from the checkpoint """ # TODO: support ShapeDtypeStructs that are not NamedArrays + env = env or hax.current_resource_env() leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) specs = htu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 9ecb82210..a14cea5f9 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -315,8 +315,6 @@ def initial_state( checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) do_load_checkpoint = self.config.load_checkpoint - axis_mapping = self.parameter_axis_mapping - mesh = self.device_mesh initial_model_path = self.config.initialize_from # we don't save the full trainer state, so we need to filter out the non-trainable parameters @@ -355,7 +353,7 @@ def init_state_and_model(model_init, training_key, is_trainable): model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), loaded_model, model_init) # now we initialize a fresh trainer state, possibly just to finish any missing fields - @named_jit(axis_resources=axis_mapping, donate_args=(True, True, True, False)) + @named_jit(axis_resources=self.param_env, donate_args=(True, True, True, False)) def init_state(partial_state, model_init, training_key, is_trainable): model = model_init() fresh_state = self._initialize_state_from_scratch(model, training_key, is_trainable) @@ -494,8 +492,7 @@ def loss_fn(model, *batch, **batch_kwargs): new_state = dataclasses.replace(state, model=model, opt_state=opt_state) - new_state = dataclasses.replace(new_state, training_key=new_key) - new_state = dataclasses.replace(new_state, _step=state._step + 1) + new_state = dataclasses.replace(new_state, _step=state._step + 1, training_key=new_key) return loss, new_state @@ -675,7 +672,6 @@ def data_axis_size(self): assert jax.device_count() % self.model_axis_size == 0 return jax.device_count() // self.model_axis_size - @cached_property def compute_env(self) -> hax.ResourceEnv: return hax.ResourceEnv(self.compute_axis_mapping, self.mp, self.device_mesh) @@ -684,7 +680,6 @@ def compute_env(self) -> hax.ResourceEnv: def param_env(self) -> hax.ResourceEnv: return hax.ResourceEnv(self.parameter_axis_mapping, self.mp, self.device_mesh) - @cached_property def compute_axis_mapping(self) -> ResourceMapping: """Mapping from logical axis to physical axis for compute.""" diff --git a/tests/test_llama.py b/tests/test_llama.py index 15a5ab452..46040afd2 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -191,7 +191,7 @@ def test_llama_decoder_layer(): state = llama_decoder_layer.to_state_dict() state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} - hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000)) + hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000), 0) hf_decoder_layer.load_state_dict(state, strict=True) x, mask = _get_random_inputs(llama_config) From 8e4e183863475617d46a16b29791a66e2f222669 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 14:45:39 -0800 Subject: [PATCH 166/205] require the jamp branch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 342e9bd45..ed207afc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ # jax = {version = ">=0.4.10,<0.5.0"} # "haliax>=1.3,<2.0", # Haliax changes in step with levanter, so we'll just use the git version except for releases. - "haliax @ git+https://github.com/stanford-crfm/haliax.git@dev", + "haliax @ git+https://github.com/stanford-crfm/haliax.git@jamp", "equinox>=0.10.7", "jaxtyping>=0.2.20", "transformers>=4.22.0", From 9a0ea6d1de15f7193eee43e0108c842909c83e75 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 15:25:59 -0800 Subject: [PATCH 167/205] knknajkdnjakd --- src/levanter/data/loader.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index bd02cfd9c..5c2d22452 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -39,12 +39,10 @@ class BatchLoader(Iterable[Ex], abc.ABC): resource_env: the resource environment, if None then use the current one max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread """ - Batch: hax.Axis + Batch: hax.Axis - def __init__( - self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: Optional[int] - ): + def __init__(self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: Optional[int]): self.max_capacity = max_capacity self.resource_env = resource_env or hax.current_resource_env() @@ -52,6 +50,7 @@ def __init__( def __iter__(self) -> Iterator[Ex]: def produce_batches(): + print("ZZZ", self.resource_env) with self.resource_env: for batch in self._produce_batches(): yield batch From 7c19f47069332e04a3028920b811dc8dcc3d4a23 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 15:41:58 -0800 Subject: [PATCH 168/205] try this? --- src/levanter/data/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 5c2d22452..d7a038b58 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -50,8 +50,8 @@ def __init__(self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: def __iter__(self) -> Iterator[Ex]: def produce_batches(): - print("ZZZ", self.resource_env) - with self.resource_env: + # print("ZZZ", self.resource_env) + with hax.axis_mapping(self.resource_env.axis_mapping): for batch in self._produce_batches(): yield batch From d98a885e3ca79a88fe4c81521e13fc75dbe17e9c Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 15:53:18 -0800 Subject: [PATCH 169/205] cleanup and explain the issue --- src/levanter/data/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index d7a038b58..fd9b05463 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -50,7 +50,8 @@ def __init__(self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: def __iter__(self) -> Iterator[Ex]: def produce_batches(): - # print("ZZZ", self.resource_env) + # NB: we do not want to use the mesh here, because it makes JAX unhappy + # TODO: figure out why we can't use mesh here and fix. with hax.axis_mapping(self.resource_env.axis_mapping): for batch in self._produce_batches(): yield batch From 015dfb3afa0d65f356b0ba10ca36864aa83e1371 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 15:54:28 -0800 Subject: [PATCH 170/205] see if we get the just-in-time conversion to bf16 that we want --- src/levanter/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index a14cea5f9..0c3bfa553 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -476,8 +476,6 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal model = inference_mode(state.model, False) def loss_fn(model, *batch, **batch_kwargs): - # TODO: when we get ResourceEnvs in place, we can remove this cast_to_compute - model = self.mp.cast_to_compute(model) with hax.axis_mapping(self.compute_axis_mapping): return self.loss_fn(model, *batch, **batch_kwargs).scalar() From cddaf20a31f1576715ec6e5a75ea73661dc81a74 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 16:04:32 -0800 Subject: [PATCH 171/205] wtf --- src/levanter/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 0c3bfa553..7e56bbc84 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -476,7 +476,7 @@ def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scal model = inference_mode(state.model, False) def loss_fn(model, *batch, **batch_kwargs): - with hax.axis_mapping(self.compute_axis_mapping): + with self.compute_env: return self.loss_fn(model, *batch, **batch_kwargs).scalar() # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us From 27949f83150194a971ce1cc2bc8619c967421ffc Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 16:04:32 -0800 Subject: [PATCH 172/205] bypass microbatching if we don't need it? --- src/levanter/grad_accum.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 3e472c8ef..6a8e64fad 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -74,6 +74,10 @@ def microbatched( raise ValueError(f"Bad value for {microbatch_size=}") num_micro_steps = batch_size // microbatch_size + + if num_micro_steps == 1: + return fn + Microbatch = Batch.resize(microbatch_size) AccumStep = Axis("accum_step", num_micro_steps) assert num_micro_steps * microbatch_size == batch_size From c3a9ce1d3b758acc75e69a3208873af1423e5bae Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 19:17:05 -0800 Subject: [PATCH 173/205] switch to using hnn.Embedding in gpt2, which means we get the mixed precision right --- src/levanter/models/gpt2.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 5822fab30..a2e1e65d8 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -312,40 +312,51 @@ class Gpt2Embeddings(StateDictSerializationMixin, eqx.Module): Vocab: Axis = eqx.static_field() config: Gpt2Config = eqx.static_field() - token_embeddings: NamedArray - position_embeddings: NamedArray + token_embeddings: hnn.Embedding + position_embeddings: hnn.Embedding dropout: hnn.Dropout @staticmethod def init(Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2Embeddings": k_wte, k_wpe, k_out = jrandom.split(key, 3) - token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) * config.initializer_range - position_embeddings = hax.random.normal(k_wpe, (config.Pos, config.Embed)) * (config.initializer_range / 2) + # token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) * config.initializer_range + # position_embeddings = hax.random.normal(k_wpe, (config.Pos, config.Embed)) * (config.initializer_range / 2) + token_embeddings = hnn.Embedding.init( + Vocab, config.Embed, key=k_wte, initializer_range=config.initializer_range + ) + position_embeddings = hnn.Embedding.init( + config.Pos, config.Embed, key=k_wpe, initializer_range=config.initializer_range / 2 + ) dropout = hnn.Dropout(pdrop=config.embed_pdrop) return Gpt2Embeddings(Vocab, config, token_embeddings, position_embeddings, dropout) @named_call def embed(self, input_ids, *, key): - input_embeds = self.token_embeddings.take("vocab", input_ids) - position_embeds = self.position_embeddings - - input_len = input_ids.resolve_axis("position").size - x = input_embeds + position_embeds["position", hax.dslice(0, input_len)] + # input_embeds = self.token_embeddings.take("vocab", input_ids) + # position_embeds = self.position_embeddings + input_embeds = self.token_embeddings(input_ids) + + input_Pos = input_ids.resolve_axis("position") + position_embeds = self.position_embeddings.embed(hax.arange(input_Pos)) + # x = input_embeds + position_embeds["position", hax.dslice(0, input_len)] + x = input_embeds + position_embeds x = self.dropout(x, key=key) return x def unembed(self, x: NamedArray): - return hax.dot("embed", x, self.token_embeddings) + return hax.dot("embed", x, self.token_embeddings.weight) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"token_embeddings": "wte.weight", "position_embeddings": "wpe.weight"} + return {"token_embeddings": "wte", "position_embeddings": "wpe"} def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): - new_weights = hax.tree_util.resize_axis(self.token_embeddings, self.Vocab, new_size, key=key) - return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) + # new_weights = hax.tree_util.resize_axis(self.token_embeddings, self.Vocab, new_size, key=key) + # return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) + new_token_embeddings = self.token_embeddings.resize_embeddings(new_size, key=key) + return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_token_embeddings) class Gpt2LMHeadModel(eqx.Module, LmWithHfSerializationMixin[Gpt2Config]): From 0e913527cc6a6878aaad429fd5b7744148f20717 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 31 Jan 2024 21:52:29 -0800 Subject: [PATCH 174/205] switch to using compute_envs where posisble use .shard instead --- src/levanter/compat/hf_checkpoints.py | 5 +---- src/levanter/doremi.py | 2 +- src/levanter/grad_accum.py | 4 ++-- src/levanter/main/eval_lm.py | 2 +- src/levanter/models/mpt.py | 3 +-- src/levanter/trainer.py | 5 +++-- tests/test_backpack.py | 3 +-- tests/test_levanter_hf_consistency.py | 13 +++++-------- 8 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 25339d245..138b33200 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -519,10 +519,7 @@ def load_pretrained( f"Model vocab size ({Vocab.size}) does not match tokenizer vocab size ({tokenizer_Vocab.size})" ) - if axis_mapping is not None: - lev_model = haliax.shard_with_axis_mapping(lev_model, axis_mapping) - else: - lev_model = haliax.auto_sharded(lev_model) + lev_model = haliax.shard(lev_model, axis_mapping) # once more for good measure gc.collect() diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 39e26cf22..3e3c3c84b 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -115,7 +115,7 @@ def estimate_mixture_weights( @eqx.filter_jit def eval_loss(model, *batch, **batch_kwargs): model = inference_mode(model, True) - with hax.axis_mapping(trainer.compute_axis_mapping): + with trainer.compute_env: return trainer.loss_fn(model, *batch, **batch_kwargs, key=None) for domain, dataset in validation_sets.items(): diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 6a8e64fad..8ac6e9395 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -113,7 +113,7 @@ def loop(acc, microbatch_and_key): with jax.named_scope("accum"): acc = eqx.apply_updates(acc, this_r) - acc = hax.shard_with_axis_mapping(acc, accum_axis_mapping) + acc = hax.shard(acc, accum_axis_mapping) return acc @@ -134,7 +134,7 @@ def _reshape(x): if not x.has_axis(Batch.name): return x x = x.unflatten_axis(Batch, (AccumStep, Microbatch)) - return hax.shard_with_axis_mapping(x, axis_mapping) + return hax.shard(x, axis_mapping) elif isinstance(x, jnp.ndarray): x = x.reshape((AccumStep.size, Microbatch.size) + x.shape[1:]) return with_sharding_constraint(x, PartitionSpec(None, ResourceAxis.DATA, *(None,) * (len(x.shape) - 2))) diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 2937218bc..24be4b11d 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -87,7 +87,7 @@ def compute_loss(model: LmHeadModel, example: LmExample): # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model model = load_checkpoint(model, config.checkpoint_path, subpath="model") - model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) + model = hax.shard(model, config.trainer.param_env) loss = callbacks.eval_loss_loop(compute_loss, model, eval_loader, max_batches=total) diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 9c31e63b6..e01c43c0c 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -475,8 +475,7 @@ def from_hf_pretrained( lev_model = eqx.filter_eval_shape(MptLmHeadModel.init, Vocab, lev_config, key=PRNGKey(0)) lev_model = lev_model.from_state_dict(state_dict) - if axis_mapping is not None: - lev_model = haliax.shard_with_axis_mapping(lev_model, axis_mapping) + lev_model = haliax.shard(lev_model, axis_mapping) return lev_model diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 7e56bbc84..b8ccb4631 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -428,10 +428,11 @@ def add_eval_hook(self, eval_dataset, name: Optional[str] = None): if eval_loader and (self.config.max_eval_batches is None or self.config.max_eval_batches > 0): - @eqx.filter_jit + @named_jit(axis_resources=self.param_env, donate_args=(False)) def eval_loss(model, *batch, **batch_kwargs): model = inference_mode(model, True) - with hax.axis_mapping(self.compute_axis_mapping): + # TODO: should we do in full precision? + with self.compute_env: return self.loss_fn(model, *batch, **batch_kwargs, key=None) self.add_hook( diff --git a/tests/test_backpack.py b/tests/test_backpack.py index 0b1af96e1..f6b841f5a 100644 --- a/tests/test_backpack.py +++ b/tests/test_backpack.py @@ -8,7 +8,6 @@ import haliax import haliax as hax from haliax import Axis -from haliax.partitioning import round_axis_for_partitioning from levanter.models.backpack import BackpackConfig, BackpackLMHeadModel from levanter.trainer import TrainerConfig @@ -22,7 +21,7 @@ def test_backpack_predict(): trainer_config = TrainerConfig() - Vocab = round_axis_for_partitioning(Axis("vocab", VOCAB_SIZE), trainer_config.compute_axis_mapping) + Vocab = Axis("vocab", VOCAB_SIZE) model_config = BackpackConfig() model_key = PRNGKey(0) model = BackpackLMHeadModel.init(Vocab, model_config, key=model_key) diff --git a/tests/test_levanter_hf_consistency.py b/tests/test_levanter_hf_consistency.py index 9a0aadb61..5fafe9791 100644 --- a/tests/test_levanter_hf_consistency.py +++ b/tests/test_levanter_hf_consistency.py @@ -5,7 +5,6 @@ import haliax as hax from haliax import Axis -from haliax.partitioning import round_axis_for_partitioning from levanter.checkpoint import load_checkpoint from levanter.models.backpack import BackpackLMHeadModel @@ -34,8 +33,7 @@ def test_hf_backpack_consistency(): model_config: BackpackConfig = BackpackConfig.from_hf_config(hf_model_config) trainer_config = TrainerConfig() - vocab_size = hf_model_config.vocab_size - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), trainer_config.compute_axis_mapping) + Vocab = Axis("vocab", hf_model_config.vocab_size) model_key = PRNGKey(0) model_levanter = BackpackLMHeadModel.init(Vocab, model_config, key=model_key) model_levanter, (_, _), _ = load_checkpoint( @@ -59,18 +57,17 @@ def test_hf_gpt2_consistency(): from levanter.models.gpt2 import Gpt2Config - model_config: GPT2Config = Gpt2Config.from_hf_config(hf_model_config) + model_config = Gpt2Config.from_hf_config(hf_model_config) trainer_config = TrainerConfig() - vocab_size = hf_model_config.vocab_size - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), trainer_config.compute_axis_mapping) + Vocab = Axis("vocab", hf_model_config.vocab_size) model_key = PRNGKey(0) model_levanter = Gpt2LMHeadModel.init(Vocab, model_config, key=model_key) - model_levanter, (_, _), _ = load_checkpoint( + model_levanter = load_checkpoint( model_levanter, - (None, None), checkpoint_path=LEVANTER_GPT2_CHECKPOINT, discover_latest=True, + subpath="model", ) mp = trainer_config.mp model_levanter = mp.cast_to_param(model_levanter) From b57e1c7e98066a230c27e68aa4e24f6b442f2802 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 1 Feb 2024 12:59:08 -0800 Subject: [PATCH 175/205] please pre-commit --- src/levanter/main/doremi_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py index 850a2c730..294d75c72 100644 --- a/src/levanter/main/doremi_lm.py +++ b/src/levanter/main/doremi_lm.py @@ -90,7 +90,7 @@ def main(config: TrainLmConfig): else: ref_model_shape = eqx.filter_eval_shape(config.model.build, Vocab, key=jrandom.PRNGKey(0)) ref_model = levanter.checkpoint.load_checkpoint( - ref_model_shape, config.ref_model_path, axis_mapping=parameter_axis_mapping, subpath="model" + ref_model_shape, config.ref_model_path, env=config.trainer.param_env, subpath="model" ) ref_model = inference_mode(ref_model, True) From 7fd46cbe7e08aaeedc964a9c8463d85bf5fa0c4c Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 1 Feb 2024 13:19:00 -0800 Subject: [PATCH 176/205] ok maybe we can do it? --- src/levanter/data/loader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index fd9b05463..2ec2e6169 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -50,9 +50,7 @@ def __init__(self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: def __iter__(self) -> Iterator[Ex]: def produce_batches(): - # NB: we do not want to use the mesh here, because it makes JAX unhappy - # TODO: figure out why we can't use mesh here and fix. - with hax.axis_mapping(self.resource_env.axis_mapping): + with self.resource_env: for batch in self._produce_batches(): yield batch From 2ca4d97e7264a652162850baf27bdf8793a0b68d Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 1 Feb 2024 13:25:00 -0800 Subject: [PATCH 177/205] sigh --- tests/test_viz_lm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 25d5e8fb0..71d117055 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -18,7 +18,11 @@ def setup_module(module): ray_designated_cores = max(1, logical_cpu_core_count()) - ray.init("local", num_cpus=ray_designated_cores) + try: + ray.init("local", num_cpus=ray_designated_cores) + except AssertionError: + # don't get upset if ray is already running + pass def teardown_module(module): From a237a57809f6a4d96c2977dd83a228965a55a7de Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 1 Feb 2024 14:29:40 -0800 Subject: [PATCH 178/205] fix test_weight_decay_mask.py --- tests/test_weight_decay_mask.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_weight_decay_mask.py b/tests/test_weight_decay_mask.py index c47231116..cc94c5749 100644 --- a/tests/test_weight_decay_mask.py +++ b/tests/test_weight_decay_mask.py @@ -18,8 +18,8 @@ def apply_weight_decay(tree): nodes = [] # apply on embedding - nodes.append(tree.embeddings.token_embeddings.array) - nodes.append(tree.embeddings.position_embeddings.array) + nodes.append(tree.embeddings.token_embeddings.weight.array) + nodes.append(tree.embeddings.position_embeddings.weight.array) # apply on attention nodes.append(tree.transformer.blocks.stacked.attn.c_attn.weight.array) @@ -49,8 +49,8 @@ def apply_weight_decay(tree): "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight", - "token_embeddings", - "position_embeddings", + "token_embeddings.weight", + "position_embeddings.weight", ] ) regex_config = AdamConfig( From 528269452a919a3770a1e57fc05c6e306a469a35 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 1 Feb 2024 14:39:16 -0800 Subject: [PATCH 179/205] use param_env everywhere --- src/levanter/doremi.py | 2 +- src/levanter/main/viz_logprobs.py | 2 +- src/levanter/trainer.py | 2 +- tests/test_replicated_loader.py | 14 +++++++------- tests/test_sharded_loader.py | 18 +++++++++--------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 3e3c3c84b..1924ebe29 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -147,7 +147,7 @@ def eval_loss(model, *batch, **batch_kwargs): @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True,)) def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) - with hax.axis_mapping(trainer.compute_axis_mapping): + with trainer.compute_env: # calculate per-token losses for proxy and ref proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy) ref_losses = loss_fn(ref, batch, reduction_axis=()) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 9e8a37059..c43525cd1 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -53,7 +53,7 @@ def main(config: VizGpt2Config): compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping - with config.trainer.device_mesh, hax.axis_mapping(parameter_axis_mapping): + with config.trainer.param_env: key = jax.random.PRNGKey(0) vocab_size = len(tokenizer) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index b8ccb4631..3842699fe 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -483,7 +483,7 @@ def loss_fn(model, *batch, **batch_kwargs): # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us loss, grads = self._compute_gradients_microbatched(loss_fn, model, batch, **batch_kwargs, key=key) - with hax.axis_mapping(self.parameter_axis_mapping): + with self.param_env: partial_loss = lambda model: loss_fn(model, *batch, **batch_kwargs) model, opt_state = take_opt_step( self.optimizer, model, state.opt_state, grads, partial_loss, state.is_trainable diff --git a/tests/test_replicated_loader.py b/tests/test_replicated_loader.py index 64774546b..347c153d5 100644 --- a/tests/test_replicated_loader.py +++ b/tests/test_replicated_loader.py @@ -40,7 +40,7 @@ def test_local_batched_data_loading_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) @@ -60,7 +60,7 @@ def test_local_batched_data_loading_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) @@ -105,7 +105,7 @@ def test_structured_batches_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) @@ -125,7 +125,7 @@ def test_structured_batches_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) @@ -180,7 +180,7 @@ def test_structured_batches_model_axis_1_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) @@ -203,7 +203,7 @@ def test_structured_batches_model_axis_2_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) @@ -227,7 +227,7 @@ def test_structured_batches_model_axis_2_subsharded(): ) Height = Axis("Height", 16) Width = Axis("Width", 16) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): + with haliax.resource_env({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}, mesh=mesh): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) loader = ReplicatedBatchLoader(dataset, Batch) diff --git a/tests/test_sharded_loader.py b/tests/test_sharded_loader.py index dfcef498a..e83a21e02 100644 --- a/tests/test_sharded_loader.py +++ b/tests/test_sharded_loader.py @@ -44,7 +44,7 @@ def test_sharded_data_loading_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) @@ -63,7 +63,7 @@ def test_sharded_data_loading_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) @@ -107,7 +107,7 @@ def test_structured_batches_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) @@ -141,7 +141,7 @@ def test_can_batch_named_scalars(): model_axis_size = 1 mesh = Mesh(np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL)) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): dataset = ScalarDataset(0, 256, 1) Batch = Axis("batch", len(devices)) loader = ShardedBatchLoader(dataset, Batch) @@ -160,7 +160,7 @@ def test_structured_batches_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) @@ -216,7 +216,7 @@ def test_structured_batches_model_axis_1_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) @@ -237,7 +237,7 @@ def test_structured_batches_model_axis_2_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) @@ -261,7 +261,7 @@ def test_structured_batches_model_axis_2_subsharded(): ) Height = Axis("Height", 16) Width = Axis("Width", 16) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): + with hax.resource_env({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}, mesh=mesh): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) loader = ShardedBatchLoader(dataset, Batch) @@ -279,7 +279,7 @@ def test_sharded_loader_doesnt_throw_away_data(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): dataset = ScalarDataset(0, 256, 1) Batch = Axis("batch", len(devices)) loader = ShardedBatchLoader(dataset, Batch) From a013c4c20799998e1d637ecd5691bc338cd9d1c6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 1 Feb 2024 14:48:36 -0800 Subject: [PATCH 180/205] makldmlkad --- docs/LoRA.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/LoRA.md b/docs/LoRA.md index 7d5c15c74..4d8cdf099 100644 --- a/docs/LoRA.md +++ b/docs/LoRA.md @@ -107,6 +107,7 @@ parameters are sharded correctly, if you're using more than one device. @dataclass class TrainArgs: lora: LoraConfig = LoraConfig() + trainer: TrainerConfig = TrainerConfig() # ... some other stuff hf_save_path: Optional[str] = None # Path to save the HuggingFace checkpoint. @@ -120,7 +121,7 @@ class TrainArgs: def train(config: TrainArgs): ... - with config.trainer.device_mesh: + with Trainer(config.trainer, optimizer) as trainer: ... @hax.named_jit(axis_resources=parameter_axis_mapping, donate_args=(True)) @@ -143,12 +144,12 @@ using the `lora_trainable_params_filter` function, which takes a model and retur ```python def train(config: TrainArgs): ... - with config.trainer.device_mesh: + with Trainer(config.trainer, optimizer) as trainer: ... lora_param_filter = lora_trainable_params_filter(model) - trainer = Trainer(config.trainer, optimizer, is_trainable=lora_param_filter) + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) ``` ### 3. Serialize a PEFT-compatible checkpoint From 58ca1d756b44b70406274e48ce4d285b1b76b3da Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 2 Feb 2024 12:51:35 -0800 Subject: [PATCH 181/205] wip debugging devices --- src/levanter/data/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 2ec2e6169..40510578b 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -43,7 +43,6 @@ class BatchLoader(Iterable[Ex], abc.ABC): Batch: hax.Axis def __init__(self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: Optional[int]): - self.max_capacity = max_capacity self.resource_env = resource_env or hax.current_resource_env() self.Batch = Batch @@ -194,6 +193,9 @@ def _produce_batches(self) -> Iterator[PyTree]: batch_offset = self.process_data_pos * self.local_batch_size local_batch: List[PyTree] = next(batched) + leaves = jtu.tree_leaves(local_batch[0]) + print([a.devices() for a in leaves if hasattr(a, "devices")]) + batch = self._construct_global_array_for_tree( item_exemplar=local_batch[0], get_batch_items=lambda begin, end: local_batch[(begin - batch_offset) : (end - batch_offset)], From 6ee6d8fe38773218ebc23730e27d0c70d32d86ac Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 2 Feb 2024 14:34:46 -0800 Subject: [PATCH 182/205] let's try this? --- infra/helpers/setup-tpu-vm.sh | 2 +- src/levanter/data/text.py | 52 +++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/infra/helpers/setup-tpu-vm.sh b/infra/helpers/setup-tpu-vm.sh index e1facd8cb..d7fc87653 100755 --- a/infra/helpers/setup-tpu-vm.sh +++ b/infra/helpers/setup-tpu-vm.sh @@ -93,7 +93,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) #retry pip install -U "jax[tpu]==0.4.5" libtpu-nightly==0.1.dev20230216 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -retry pip install -U "jax[tpu]==0.4.21" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retry pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 212318433..6c2c89ce2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -11,6 +11,7 @@ import braceexpand import datasets +import equinox as eqx import fsspec import jax import numpy as np @@ -50,7 +51,9 @@ ) from levanter.data.sharded_dataset import ShardedDataset, TextUrlDataset, WrappedHFDataset # noqa from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa -from levanter.utils.jax_utils import use_cpu_device # noqa + + +# from levanter.utils.jax_utils import use_cpu_device # noqa logger = logging.getLogger("levanter.data.text") @@ -91,29 +94,32 @@ def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": def __iter__(self) -> Iterator[LmExample]: key = self.key + + sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) + + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_lm_example(tokens, key): + tokens = hax.named(tokens, self.QPos) + + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + + if self.fcm_prob > 0: + # masks for attention + # We support forgetful causal masking (FCM) which is a technique that improves training speed by + # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention + # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 + assert self.key is not None + this_key, key = jax.random.split(key) + fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) + attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) + example = dataclasses.replace(example, attn_mask=attn_mask) + + return example + for tokens in self.dataset: - with use_cpu_device(): - example = self._create_lm_example(tokens, key) - yield example - - @functools.partial(jax.jit, static_argnums=(0)) - def _create_lm_example(self, tokens, key): - tokens = hax.named(tokens, self.QPos) - - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - - if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 - assert self.key is not None - this_key, key = jax.random.split(key) - fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) - attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) - example = dataclasses.replace(example, attn_mask=attn_mask) - - return example + example = _create_lm_example(tokens, key) + print("?", example.tokens.array.devices()) + yield example class TokenSeqDataset(ShardableDataset[np.ndarray]): From b485673186b186af46b25c8582e17306d6ae3b8d Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 2 Feb 2024 14:45:00 -0800 Subject: [PATCH 183/205] so confused --- src/levanter/data/text.py | 53 ++++++++++++++++----------------- src/levanter/utils/jax_utils.py | 16 ++++++++-- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 6c2c89ce2..be424fee1 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -51,9 +51,7 @@ ) from levanter.data.sharded_dataset import ShardedDataset, TextUrlDataset, WrappedHFDataset # noqa from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa - - -# from levanter.utils.jax_utils import use_cpu_device # noqa +from levanter.utils.jax_utils import use_cpu_device # noqa logger = logging.getLogger("levanter.data.text") @@ -94,32 +92,33 @@ def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": def __iter__(self) -> Iterator[LmExample]: key = self.key - sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) - @functools.partial(eqx.filter_jit, out_shardings=sharding) - def _create_lm_example(tokens, key): - tokens = hax.named(tokens, self.QPos) - - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - - if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 - assert self.key is not None - this_key, key = jax.random.split(key) - fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) - attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) - example = dataclasses.replace(example, attn_mask=attn_mask) - - return example - - for tokens in self.dataset: - example = _create_lm_example(tokens, key) - print("?", example.tokens.array.devices()) - yield example + with use_cpu_device(): + + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_lm_example(tokens, key): + tokens = hax.named(tokens, self.QPos) + + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + + if self.fcm_prob > 0: + # masks for attention + # We support forgetful causal masking (FCM) which is a technique that improves training speed by + # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention + # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 + assert self.key is not None + this_key, key = jax.random.split(key) + fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) + attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) + example = dataclasses.replace(example, attn_mask=attn_mask) + + return example + + for tokens in self.dataset: + example = _create_lm_example(tokens, key) + print("?", example.tokens.array.devices()) + yield example class TokenSeqDataset(ShardableDataset[np.ndarray]): diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index a1253b500..0a56f07e2 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -7,6 +7,7 @@ import equinox as eqx import jax from jax import numpy as jnp +from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree from haliax.jax_utils import is_jax_array_like @@ -27,8 +28,19 @@ def jnp_to_python(a: jnp.ndarray): @contextlib.contextmanager def use_cpu_device(): """Temporarily sets the default device to CPU""" - with jax.default_device(jax.local_devices(backend="cpu")[0]): - yield + # If we have a mesh, we need to make a new version of that mesh + from haliax import current_resource_env + + mesh = current_resource_env().mesh + cpu = jax.local_devices(backend="cpu")[0] + if mesh is None: + with jax.default_device(cpu): + yield + else: + mesh_axis_names = mesh.axis_names + new_mesh = Mesh(jnp.array([cpu]).reshape((1,) * len(mesh_axis_names)), axis_names=mesh_axis_names) + with jax.default_device(cpu), new_mesh: + yield def is_inside_jit(): From de3162b14593f2726b508e9b451c1904322231cc Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 2 Feb 2024 14:49:12 -0800 Subject: [PATCH 184/205] sigh --- src/levanter/utils/jax_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 0a56f07e2..92af33648 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -6,6 +6,7 @@ import equinox as eqx import jax +import numpy as np from jax import numpy as jnp from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree @@ -38,7 +39,7 @@ def use_cpu_device(): yield else: mesh_axis_names = mesh.axis_names - new_mesh = Mesh(jnp.array([cpu]).reshape((1,) * len(mesh_axis_names)), axis_names=mesh_axis_names) + new_mesh = Mesh(np.array([cpu]).reshape((1,) * len(mesh_axis_names)), axis_names=mesh_axis_names) with jax.default_device(cpu), new_mesh: yield From 343367f5f42bb91901538a46eb65ca8e5aaa644a Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 2 Feb 2024 14:53:56 -0800 Subject: [PATCH 185/205] ok i think i got it --- src/levanter/data/loader.py | 3 --- src/levanter/data/text.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 40510578b..a879cc06a 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -193,9 +193,6 @@ def _produce_batches(self) -> Iterator[PyTree]: batch_offset = self.process_data_pos * self.local_batch_size local_batch: List[PyTree] = next(batched) - leaves = jtu.tree_leaves(local_batch[0]) - print([a.devices() for a in leaves if hasattr(a, "devices")]) - batch = self._construct_global_array_for_tree( item_exemplar=local_batch[0], get_batch_items=lambda begin, end: local_batch[(begin - batch_offset) : (end - batch_offset)], diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index be424fee1..f89a75ca2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -117,7 +117,6 @@ def _create_lm_example(tokens, key): for tokens in self.dataset: example = _create_lm_example(tokens, key) - print("?", example.tokens.array.devices()) yield example From 71b755ef8a6408d52cdf4bbf57e7dc248619d48a Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 10:42:38 -0800 Subject: [PATCH 186/205] wtf --- src/levanter/data/shard_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f33329dd5..9fe927209 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -544,6 +544,8 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: total_chunk_rows = 0 # the total number of rows in the chunk batch_result_ref = None + self.group.logger.info(f"Reading one chunk of shard {self.shard_name}: {self.chunk_idx}") + try: while not chunk_filled: batch = next(self.reader, None) From 07b579745405664777e6af33244d7147d7e9d2f2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 11:07:26 -0800 Subject: [PATCH 187/205] this async seems like a bad idea --- src/levanter/data/shard_cache.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 9fe927209..2c39072b5 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -460,6 +460,7 @@ class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): processor_actor: ray.actor.ActorHandle # BatchProcessorQueue batch_size: int num_rows_per_chunk: int + group_id: int def build(self) -> "PriorityWorkTaskGroup": return ShardGroupTaskGroup(self) @@ -468,7 +469,7 @@ def build(self) -> "PriorityWorkTaskGroup": class ShardGroupTaskGroup(PriorityWorkTaskGroup): def __init__(self, spec: ShardGroupToBeProcessed): self.spec = spec - self.logger = pylogging.getLogger(f"shard_reader.{self.spec.name}") + self.logger = pylogging.getLogger(f"shard_reader.{spec.name}.{spec.group_id}") try: metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( @@ -565,7 +566,6 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref) ) chunk_batch_idx += 1 - # enqueue_to_backpressure(batch, batch_result_ref) del batch if total_chunk_rows >= self.spec.num_rows_per_chunk or exhausted_shard: @@ -580,7 +580,7 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if exhausted_shard: writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - logger.debug(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") + logger.info(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") return exhausted_shard, batch_result_ref except Exception as e: # noqa @@ -857,10 +857,10 @@ def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): def current_metadata(self, shard_name: str): return self.shard_writers[shard_name].current_metadata() - async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox): + def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox): # batch is a pa.RecordBatch ref box try: - batch = await batch.ref + batch = ray.get(batch.ref) return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) except Exception as e: print(f"Error while processing batch {batch_idx} of chunk {chunk_id} of shard {shard_name}", flush=True) @@ -1129,6 +1129,7 @@ def priority_fn(shard_idx, chunk_idx): processor_actor=processor_actor, batch_size=processor.batch_size, num_rows_per_chunk=rows_per_chunk, + group_id=group_id, ) # we want global names so that different tasks can coordinate priorities From b6e0c1d04c1d3afaf61b38b0436cccff369b919f Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 11:12:31 -0800 Subject: [PATCH 188/205] log perf numbers? --- src/levanter/data/shard_cache.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 2c39072b5..6c5b8625f 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -559,11 +559,14 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if batch: priority = self.spec.priority_fn(self.shard_idx, self.chunk_idx) + # these times aren't exact because the times might be from different machines + # but they're just for logging + time_in = time.time() batch_result_ref = ray.get( self.spec.processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) ) writer.chunk_batch_finished.remote( - self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref) + self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref), time_in ) chunk_batch_idx += 1 del batch @@ -857,10 +860,19 @@ def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): def current_metadata(self, shard_name: str): return self.shard_writers[shard_name].current_metadata() - def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox): + def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox, time_in): # batch is a pa.RecordBatch ref box try: + time_mid = time.time() + logger.info( + f"Received in progress batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in" + f" {time_mid - time_in}" + ) batch = ray.get(batch.ref) + logger.info( + f"Received finished batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in total" + f" {time.time() - time_in}" + ) return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) except Exception as e: print(f"Error while processing batch {batch_idx} of chunk {chunk_id} of shard {shard_name}", flush=True) From a3f9c7f3c378b51f8e0df53367b6e78d70a7428f Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 11:34:56 -0800 Subject: [PATCH 189/205] more logging --- src/levanter/data/shard_cache.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 6c5b8625f..7d6dc1e74 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -948,13 +948,9 @@ def chunk_failed(self, chunk_id: int, error: ExceptionInfo): self.parent_ref.shard_failed.remote(self.shard_name, error) def _finished_chunk(self, idx: int, chunk: ChunkMetadata): - if idx < self.metadata_writer.num_chunks: - logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") - error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") - self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) - raise error - - if self._expected_num_chunks is not None and idx >= self._expected_num_chunks: + if (idx < self.metadata_writer.num_chunks) or ( + self._expected_num_chunks is not None and idx >= self._expected_num_chunks + ): logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) @@ -975,6 +971,8 @@ def _attempt_to_commit_chunks(self): chunks_committed = [] while len(self.uncommited_chunks) > 0 and self.uncommited_chunks[0][0] == self.metadata_writer.num_chunks: _, chunk = heapq.heappop(self.uncommited_chunks) + chunk_number = self.metadata_writer.num_chunks + logger.info(f"Committing chunk {chunk.name} of shard {self.shard_name}. It is chunk {chunk_number}") self.metadata_writer.commit_chunk(chunk) chunks_committed.append(chunk) @@ -1284,6 +1282,9 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr # used to subscribe to metrics updates self._latest_metrics = InProgressCacheMetrics() self._metrics_condition = asyncio.Condition() + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"broker::{path_for_name}" + self.logger = pylogging.getLogger(f"{name}") # initialize writer task # first see if we need to do anything: check the ledger for is_finished @@ -1295,7 +1296,6 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr except FileNotFoundError: self_ref = ray.runtime_context.get_runtime_context().current_actor # only use the last two components of the name since it gets kind of long - path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) name = f"builder::{path_for_name}" self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, name, self._source, self._processor, self._rows_per_chunk) # type: ignore @@ -1322,7 +1322,6 @@ async def get_chunk(self, chunk_idx: int) -> Optional[ChunkMetadata]: elif self._is_finished: return None else: - # we don't have this chunk yet, so we need to wait if chunk_idx not in self._reader_promises: self._reader_promises[chunk_idx] = asyncio.Future() return await self._reader_promises[chunk_idx] @@ -1337,7 +1336,9 @@ def _append_chunks(self, *chunks: ChunkMetadata): for chunk in chunks: self.chunks.append(chunk) chunk_idx = len(self.chunks) - 1 + self.logger.info(f"Received chunk {chunk_idx}") if chunk_idx in self._reader_promises: + self.logger.info(f"Resolving promise for chunk {chunk_idx}") self._reader_promises[chunk_idx].set_result(chunk) del self._reader_promises[chunk_idx] @@ -1545,13 +1546,14 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N time_in = time.time() # we want to also log if we're waiting for a long time, so we do this in a loop while timeout is None or time.time() - time_in < timeout: + next_time = time.time() current_timeout = 20.0 if timeout is not None: - current_timeout = min(current_timeout, timeout - (time.time() - time_in)) + current_timeout = min(current_timeout, timeout - (next_time - time_in)) try: chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) except GetTimeoutError: - self.logger.warning(f"Waiting for chunk {mapped_index} for {int(time.time() - time_in)} seconds") + self.logger.warning(f"Waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds") current_timeout *= 2 current_timeout = min(current_timeout, 80) continue From ce2db7b6c338e059219aa66bbdb05f1dfb0e7f6b Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 11:51:43 -0800 Subject: [PATCH 190/205] moar --- src/levanter/data/shard_cache.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 7d6dc1e74..0ef9df2bc 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1241,18 +1241,23 @@ def _attempt_to_flush_buffers(self): next_chunk = status.pop_chunk_to_send() if next_chunk is not None: # we can send a chunk from this shard - logger.debug(f"Sending chunk from {name}") + self.logger.info(f"Sending chunk from {name}") self._current_round_robin.pop(0) self._current_round_robin.append(name) chunks_to_send.append(next_chunk) continue else: - logger.debug(f"Shard {name} has no chunks to send and is not known to be finished") + chunks_waiting = [f"{n2} ({len(s2.current_buffer)})" for n2, s2 in self.shard_status.items()] + msg = ( + f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" + f" chunks: {chunks_waiting}" + ) + self.logger.info(msg) # we can't send a chunk from this shard, so we can't send any additional chunks break if len(chunks_to_send) > 0: - logger.debug(f"Sending {len(chunks_to_send)} chunks to broker") + logger.info(f"Sending {len(chunks_to_send)} chunks to broker") ray.get(self.broker_ref._append_chunks.remote(*chunks_to_send)) def _finish(self): From ea57bde33f103dd430dfa460e8c84948bc4b5f7d Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 11:57:23 -0800 Subject: [PATCH 191/205] oops --- src/levanter/data/shard_cache.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 0ef9df2bc..aeb483cb9 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1107,6 +1107,7 @@ def __init__( self._processor_actors = [] for shard_name in source.shard_names: + self._current_round_robin.append(shard_name) self.shard_status[shard_name] = _ShardStatus() num_shards = len(source.shard_names) @@ -1118,7 +1119,6 @@ def priority_fn(shard_idx, chunk_idx): shard_groups: list[list[str]] = [[] for _ in range(num_shard_groups)] for i, shard_name in enumerate(source.shard_names): - self._current_round_robin.append(shard_name) shard_groups[i % num_shard_groups].append(shard_name) for group_id, shard_group in enumerate(shard_groups): @@ -1247,7 +1247,11 @@ def _attempt_to_flush_buffers(self): chunks_to_send.append(next_chunk) continue else: - chunks_waiting = [f"{n2} ({len(s2.current_buffer)})" for n2, s2 in self.shard_status.items()] + chunks_waiting = [ + f"{n2} ({len(s2.current_buffer)})" + for n2, s2 in self.shard_status.items() + if len(s2.current_buffer) > 0 + ] msg = ( f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" f" chunks: {chunks_waiting}" From d352b3750e5d5caf428c6b7ac326ded524781168 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 13:02:51 -0800 Subject: [PATCH 192/205] reduce logging some, try to figure out this stupid books problem --- src/levanter/data/shard_cache.py | 77 +++++++++++++++++++++----------- src/levanter/data/text.py | 2 +- 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index aeb483cb9..ec36c4a83 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -469,7 +469,7 @@ def build(self) -> "PriorityWorkTaskGroup": class ShardGroupTaskGroup(PriorityWorkTaskGroup): def __init__(self, spec: ShardGroupToBeProcessed): self.spec = spec - self.logger = pylogging.getLogger(f"shard_reader.{spec.name}.{spec.group_id}") + self.logger = pylogging.getLogger(f"shard_reader.{spec.group_id}.{spec.name}") try: metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( @@ -545,7 +545,7 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: total_chunk_rows = 0 # the total number of rows in the chunk batch_result_ref = None - self.group.logger.info(f"Reading one chunk of shard {self.shard_name}: {self.chunk_idx}") + self.group.logger.debug(f"Reading one chunk of shard {self.shard_name}: {self.chunk_idx}") try: while not chunk_filled: @@ -563,7 +563,11 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: # but they're just for logging time_in = time.time() batch_result_ref = ray.get( - self.spec.processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) + self.spec.processor_actor.submit.remote( + priority=priority, + desc=f"{self.shard_name}.{self.chunk_idx}.{chunk_batch_idx}", + batch=RefBox(ray.put(batch)), + ) ) writer.chunk_batch_finished.remote( self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref), time_in @@ -776,12 +780,16 @@ def is_finished_and_buffer_empty(self): def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) - def process_task(batch: List[T]) -> pa.RecordBatch: + def process_task(desc, batch: List[T]) -> pa.RecordBatch: pylogging.basicConfig(level=pylogging.INFO) queue.task_running.remote() - result = processor(batch) - del batch - return as_record_batch(result) + try: + result = processor(batch) + del batch + return as_record_batch(result) + except Exception as e: + logger.exception(f"Error while processing batch {desc}") + raise e return process_task @@ -789,6 +797,7 @@ def process_task(batch: List[T]) -> pa.RecordBatch: @dataclass(order=True, frozen=True) class _QueueItem: priority: float + desc: str batch: ray.ObjectRef = dataclasses.field(compare=False) task_id: int task_future: asyncio.Future = dataclasses.field(compare=False) @@ -823,13 +832,13 @@ def __init__(self, batch_processor: BatchProcessor[T]): # we don't need/want to dereference the batch, so we wrap it in a RefBox # one virtue of doing things this way is that we can let Ray try to schedule the compute near the data. - async def submit(self, priority: float, batch: RefBox): + async def submit(self, priority: float, desc: str, batch: RefBox): """Returns a future that is set to the *ObjectRef* of the processed batch. The future is "complete" when the task starts, not when it finishes. You then call ray.get on the future's result to get the actual batch.""" task_id = self._next_task_id self._next_task_id += 1 f: asyncio.Future = asyncio.Future() - self.pqueue.put(_QueueItem(priority, batch.ref, task_id, f)) + self.pqueue.put(_QueueItem(priority, desc, batch.ref, task_id, f)) self._maybe_start_task() return await f @@ -838,7 +847,7 @@ def _maybe_start_task(self): self.ready = False item = self.pqueue.get() batch = item.batch - item.task_future.set_result(self._task_processor.remote(batch)) + item.task_future.set_result(self._task_processor.remote(item.desc, batch)) def task_running(self): self.ready = True @@ -860,18 +869,34 @@ def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): def current_metadata(self, shard_name: str): return self.shard_writers[shard_name].current_metadata() - def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox, time_in): + async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox, time_in): # batch is a pa.RecordBatch ref box try: time_mid = time.time() - logger.info( + logger.debug( f"Received in progress batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in" f" {time_mid - time_in}" ) - batch = ray.get(batch.ref) - logger.info( + # do a backoff loop until the batch is actually processed. log if it's been a while + timeout_interval = 5 + total_time_waited = 0 + + while True: + try: + batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) + # to keep to round numbers, we log how much we asked for rather than how much we got + total_time_waited += timeout_interval + timeout_interval = min(2 * timeout_interval, 60) + break + except asyncio.TimeoutError: + logger.info( + f"Waiting for batch {batch_idx} of chunk {chunk_id} of shard {shard_name} to be processed." + f"Waited {total_time_waited} seconds." + ) + + logger.debug( f"Received finished batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in total" - f" {time.time() - time_in}" + f" {(time.time() - time_in):.2f} seconds." ) return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) except Exception as e: @@ -1241,23 +1266,23 @@ def _attempt_to_flush_buffers(self): next_chunk = status.pop_chunk_to_send() if next_chunk is not None: # we can send a chunk from this shard - self.logger.info(f"Sending chunk from {name}") self._current_round_robin.pop(0) self._current_round_robin.append(name) chunks_to_send.append(next_chunk) continue else: - chunks_waiting = [ - f"{n2} ({len(s2.current_buffer)})" - for n2, s2 in self.shard_status.items() - if len(s2.current_buffer) > 0 - ] - msg = ( - f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" - f" chunks: {chunks_waiting}" - ) - self.logger.info(msg) # we can't send a chunk from this shard, so we can't send any additional chunks + if self.logger.level <= pylogging.DEBUG: + chunks_waiting = [ + f"{n2} ({len(s2.current_buffer)})" + for n2, s2 in self.shard_status.items() + if len(s2.current_buffer) > 0 + ] + msg = ( + f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" + f" chunks: {chunks_waiting}" + ) + self.logger.debug(msg) break if len(chunks_to_send) > 0: diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index f89a75ca2..67e922a5d 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -357,7 +357,7 @@ def num_gpus(self) -> int: @property def batch_size(self) -> int: - return 1024 + return 128 def concatenate_and_group_texts( From 0effef068dd4f86b47f49824b60d1ac4c5c7c9be Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 13:43:43 -0800 Subject: [PATCH 193/205] ka dkla dkl --- src/levanter/data/shard_cache.py | 51 ++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index ec36c4a83..0f888d677 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -587,7 +587,9 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if exhausted_shard: writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - logger.info(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") + self.group.logger.info( + f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}" + ) return exhausted_shard, batch_result_ref except Exception as e: # noqa @@ -778,11 +780,35 @@ def is_finished_and_buffer_empty(self): return self.expected_num_chunks is not None and self.num_chunks_sent >= self.expected_num_chunks +class WaitTimeReportingThread(threading.Thread): + def __init__(self, report, interval=60): + super().__init__() + self.report = report + self.interval = interval + self.shutdown_event = threading.Event() + + def run(self): + total_waited = 0 + while not self.shutdown_event.is_set(): + if total_waited > 0: + self.report(total_waited) + total_waited += self.interval + time.sleep(self.interval) + + def shutdown(self): + self.shutdown_event.set() + + def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(desc, batch: List[T]) -> pa.RecordBatch: + logger.info(f"Processing batch {desc}") pylogging.basicConfig(level=pylogging.INFO) queue.task_running.remote() + timer_thread = WaitTimeReportingThread( + lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 + ) + timer_thread.start() try: result = processor(batch) del batch @@ -790,6 +816,9 @@ def process_task(desc, batch: List[T]) -> pa.RecordBatch: except Exception as e: logger.exception(f"Error while processing batch {desc}") raise e + finally: + timer_thread.shutdown() + timer_thread.join() return process_task @@ -878,26 +907,30 @@ async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: f" {time_mid - time_in}" ) # do a backoff loop until the batch is actually processed. log if it's been a while - timeout_interval = 5 + timeout_interval = 10 total_time_waited = 0 while True: try: batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) + break + except asyncio.TimeoutError: # to keep to round numbers, we log how much we asked for rather than how much we got total_time_waited += timeout_interval timeout_interval = min(2 * timeout_interval, 60) - break - except asyncio.TimeoutError: logger.info( - f"Waiting for batch {batch_idx} of chunk {chunk_id} of shard {shard_name} to be processed." + f"Waiting for {shard_name}.{chunk_id}.{batch_idx} to be processed. " f"Waited {total_time_waited} seconds." ) - logger.debug( - f"Received finished batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in total" - f" {(time.time() - time_in):.2f} seconds." - ) + if logger.isEnabledFor(pylogging.DEBUG): + logger.debug( + f"Received finished {shard_name}.{chunk_id}.{batch_idx} in {(time.time() - time_in):.2f} seconds." + ) + elif total_time_waited > 10: + logger.info( + f"Waited {total_time_waited} seconds for {shard_name}.{chunk_id}.{batch_idx} to be processed." + ) return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) except Exception as e: print(f"Error while processing batch {batch_idx} of chunk {chunk_id} of shard {shard_name}", flush=True) From 1e85d16ebe92b5ffb6efd65a92f0bb61e2258961 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 14:12:57 -0800 Subject: [PATCH 194/205] admaldl --- src/levanter/data/shard_cache.py | 2 +- src/levanter/data/text.py | 2 +- src/levanter/utils/hf_utils.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 0f888d677..2615f59a3 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -802,8 +802,8 @@ def shutdown(self): def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(desc, batch: List[T]) -> pa.RecordBatch: - logger.info(f"Processing batch {desc}") pylogging.basicConfig(level=pylogging.INFO) + logger.info(f"Processing batch {desc}") queue.task_running.remote() timer_thread = WaitTimeReportingThread( lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 67e922a5d..93e602041 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -357,7 +357,7 @@ def num_gpus(self) -> int: @property def batch_size(self) -> int: - return 128 + return 512 def concatenate_and_group_texts( diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 408a8c8da..879162a65 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,7 +18,9 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - return min(max(1, logical_cpu_core_count() - 2), 32) + # Really it's dependent on the number of docs, but that's not something we + # can easily know here. + return min(max(1, logical_cpu_core_count() - 2), 16) else: return 1 From a25a8cebfbd90eaf528bff05de9390e67732ab0a Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 19:05:13 -0800 Subject: [PATCH 195/205] fix the unnecessarily long time outs --- src/levanter/data/shard_cache.py | 47 ++++++++++++++++++-------------- tests/test_shard_cache.py | 4 +-- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 2615f59a3..98a043b51 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -49,6 +49,8 @@ DEFAULT_MAX_BYTES_PER_BATCH = 256 * 1024 * 1024 # 256 MB, this is pre-preprocessing python object size LEDGER_FILE_NAME = "cache_ledger.json" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + def build_cache( cache_dir: str, @@ -363,7 +365,7 @@ def __le__(self, other: "PriorityWorkItem"): @ray.remote(num_cpus=1, scheduling_strategy="SPREAD") class PriorityProcessorActor: def __init__(self, max_in_flight: Optional[int] = 200): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self._queue: list[PriorityWorkItem] = [] # heapq self._queue_lock = threading.Lock() self._shutdown_event = threading.Event() @@ -387,7 +389,7 @@ def is_group_finished(self, group: PriorityWorkTaskGroupSpec): if self._current_item is not None and self._current_item.spec == group: return False - logger.info(f"Group {group.name} is finished.") + logger.debug(f"Group {group.name} is finished.") return True @@ -444,9 +446,9 @@ def drain_backpressure_to(count): if not item_is_finished: heapq.heappush(self._queue, item) - logger.info("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") + logger.debug("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") drain_backpressure_to(0) - logger.info("Backpressure drained. Shutting down PriorityProcessorActor.") + logger.debug("Backpressure drained. Shutting down PriorityProcessorActor.") @dataclass @@ -789,11 +791,12 @@ def __init__(self, report, interval=60): def run(self): total_waited = 0 - while not self.shutdown_event.is_set(): + while True: + if self.shutdown_event.wait(self.interval): + break if total_waited > 0: self.report(total_waited) total_waited += self.interval - time.sleep(self.interval) def shutdown(self): self.shutdown_event.set() @@ -802,7 +805,7 @@ def shutdown(self): def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(desc, batch: List[T]) -> pa.RecordBatch: - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) logger.info(f"Processing batch {desc}") queue.task_running.remote() timer_thread = WaitTimeReportingThread( @@ -812,7 +815,9 @@ def process_task(desc, batch: List[T]) -> pa.RecordBatch: try: result = processor(batch) del batch - return as_record_batch(result) + result = as_record_batch(result) + logger.info(f"Finished processing batch {desc}") + return result except Exception as e: logger.exception(f"Error while processing batch {desc}") raise e @@ -888,7 +893,7 @@ def task_running(self): @ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") # type: ignore class _GroupShardWriterWorker: def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.cache_dir = cache_dir self.shard_names = shard_names self.shard_writers: dict[str, _ShardWriterWorker] = { @@ -912,7 +917,8 @@ async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: while True: try: - batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) + # batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) + batch = await batch.ref break except asyncio.TimeoutError: # to keep to round numbers, we log how much we asked for rather than how much we got @@ -961,7 +967,7 @@ def __init__( cache_dir: str, shard_name: str, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.parent_ref = parent_ref self.cache_dir = cache_dir self.shard_name = shard_name @@ -1030,7 +1036,7 @@ def _attempt_to_commit_chunks(self): while len(self.uncommited_chunks) > 0 and self.uncommited_chunks[0][0] == self.metadata_writer.num_chunks: _, chunk = heapq.heappop(self.uncommited_chunks) chunk_number = self.metadata_writer.num_chunks - logger.info(f"Committing chunk {chunk.name} of shard {self.shard_name}. It is chunk {chunk_number}") + logger.debug(f"Committing chunk {chunk.name} of shard {self.shard_name}. It is chunk {chunk_number}") self.metadata_writer.commit_chunk(chunk) chunks_committed.append(chunk) @@ -1144,7 +1150,7 @@ def __init__( processor: BatchProcessor[T], rows_per_chunk: int, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() @@ -1319,7 +1325,7 @@ def _attempt_to_flush_buffers(self): break if len(chunks_to_send) > 0: - logger.info(f"Sending {len(chunks_to_send)} chunks to broker") + logger.debug(f"Sending {len(chunks_to_send)} chunks to broker") ray.get(self.broker_ref._append_chunks.remote(*chunks_to_send)) def _finish(self): @@ -1337,7 +1343,7 @@ class ChunkCacheBroker: _finished_promise: asyncio.Future[None] def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.chunks = [] self._reader_promises = {} self._is_finished = False @@ -1403,9 +1409,9 @@ def _append_chunks(self, *chunks: ChunkMetadata): for chunk in chunks: self.chunks.append(chunk) chunk_idx = len(self.chunks) - 1 - self.logger.info(f"Received chunk {chunk_idx}") + self.logger.debug(f"Received chunk {chunk_idx}") if chunk_idx in self._reader_promises: - self.logger.info(f"Resolving promise for chunk {chunk_idx}") + self.logger.debug(f"Resolving promise for chunk {chunk_idx}") self._reader_promises[chunk_idx].set_result(chunk) del self._reader_promises[chunk_idx] @@ -1611,9 +1617,9 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N else: assert self._broker is not None time_in = time.time() + next_time = time.time() # we want to also log if we're waiting for a long time, so we do this in a loop - while timeout is None or time.time() - time_in < timeout: - next_time = time.time() + while timeout is None or next_time - time_in < timeout: current_timeout = 20.0 if timeout is not None: current_timeout = min(current_timeout, timeout - (next_time - time_in)) @@ -1621,8 +1627,9 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) except GetTimeoutError: self.logger.warning(f"Waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds") + next_time += current_timeout current_timeout *= 2 - current_timeout = min(current_timeout, 80) + current_timeout = min(current_timeout, 100) continue if chunk is None: diff --git a/tests/test_shard_cache.py b/tests/test_shard_cache.py index c1b55ca1e..4ff040bca 100644 --- a/tests/test_shard_cache.py +++ b/tests/test_shard_cache.py @@ -13,10 +13,12 @@ def setup_module(module): + print("setting up") ray.init("local", num_cpus=2 * logical_cpu_core_count()) # 2x cpu count is faster on my m1 def teardown_module(module): + print("shutting down") ray.shutdown() @@ -175,9 +177,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: ) # now block until the cache is done - print("at wait") cache.await_finished(timeout=10) - print("done waiting") # now check that the chunks are in the right order # TODO: this is a bit gross From 7c163a81afb525db483d6864ebb5349f3ff8d44d Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 19:06:26 -0800 Subject: [PATCH 196/205] break really long docs into shorter docs b/c tokenizers is quadratic --- src/levanter/data/text.py | 109 ++++++++++++++++++++++++++++++++++++-- tests/test_text.py | 22 +++++++- 2 files changed, 125 insertions(+), 6 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 93e602041..a8edf7c1e 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -16,8 +16,10 @@ import jax import numpy as np import pyarrow as pa +import regex from draccus import field from jaxtyping import PRNGKeyArray +from tokenizers import normalizers import haliax as hax from haliax import Axis @@ -310,13 +312,25 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" +LONG_STRING_WORKAROUND = 50_000 + + +ws = regex.compile(r"\s") + + class BatchTokenizer(BatchProcessor[str]): """ A batch processor that tokenizes a batch of strings using a tokenizer. By default, this will append eos to the end of the string, even if the tokenizer doesn't. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, override_resources=None): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + enforce_eos=True, + override_resources=None, + _workaround_len=LONG_STRING_WORKAROUND, + ): _maybe_force_tokenizer_parallelism(tokenizer) self.tokenizer = tokenizer self.override_resources = override_resources @@ -331,16 +345,101 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, overrid should_append_eos = False self._need_to_add_eos = should_append_eos + self._workaround_len = _workaround_len def __call__(self, batch: Sequence[str]) -> BatchEncoding: + orig_lengths = [len(d) for d in batch] if self._need_to_add_eos: - encoding = self.tokenizer( - [d + " " + self.tokenizer.eos_token for d in batch], return_attention_mask=False, verbose=False - ) + batch = [d + " " + self.tokenizer.eos_token for d in batch] + + if self._needs_long_sequence_workaround: + # break any strings that are longer than 50K characters into smaller chunks + orig_batch = batch + batch = [] + needs_merge = [] + for i, d in enumerate(orig_batch): + needs_merge.append(False) + orig_len = orig_lengths[i] + while len(d) > self._workaround_len: + # we'd rather break strings at whitespace, so find the first whitespace + match = ws.search(d, self._workaround_len) + # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit + if match is None: + split = len(d) + else: + split = match.start() + + batch.append(d[:split]) + needs_merge.append(True) + + d = d[split:] + orig_len -= split + + batch.append(d) else: - encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + needs_merge = [] + + encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + + if needs_merge: + new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) + encoding = BatchEncoding(new_encoding) + return encoding + @staticmethod + def _merge_split_encodings(batch, encoding, needs_merge): + # merge the encodings back together + # we might need to merge multiple encodings together + # needs merge marks the first n-1 encodings that need to be merged for each document + new_encoding = {} + for k, v in encoding.items(): + if len(v) == 0: + continue + if isinstance(v[0], np.ndarray): + assert len(v) == len(batch) + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + v_out.append(np.concatenate(vs_to_merge)) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(np.concatenate(vs_to_merge)) + + new_encoding[k] = v_out + elif isinstance(v[0], list): + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + new_encoding[k] = v_out + else: + raise ValueError(f"Unknown type {type(v[0])}") + return new_encoding + + # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449 + @cached_property + def _needs_long_sequence_workaround(self): + if isinstance(self.tokenizer, PreTrainedTokenizerFast): + normalizer = self.tokenizer.backend_tokenizer.normalizer + if normalizer is None: + return False + # if there's a "Replace" normalizer, then we need to do the workaround + # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it + return any(isinstance(n, (normalizers.Replace, normalizers.Sequence)) for n in normalizer) + else: + return False + @property def num_cpus(self) -> int: if self.override_resources is not None: diff --git a/tests/test_text.py b/tests/test_text.py index a9d407b44..90521d4a1 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,10 +1,11 @@ import tempfile import jax.numpy as jnp +from transformers import AutoTokenizer import haliax as hax -from levanter.data.text import LMDatasetConfig +from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample from levanter.models.loss import next_token_loss @@ -39,3 +40,22 @@ def test_lm_example_handles_ignore_id(): no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size + + +def test_merge_split_encodings(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + # make this very short for testing + + lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" + + short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3) + # force this + short_batch_tokenizer._needs_long_sequence_workaround = True + + batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000) + batch = [lorem] + + short_out = short_batch_tokenizer(batch) + reg_out = batch_tokenizer(batch) + + assert short_out == reg_out From 4125d3f96dae8e0024c5db14c8de35ce2cfb1215 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 19:12:36 -0800 Subject: [PATCH 197/205] kmklamdklad --- src/levanter/data/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index a8edf7c1e..fb1a841b6 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -436,7 +436,7 @@ def _needs_long_sequence_workaround(self): return False # if there's a "Replace" normalizer, then we need to do the workaround # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it - return any(isinstance(n, (normalizers.Replace, normalizers.Sequence)) for n in normalizer) + return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) else: return False From 99a87e8c16acee6c730c9feb1f0c5b3f31dbd1ea Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 22:33:28 -0800 Subject: [PATCH 198/205] maybe don't do the workaround so often? --- src/levanter/data/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index fb1a841b6..36fb1488a 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -312,7 +312,7 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" -LONG_STRING_WORKAROUND = 50_000 +LONG_STRING_WORKAROUND = 100_000 ws = regex.compile(r"\s") From 5245e10cd7ea3892afd19e023950083668a383d4 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 5 Feb 2024 22:36:29 -0800 Subject: [PATCH 199/205] is this the leak?!? --- src/levanter/data/shard_cache.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 98a043b51..142ba57df 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -808,10 +808,10 @@ def process_task(desc, batch: List[T]) -> pa.RecordBatch: pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) logger.info(f"Processing batch {desc}") queue.task_running.remote() - timer_thread = WaitTimeReportingThread( - lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 - ) - timer_thread.start() + # timer_thread = WaitTimeReportingThread( + # lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 + # ) + # timer_thread.start() try: result = processor(batch) del batch @@ -822,8 +822,9 @@ def process_task(desc, batch: List[T]) -> pa.RecordBatch: logger.exception(f"Error while processing batch {desc}") raise e finally: - timer_thread.shutdown() - timer_thread.join() + # timer_thread.shutdown() + # timer_thread.join() + pass return process_task From 8a6f59b2d8518723e6edec02ee0df84a128f54bd Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 09:38:38 -0800 Subject: [PATCH 200/205] update for latest datasets --- config/gpt2_nano.yaml | 4 ++-- src/levanter/data/sharded_dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 267d83b51..5612fc104 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -1,5 +1,5 @@ -#data: -# id: dlwh/wikitext_103_detokenized +data: + id: dlwh/wikitext_103_detokenized model: type: gpt2 hidden_dim: 32 diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 1ceae6366..0ec178e08 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -171,7 +171,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources([int(shard_name)])) + shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) else: shard = dataset From 002989b17717796b57b4560363683accd0d64dde Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 09:46:05 -0800 Subject: [PATCH 201/205] add a test to ensure we use the workaround for llama tokenizer --- config/gpt2_xl.yaml | 2 +- tests/test_text.py | 8 ++++++++ tests/test_utils.py | 8 ++++---- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/config/gpt2_xl.yaml b/config/gpt2_xl.yaml index 026fc077e..a58c7ceb0 100644 --- a/config/gpt2_xl.yaml +++ b/config/gpt2_xl.yaml @@ -12,7 +12,7 @@ trainer: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - per_device_parallelism: 1 + per_device_parallelism: -1 optimizer: learning_rate: 1E-4 weight_decay: 0.1 diff --git a/tests/test_text.py b/tests/test_text.py index 90521d4a1..70b2d26a7 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -8,6 +8,7 @@ from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample from levanter.models.loss import next_token_loss +from test_utils import skip_if_hf_model_not_accessible def test_dont_blow_up_without_validation_set(): @@ -59,3 +60,10 @@ def test_merge_split_encodings(): reg_out = batch_tokenizer(batch) assert short_out == reg_out + + +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") +def test_llama_tokenizer_needs_long_sequence_workaround(): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + batch_tokenizer = BatchTokenizer(tokenizer) + assert batch_tokenizer._needs_long_sequence_workaround diff --git a/tests/test_utils.py b/tests/test_utils.py index b2b060c28..08df42f69 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -139,21 +139,21 @@ def try_load_path(path): else: return True - return pytest.mark.skipif(not try_load_path(path), reason="Checkpoint not accessible")(lambda x: x) + return pytest.mark.skipif(not try_load_path(path), reason="Checkpoint not accessible") def skip_if_hf_model_not_accessible(model_id: str): def try_load_hf(model_id): try: - from transformers import AutoModelForCausalLM + from transformers import AutoConfig - AutoModelForCausalLM.from_pretrained(model_id) + AutoConfig.from_pretrained(model_id) except Exception: return False else: return True - return pytest.mark.skipif(not try_load_hf(model_id), reason="HuggingFace model not accessible")(lambda x: x) + return pytest.mark.skipif(not try_load_hf(model_id), reason="HuggingFace model not accessible") class IdentityProcessor(BatchProcessor[BatchEncoding]): From 3dfebe22f9c7c62519d48f0d4bd90e1af052af74 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 09:46:23 -0800 Subject: [PATCH 202/205] tweak timeouts in test --- src/levanter/data/shard_cache.py | 2 +- tests/test_shard_cache.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 142ba57df..417b15862 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1618,7 +1618,7 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N else: assert self._broker is not None time_in = time.time() - next_time = time.time() + next_time = time_in # we want to also log if we're waiting for a long time, so we do this in a loop while timeout is None or next_time - time_in < timeout: current_timeout = 20.0 diff --git a/tests/test_shard_cache.py b/tests/test_shard_cache.py index 4ff040bca..6b54970c5 100644 --- a/tests/test_shard_cache.py +++ b/tests/test_shard_cache.py @@ -233,7 +233,7 @@ def back_to_py(batch: pa.RecordBatch): assert [list(x) for x in chunk] == [[i] * 10 for i in range(10)] with pytest.raises(TimeoutError): - cache.get_chunk(1, timeout=0.1) + cache.get_chunk(1, timeout=0.5) ray.get(blocker_to_wait_on_test.unblock.remote()) From 5a5a1f1c79a5144b23f670837c9a49f1d4628c95 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 09:49:07 -0800 Subject: [PATCH 203/205] less spammy logging --- src/levanter/data/shard_cache.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 417b15862..16e75f8b5 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -589,7 +589,7 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if exhausted_shard: writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - self.group.logger.info( + self.group.logger.debug( f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}" ) @@ -806,7 +806,7 @@ def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandl @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(desc, batch: List[T]) -> pa.RecordBatch: pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) - logger.info(f"Processing batch {desc}") + logger.debug(f"Processing batch {desc}") queue.task_running.remote() # timer_thread = WaitTimeReportingThread( # lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 @@ -913,7 +913,7 @@ async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: f" {time_mid - time_in}" ) # do a backoff loop until the batch is actually processed. log if it's been a while - timeout_interval = 10 + timeout_interval = 20 total_time_waited = 0 while True: @@ -924,7 +924,7 @@ async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: except asyncio.TimeoutError: # to keep to round numbers, we log how much we asked for rather than how much we got total_time_waited += timeout_interval - timeout_interval = min(2 * timeout_interval, 60) + timeout_interval = min(2 * timeout_interval, 100) logger.info( f"Waiting for {shard_name}.{chunk_id}.{batch_idx} to be processed. " f"Waited {total_time_waited} seconds." @@ -934,7 +934,7 @@ async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: logger.debug( f"Received finished {shard_name}.{chunk_id}.{batch_idx} in {(time.time() - time_in):.2f} seconds." ) - elif total_time_waited > 10: + elif total_time_waited > 40: logger.info( f"Waited {total_time_waited} seconds for {shard_name}.{chunk_id}.{batch_idx} to be processed." ) From 4e6df525c205b4a4401b4b846c42de7ffab1f21d Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 09:53:04 -0800 Subject: [PATCH 204/205] cleanup, see if we can avoid crashing when one cache finishes --- src/levanter/data/shard_cache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 16e75f8b5..ff1256716 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -300,7 +300,7 @@ def _commit(self): # The difficulty is that we want parallelism and we want to control the order of chunks. # reading batches requires CPU and network. This means we should limit the number to roughly the number of nodes, maybe times 2. -# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from any shard. +# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from another shard. # We also want to prioritize reading earlier shards before later shards (within a chunk generation round). # Ray also seems to get upset about having too many processes, and we can't serialize the iterators over shards. @@ -1449,7 +1449,9 @@ def _finalize(self): _serialize_json_and_commit(os.path.join(self._cache_dir, LEDGER_FILE_NAME), CacheLedger(self.chunks)) self._reader_promises = {} - self._builder_actor = None + # TODO: For some reason this crashes other actors with weird reference counting assertion errors. + # pretty sure it's a ray bug + # self._builder_actor = None self._finished_promise.set_result(None) # notify metrics subscribers From c2dccf2284372b4d1ec072e334f8b0900fc9596f Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 12:19:01 -0800 Subject: [PATCH 205/205] tweaks to tokenization/shard_cache throughput (#456) --- config/gpt2_nano.yaml | 4 +- config/gpt2_xl.yaml | 2 +- src/levanter/data/shard_cache.py | 176 ++++++++++++++++++++------- src/levanter/data/sharded_dataset.py | 2 +- src/levanter/data/text.py | 111 ++++++++++++++++- src/levanter/utils/hf_utils.py | 4 +- tests/test_shard_cache.py | 6 +- tests/test_text.py | 30 ++++- tests/test_utils.py | 8 +- 9 files changed, 283 insertions(+), 60 deletions(-) diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 267d83b51..5612fc104 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -1,5 +1,5 @@ -#data: -# id: dlwh/wikitext_103_detokenized +data: + id: dlwh/wikitext_103_detokenized model: type: gpt2 hidden_dim: 32 diff --git a/config/gpt2_xl.yaml b/config/gpt2_xl.yaml index 026fc077e..a58c7ceb0 100644 --- a/config/gpt2_xl.yaml +++ b/config/gpt2_xl.yaml @@ -12,7 +12,7 @@ trainer: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - per_device_parallelism: 1 + per_device_parallelism: -1 optimizer: learning_rate: 1E-4 weight_decay: 0.1 diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f33329dd5..ff1256716 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -49,6 +49,8 @@ DEFAULT_MAX_BYTES_PER_BATCH = 256 * 1024 * 1024 # 256 MB, this is pre-preprocessing python object size LEDGER_FILE_NAME = "cache_ledger.json" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + def build_cache( cache_dir: str, @@ -298,7 +300,7 @@ def _commit(self): # The difficulty is that we want parallelism and we want to control the order of chunks. # reading batches requires CPU and network. This means we should limit the number to roughly the number of nodes, maybe times 2. -# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from any shard. +# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from another shard. # We also want to prioritize reading earlier shards before later shards (within a chunk generation round). # Ray also seems to get upset about having too many processes, and we can't serialize the iterators over shards. @@ -363,7 +365,7 @@ def __le__(self, other: "PriorityWorkItem"): @ray.remote(num_cpus=1, scheduling_strategy="SPREAD") class PriorityProcessorActor: def __init__(self, max_in_flight: Optional[int] = 200): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self._queue: list[PriorityWorkItem] = [] # heapq self._queue_lock = threading.Lock() self._shutdown_event = threading.Event() @@ -387,7 +389,7 @@ def is_group_finished(self, group: PriorityWorkTaskGroupSpec): if self._current_item is not None and self._current_item.spec == group: return False - logger.info(f"Group {group.name} is finished.") + logger.debug(f"Group {group.name} is finished.") return True @@ -444,9 +446,9 @@ def drain_backpressure_to(count): if not item_is_finished: heapq.heappush(self._queue, item) - logger.info("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") + logger.debug("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") drain_backpressure_to(0) - logger.info("Backpressure drained. Shutting down PriorityProcessorActor.") + logger.debug("Backpressure drained. Shutting down PriorityProcessorActor.") @dataclass @@ -460,6 +462,7 @@ class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): processor_actor: ray.actor.ActorHandle # BatchProcessorQueue batch_size: int num_rows_per_chunk: int + group_id: int def build(self) -> "PriorityWorkTaskGroup": return ShardGroupTaskGroup(self) @@ -468,7 +471,7 @@ def build(self) -> "PriorityWorkTaskGroup": class ShardGroupTaskGroup(PriorityWorkTaskGroup): def __init__(self, spec: ShardGroupToBeProcessed): self.spec = spec - self.logger = pylogging.getLogger(f"shard_reader.{self.spec.name}") + self.logger = pylogging.getLogger(f"shard_reader.{spec.group_id}.{spec.name}") try: metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( @@ -544,6 +547,8 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: total_chunk_rows = 0 # the total number of rows in the chunk batch_result_ref = None + self.group.logger.debug(f"Reading one chunk of shard {self.shard_name}: {self.chunk_idx}") + try: while not chunk_filled: batch = next(self.reader, None) @@ -556,14 +561,20 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if batch: priority = self.spec.priority_fn(self.shard_idx, self.chunk_idx) + # these times aren't exact because the times might be from different machines + # but they're just for logging + time_in = time.time() batch_result_ref = ray.get( - self.spec.processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) + self.spec.processor_actor.submit.remote( + priority=priority, + desc=f"{self.shard_name}.{self.chunk_idx}.{chunk_batch_idx}", + batch=RefBox(ray.put(batch)), + ) ) writer.chunk_batch_finished.remote( - self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref) + self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref), time_in ) chunk_batch_idx += 1 - # enqueue_to_backpressure(batch, batch_result_ref) del batch if total_chunk_rows >= self.spec.num_rows_per_chunk or exhausted_shard: @@ -578,7 +589,9 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if exhausted_shard: writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - logger.debug(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") + self.group.logger.debug( + f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}" + ) return exhausted_shard, batch_result_ref except Exception as e: # noqa @@ -769,14 +782,49 @@ def is_finished_and_buffer_empty(self): return self.expected_num_chunks is not None and self.num_chunks_sent >= self.expected_num_chunks +class WaitTimeReportingThread(threading.Thread): + def __init__(self, report, interval=60): + super().__init__() + self.report = report + self.interval = interval + self.shutdown_event = threading.Event() + + def run(self): + total_waited = 0 + while True: + if self.shutdown_event.wait(self.interval): + break + if total_waited > 0: + self.report(total_waited) + total_waited += self.interval + + def shutdown(self): + self.shutdown_event.set() + + def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) - def process_task(batch: List[T]) -> pa.RecordBatch: - pylogging.basicConfig(level=pylogging.INFO) + def process_task(desc, batch: List[T]) -> pa.RecordBatch: + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) + logger.debug(f"Processing batch {desc}") queue.task_running.remote() - result = processor(batch) - del batch - return as_record_batch(result) + # timer_thread = WaitTimeReportingThread( + # lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 + # ) + # timer_thread.start() + try: + result = processor(batch) + del batch + result = as_record_batch(result) + logger.info(f"Finished processing batch {desc}") + return result + except Exception as e: + logger.exception(f"Error while processing batch {desc}") + raise e + finally: + # timer_thread.shutdown() + # timer_thread.join() + pass return process_task @@ -784,6 +832,7 @@ def process_task(batch: List[T]) -> pa.RecordBatch: @dataclass(order=True, frozen=True) class _QueueItem: priority: float + desc: str batch: ray.ObjectRef = dataclasses.field(compare=False) task_id: int task_future: asyncio.Future = dataclasses.field(compare=False) @@ -818,13 +867,13 @@ def __init__(self, batch_processor: BatchProcessor[T]): # we don't need/want to dereference the batch, so we wrap it in a RefBox # one virtue of doing things this way is that we can let Ray try to schedule the compute near the data. - async def submit(self, priority: float, batch: RefBox): + async def submit(self, priority: float, desc: str, batch: RefBox): """Returns a future that is set to the *ObjectRef* of the processed batch. The future is "complete" when the task starts, not when it finishes. You then call ray.get on the future's result to get the actual batch.""" task_id = self._next_task_id self._next_task_id += 1 f: asyncio.Future = asyncio.Future() - self.pqueue.put(_QueueItem(priority, batch.ref, task_id, f)) + self.pqueue.put(_QueueItem(priority, desc, batch.ref, task_id, f)) self._maybe_start_task() return await f @@ -833,7 +882,7 @@ def _maybe_start_task(self): self.ready = False item = self.pqueue.get() batch = item.batch - item.task_future.set_result(self._task_processor.remote(batch)) + item.task_future.set_result(self._task_processor.remote(item.desc, batch)) def task_running(self): self.ready = True @@ -845,7 +894,7 @@ def task_running(self): @ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") # type: ignore class _GroupShardWriterWorker: def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.cache_dir = cache_dir self.shard_names = shard_names self.shard_writers: dict[str, _ShardWriterWorker] = { @@ -855,10 +904,40 @@ def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): def current_metadata(self, shard_name: str): return self.shard_writers[shard_name].current_metadata() - async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox): + async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox, time_in): # batch is a pa.RecordBatch ref box try: - batch = await batch.ref + time_mid = time.time() + logger.debug( + f"Received in progress batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in" + f" {time_mid - time_in}" + ) + # do a backoff loop until the batch is actually processed. log if it's been a while + timeout_interval = 20 + total_time_waited = 0 + + while True: + try: + # batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) + batch = await batch.ref + break + except asyncio.TimeoutError: + # to keep to round numbers, we log how much we asked for rather than how much we got + total_time_waited += timeout_interval + timeout_interval = min(2 * timeout_interval, 100) + logger.info( + f"Waiting for {shard_name}.{chunk_id}.{batch_idx} to be processed. " + f"Waited {total_time_waited} seconds." + ) + + if logger.isEnabledFor(pylogging.DEBUG): + logger.debug( + f"Received finished {shard_name}.{chunk_id}.{batch_idx} in {(time.time() - time_in):.2f} seconds." + ) + elif total_time_waited > 40: + logger.info( + f"Waited {total_time_waited} seconds for {shard_name}.{chunk_id}.{batch_idx} to be processed." + ) return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) except Exception as e: print(f"Error while processing batch {batch_idx} of chunk {chunk_id} of shard {shard_name}", flush=True) @@ -889,7 +968,7 @@ def __init__( cache_dir: str, shard_name: str, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.parent_ref = parent_ref self.cache_dir = cache_dir self.shard_name = shard_name @@ -934,13 +1013,9 @@ def chunk_failed(self, chunk_id: int, error: ExceptionInfo): self.parent_ref.shard_failed.remote(self.shard_name, error) def _finished_chunk(self, idx: int, chunk: ChunkMetadata): - if idx < self.metadata_writer.num_chunks: - logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") - error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") - self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) - raise error - - if self._expected_num_chunks is not None and idx >= self._expected_num_chunks: + if (idx < self.metadata_writer.num_chunks) or ( + self._expected_num_chunks is not None and idx >= self._expected_num_chunks + ): logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) @@ -961,6 +1036,8 @@ def _attempt_to_commit_chunks(self): chunks_committed = [] while len(self.uncommited_chunks) > 0 and self.uncommited_chunks[0][0] == self.metadata_writer.num_chunks: _, chunk = heapq.heappop(self.uncommited_chunks) + chunk_number = self.metadata_writer.num_chunks + logger.debug(f"Committing chunk {chunk.name} of shard {self.shard_name}. It is chunk {chunk_number}") self.metadata_writer.commit_chunk(chunk) chunks_committed.append(chunk) @@ -1074,7 +1151,7 @@ def __init__( processor: BatchProcessor[T], rows_per_chunk: int, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() @@ -1095,6 +1172,7 @@ def __init__( self._processor_actors = [] for shard_name in source.shard_names: + self._current_round_robin.append(shard_name) self.shard_status[shard_name] = _ShardStatus() num_shards = len(source.shard_names) @@ -1106,7 +1184,6 @@ def priority_fn(shard_idx, chunk_idx): shard_groups: list[list[str]] = [[] for _ in range(num_shard_groups)] for i, shard_name in enumerate(source.shard_names): - self._current_round_robin.append(shard_name) shard_groups[i % num_shard_groups].append(shard_name) for group_id, shard_group in enumerate(shard_groups): @@ -1127,6 +1204,7 @@ def priority_fn(shard_idx, chunk_idx): processor_actor=processor_actor, batch_size=processor.batch_size, num_rows_per_chunk=rows_per_chunk, + group_id=group_id, ) # we want global names so that different tasks can coordinate priorities @@ -1228,14 +1306,23 @@ def _attempt_to_flush_buffers(self): next_chunk = status.pop_chunk_to_send() if next_chunk is not None: # we can send a chunk from this shard - logger.debug(f"Sending chunk from {name}") self._current_round_robin.pop(0) self._current_round_robin.append(name) chunks_to_send.append(next_chunk) continue else: - logger.debug(f"Shard {name} has no chunks to send and is not known to be finished") # we can't send a chunk from this shard, so we can't send any additional chunks + if self.logger.level <= pylogging.DEBUG: + chunks_waiting = [ + f"{n2} ({len(s2.current_buffer)})" + for n2, s2 in self.shard_status.items() + if len(s2.current_buffer) > 0 + ] + msg = ( + f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" + f" chunks: {chunks_waiting}" + ) + self.logger.debug(msg) break if len(chunks_to_send) > 0: @@ -1257,7 +1344,7 @@ class ChunkCacheBroker: _finished_promise: asyncio.Future[None] def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.chunks = [] self._reader_promises = {} self._is_finished = False @@ -1269,6 +1356,9 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr # used to subscribe to metrics updates self._latest_metrics = InProgressCacheMetrics() self._metrics_condition = asyncio.Condition() + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"broker::{path_for_name}" + self.logger = pylogging.getLogger(f"{name}") # initialize writer task # first see if we need to do anything: check the ledger for is_finished @@ -1280,7 +1370,6 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr except FileNotFoundError: self_ref = ray.runtime_context.get_runtime_context().current_actor # only use the last two components of the name since it gets kind of long - path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) name = f"builder::{path_for_name}" self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, name, self._source, self._processor, self._rows_per_chunk) # type: ignore @@ -1307,7 +1396,6 @@ async def get_chunk(self, chunk_idx: int) -> Optional[ChunkMetadata]: elif self._is_finished: return None else: - # we don't have this chunk yet, so we need to wait if chunk_idx not in self._reader_promises: self._reader_promises[chunk_idx] = asyncio.Future() return await self._reader_promises[chunk_idx] @@ -1322,7 +1410,9 @@ def _append_chunks(self, *chunks: ChunkMetadata): for chunk in chunks: self.chunks.append(chunk) chunk_idx = len(self.chunks) - 1 + self.logger.debug(f"Received chunk {chunk_idx}") if chunk_idx in self._reader_promises: + self.logger.debug(f"Resolving promise for chunk {chunk_idx}") self._reader_promises[chunk_idx].set_result(chunk) del self._reader_promises[chunk_idx] @@ -1359,7 +1449,9 @@ def _finalize(self): _serialize_json_and_commit(os.path.join(self._cache_dir, LEDGER_FILE_NAME), CacheLedger(self.chunks)) self._reader_promises = {} - self._builder_actor = None + # TODO: For some reason this crashes other actors with weird reference counting assertion errors. + # pretty sure it's a ray bug + # self._builder_actor = None self._finished_promise.set_result(None) # notify metrics subscribers @@ -1528,17 +1620,19 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N else: assert self._broker is not None time_in = time.time() + next_time = time_in # we want to also log if we're waiting for a long time, so we do this in a loop - while timeout is None or time.time() - time_in < timeout: + while timeout is None or next_time - time_in < timeout: current_timeout = 20.0 if timeout is not None: - current_timeout = min(current_timeout, timeout - (time.time() - time_in)) + current_timeout = min(current_timeout, timeout - (next_time - time_in)) try: chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) except GetTimeoutError: - self.logger.warning(f"Waiting for chunk {mapped_index} for {int(time.time() - time_in)} seconds") + self.logger.warning(f"Waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds") + next_time += current_timeout current_timeout *= 2 - current_timeout = min(current_timeout, 80) + current_timeout = min(current_timeout, 100) continue if chunk is None: diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 1ceae6366..0ec178e08 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -171,7 +171,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources([int(shard_name)])) + shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) else: shard = dataset diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index f89a75ca2..36fb1488a 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -16,8 +16,10 @@ import jax import numpy as np import pyarrow as pa +import regex from draccus import field from jaxtyping import PRNGKeyArray +from tokenizers import normalizers import haliax as hax from haliax import Axis @@ -310,13 +312,25 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" +LONG_STRING_WORKAROUND = 100_000 + + +ws = regex.compile(r"\s") + + class BatchTokenizer(BatchProcessor[str]): """ A batch processor that tokenizes a batch of strings using a tokenizer. By default, this will append eos to the end of the string, even if the tokenizer doesn't. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, override_resources=None): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + enforce_eos=True, + override_resources=None, + _workaround_len=LONG_STRING_WORKAROUND, + ): _maybe_force_tokenizer_parallelism(tokenizer) self.tokenizer = tokenizer self.override_resources = override_resources @@ -331,16 +345,101 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, overrid should_append_eos = False self._need_to_add_eos = should_append_eos + self._workaround_len = _workaround_len def __call__(self, batch: Sequence[str]) -> BatchEncoding: + orig_lengths = [len(d) for d in batch] if self._need_to_add_eos: - encoding = self.tokenizer( - [d + " " + self.tokenizer.eos_token for d in batch], return_attention_mask=False, verbose=False - ) + batch = [d + " " + self.tokenizer.eos_token for d in batch] + + if self._needs_long_sequence_workaround: + # break any strings that are longer than 50K characters into smaller chunks + orig_batch = batch + batch = [] + needs_merge = [] + for i, d in enumerate(orig_batch): + needs_merge.append(False) + orig_len = orig_lengths[i] + while len(d) > self._workaround_len: + # we'd rather break strings at whitespace, so find the first whitespace + match = ws.search(d, self._workaround_len) + # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit + if match is None: + split = len(d) + else: + split = match.start() + + batch.append(d[:split]) + needs_merge.append(True) + + d = d[split:] + orig_len -= split + + batch.append(d) else: - encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + needs_merge = [] + + encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + + if needs_merge: + new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) + encoding = BatchEncoding(new_encoding) + return encoding + @staticmethod + def _merge_split_encodings(batch, encoding, needs_merge): + # merge the encodings back together + # we might need to merge multiple encodings together + # needs merge marks the first n-1 encodings that need to be merged for each document + new_encoding = {} + for k, v in encoding.items(): + if len(v) == 0: + continue + if isinstance(v[0], np.ndarray): + assert len(v) == len(batch) + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + v_out.append(np.concatenate(vs_to_merge)) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(np.concatenate(vs_to_merge)) + + new_encoding[k] = v_out + elif isinstance(v[0], list): + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + new_encoding[k] = v_out + else: + raise ValueError(f"Unknown type {type(v[0])}") + return new_encoding + + # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449 + @cached_property + def _needs_long_sequence_workaround(self): + if isinstance(self.tokenizer, PreTrainedTokenizerFast): + normalizer = self.tokenizer.backend_tokenizer.normalizer + if normalizer is None: + return False + # if there's a "Replace" normalizer, then we need to do the workaround + # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it + return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) + else: + return False + @property def num_cpus(self) -> int: if self.override_resources is not None: @@ -357,7 +456,7 @@ def num_gpus(self) -> int: @property def batch_size(self) -> int: - return 1024 + return 512 def concatenate_and_group_texts( diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 408a8c8da..879162a65 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,7 +18,9 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - return min(max(1, logical_cpu_core_count() - 2), 32) + # Really it's dependent on the number of docs, but that's not something we + # can easily know here. + return min(max(1, logical_cpu_core_count() - 2), 16) else: return 1 diff --git a/tests/test_shard_cache.py b/tests/test_shard_cache.py index c1b55ca1e..6b54970c5 100644 --- a/tests/test_shard_cache.py +++ b/tests/test_shard_cache.py @@ -13,10 +13,12 @@ def setup_module(module): + print("setting up") ray.init("local", num_cpus=2 * logical_cpu_core_count()) # 2x cpu count is faster on my m1 def teardown_module(module): + print("shutting down") ray.shutdown() @@ -175,9 +177,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: ) # now block until the cache is done - print("at wait") cache.await_finished(timeout=10) - print("done waiting") # now check that the chunks are in the right order # TODO: this is a bit gross @@ -233,7 +233,7 @@ def back_to_py(batch: pa.RecordBatch): assert [list(x) for x in chunk] == [[i] * 10 for i in range(10)] with pytest.raises(TimeoutError): - cache.get_chunk(1, timeout=0.1) + cache.get_chunk(1, timeout=0.5) ray.get(blocker_to_wait_on_test.unblock.remote()) diff --git a/tests/test_text.py b/tests/test_text.py index a9d407b44..70b2d26a7 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,12 +1,14 @@ import tempfile import jax.numpy as jnp +from transformers import AutoTokenizer import haliax as hax -from levanter.data.text import LMDatasetConfig +from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample from levanter.models.loss import next_token_loss +from test_utils import skip_if_hf_model_not_accessible def test_dont_blow_up_without_validation_set(): @@ -39,3 +41,29 @@ def test_lm_example_handles_ignore_id(): no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size + + +def test_merge_split_encodings(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + # make this very short for testing + + lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" + + short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3) + # force this + short_batch_tokenizer._needs_long_sequence_workaround = True + + batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000) + batch = [lorem] + + short_out = short_batch_tokenizer(batch) + reg_out = batch_tokenizer(batch) + + assert short_out == reg_out + + +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") +def test_llama_tokenizer_needs_long_sequence_workaround(): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + batch_tokenizer = BatchTokenizer(tokenizer) + assert batch_tokenizer._needs_long_sequence_workaround diff --git a/tests/test_utils.py b/tests/test_utils.py index b2b060c28..08df42f69 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -139,21 +139,21 @@ def try_load_path(path): else: return True - return pytest.mark.skipif(not try_load_path(path), reason="Checkpoint not accessible")(lambda x: x) + return pytest.mark.skipif(not try_load_path(path), reason="Checkpoint not accessible") def skip_if_hf_model_not_accessible(model_id: str): def try_load_hf(model_id): try: - from transformers import AutoModelForCausalLM + from transformers import AutoConfig - AutoModelForCausalLM.from_pretrained(model_id) + AutoConfig.from_pretrained(model_id) except Exception: return False else: return True - return pytest.mark.skipif(not try_load_hf(model_id), reason="HuggingFace model not accessible")(lambda x: x) + return pytest.mark.skipif(not try_load_hf(model_id), reason="HuggingFace model not accessible") class IdentityProcessor(BatchProcessor[BatchEncoding]):