From aaa8bbf705b6f090fb07ad36503f39b5e922a6df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 4 Sep 2024 11:33:29 +0200 Subject: [PATCH] Standalone nanotron config (#285) What does this implement/fix? Explain your changes. --------------------------------------------------- This PR moves the lighteval config to lighteval codebase. - Enforces the lighteval_config_path as the only way to read the lighteval config. The nanotron part is ignore, this way the breaking changes won't be as breaking. - Some typing corrections --------- Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Co-authored-by: Nathan Habib Co-authored-by: Hynek Kydlicek --- .../lighteval_config_override_template.yaml | 18 ++-- src/lighteval/__main__.py | 2 +- src/lighteval/config/lighteval_config.py | 100 ++++++++++++++++++ src/lighteval/logging/evaluation_tracker.py | 11 +- src/lighteval/main_nanotron.py | 25 +++-- src/lighteval/models/base_model.py | 7 -- src/lighteval/models/nanotron_model.py | 40 +++---- src/lighteval/parsers.py | 10 +- src/lighteval/pipeline.py | 20 ++-- src/lighteval/tasks/lighteval_task.py | 4 +- tests/utils.py | 2 +- 11 files changed, 163 insertions(+), 76 deletions(-) create mode 100644 src/lighteval/config/lighteval_config.py diff --git a/examples/nanotron/lighteval_config_override_template.yaml b/examples/nanotron/lighteval_config_override_template.yaml index 12955216..03b65596 100644 --- a/examples/nanotron/lighteval_config_override_template.yaml +++ b/examples/nanotron/lighteval_config_override_template.yaml @@ -1,15 +1,15 @@ -batch_size: 16 -checkpoints_path: null +# As of right now auto batch size doesn't work, so we use some default +batch_size: 8 generation: null logging: - hub_repo_details: null - hub_repo_results: null - hub_repo_tensorboard: null - local_output_path: ./output_dir - push_details_to_hub: false + output_dir: "outputs" + save_details: false push_results_to_hub: false - push_results_to_tensorboard: true - tensorboard_metric_prefix: e + push_details_to_hub: false + push_results_to_tensorboard: false + public_run: false + results_org: null + tensorboard_metric_prefix: "eval" parallelism: dp: 1 pp: 1 diff --git a/src/lighteval/__main__.py b/src/lighteval/__main__.py index fcc4e0f2..054a06a6 100644 --- a/src/lighteval/__main__.py +++ b/src/lighteval/__main__.py @@ -60,7 +60,7 @@ def cli_evaluate(): elif args.subcommand == "nanotron": from lighteval.main_nanotron import main as main_nanotron - main_nanotron(args.checkpoint_config_path, args.lighteval_override, args.cache_dir) + main_nanotron(args.checkpoint_config_path, args.lighteval_config_path, args.cache_dir) elif args.subcommand == "tasks": if args.list: diff --git a/src/lighteval/config/lighteval_config.py b/src/lighteval/config/lighteval_config.py new file mode 100644 index 00000000..3b8a3332 --- /dev/null +++ b/src/lighteval/config/lighteval_config.py @@ -0,0 +1,100 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from dataclasses import dataclass +from typing import Dict, Optional, Union + +from nanotron.config import Config +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.generation.sampler import SamplerType +from nanotron.logging import get_logger + + +logger = get_logger(__name__) + +DEFAULT_GENERATION_SEED = 42 + + +@dataclass +class GenerationArgs: + sampler: Optional[Union[str, SamplerType]] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + n_samples: Optional[int] = None + eos: Optional[str] = None + seed: Optional[int] = None + use_cache: Optional[bool] = False + + def __post_init__(self): + if isinstance(self.sampler, str): + self.sampler = SamplerType[self.sampler.upper()] + if self.seed is None: + self.seed = DEFAULT_GENERATION_SEED + + +@dataclass +class LightEvalLoggingArgs: + """Arguments related to logging for LightEval""" + + output_dir: str + save_details: bool = True + push_results_to_hub: bool = False + push_details_to_hub: bool = False + push_results_to_tensorboard: bool = False + public_run: bool = False + results_org: str | None = None + tensorboard_metric_prefix: str = "eval" + + +@dataclass +class LightEvalTasksArgs: + """Arguments related to tasks for LightEval""" + + tasks: str + custom_tasks: Optional[str] = None + max_samples: Optional[int] = None + num_fewshot_seeds: Optional[int] = None + + dataset_loading_processes: int = 8 + multichoice_continuations_start_space: Optional[bool] = None + + +@dataclass +class LightEvalConfig: + """Arguments related to running LightEval on checkpoints. + + All is optional because you can also use this class to later supply arguments to override + the saved config when running LightEval after training. + """ + + logging: LightEvalLoggingArgs + tasks: LightEvalTasksArgs + parallelism: ParallelismArgs + batch_size: int = 0 + generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None + + +@dataclass +class FullNanotronConfig: + lighteval_config: LightEvalConfig + nanotron_config: Config diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index ebae02b6..e72982cf 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -94,8 +94,8 @@ class EvaluationTracker: def __init__( self, - output_dir: str = None, - hub_results_org: str = "", + output_dir: str, + hub_results_org: str | None = None, push_results_to_hub: bool = False, push_details_to_hub: bool = False, push_results_to_tensorboard: bool = False, @@ -133,14 +133,13 @@ def __init__( self.output_dir = output_dir - self.hub_results_org = hub_results_org # will also contain tensorboard results - if hub_results_org in ["", None] and any( - [push_details_to_hub, push_results_to_hub, push_results_to_tensorboard] - ): + if hub_results_org in [None] and any([push_details_to_hub, push_results_to_hub, push_results_to_tensorboard]): raise Exception( "You need to select which org to push to, using `--results_org`, if you want to save information to the hub." ) + self.hub_results_org = hub_results_org # will also contain tensorboard results + self.hub_results_repo = f"{hub_results_org}/results" self.hub_private_results_repo = f"{hub_results_org}/private-results" self.push_results_to_hub = push_results_to_hub diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 6e219b30..2fa05f8f 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -24,6 +24,7 @@ import os from typing import Optional +from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.hierarchical_logger import htrack, htrack_block from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters @@ -34,7 +35,7 @@ if not is_nanotron_available(): raise ImportError(NO_NANOTRON_ERROR_MSG) -from nanotron.config import Config, LightEvalConfig, get_config_from_file +from nanotron.config import Config, get_config_from_file SEED = 1234 @@ -60,28 +61,26 @@ def main( skip_unused_config_keys=True, skip_null_keys=True, ) - if lighteval_config_path: - lighteval_config = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) - model_config.lighteval = lighteval_config - else: - lighteval_config = model_config.lighteval + + # We are getting an type error, because the get_config_from_file is not correctly typed, + lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore + nanotron_config = FullNanotronConfig(lighteval_config, model_config) evaluation_tracker = EvaluationTracker( - token=os.getenv("HF_TOKEN"), - output_dir=lighteval_config.logging.local_output_path, - hub_results_org=lighteval_config.logging.hub_repo_tensorboard, + output_dir=lighteval_config.logging.output_dir, + hub_results_org=lighteval_config.logging.results_org, tensorboard_metric_prefix=lighteval_config.logging.tensorboard_metric_prefix, - nanotron_run_info=model_config.general, + nanotron_run_info=nanotron_config.nanotron_config.general, ) pipeline_parameters = PipelineParameters( launcher_type=ParallelismManager.NANOTRON, env_config=env_config, - job_id=os.environ.get("SLURM_JOB_ID", None), + job_id=os.environ.get("SLURM_JOB_ID", 0), nanotron_checkpoint_path=checkpoint_config_path, dataset_loading_processes=lighteval_config.tasks.dataset_loading_processes, custom_tasks_directory=lighteval_config.tasks.custom_tasks, - override_batch_size=None, + override_batch_size=lighteval_config.batch_size, num_fewshot_seeds=1, max_samples=lighteval_config.tasks.max_samples, use_chat_template=False, @@ -92,7 +91,7 @@ def main( tasks=lighteval_config.tasks.tasks, pipeline_parameters=pipeline_parameters, evaluation_tracker=evaluation_tracker, - model_config=model_config, + model_config=nanotron_config, ) pipeline.evaluate() diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index a975d7cf..41d12437 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -73,7 +73,6 @@ def __init__( """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.""" self._config = config.init_configs(env_config) self.accelerator = config.accelerator - self._batch_size = config.batch_size self._max_length = self._init_max_length(config.max_length) self.use_chat_template = config.use_chat_template @@ -285,12 +284,6 @@ def _init_max_length(self, max_length) -> int: # or no max length config setting is found in the model or tokenizer. return 2048 - @property - def batch_size(self) -> int: - if self._batch_size >= 0: - self._batch_size = self._get_batch_size(max_input_length=self.max_length) - return self._batch_size # * gpus - @property def device(self) -> Union[int, str, torch.device]: return self._device diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index c0b9c6b4..9c84d33a 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -34,6 +34,7 @@ from tqdm import tqdm from transformers import AutoTokenizer, BatchEncoding +from lighteval.config.lighteval_config import FullNanotronConfig from lighteval.data import ( GenDistributedSampler, GenerativeTaskDatasetNanotron, @@ -55,7 +56,7 @@ ) from lighteval.utils.imports import is_nanotron_available from lighteval.utils.parallelism import find_executable_batch_size -from lighteval.utils.utils import EnvConfig, as_list, boolstring_to_bool +from lighteval.utils.utils import EnvConfig, as_list os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -63,10 +64,8 @@ TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] if is_nanotron_available(): - import nanotron from nanotron import distributed as dist from nanotron import logging - from nanotron.config import LightEvalConfig, ModelArgs, TokenizerArgs from nanotron.generation.decode import decode_tokenized from nanotron.logging import human_format, log_rank from nanotron.models import build_model @@ -90,7 +89,7 @@ class NanotronLightevalModel(LightevalModel): def __init__( self, checkpoint_path: str, - nanotron_config: nanotron.config.Config, + nanotron_config: FullNanotronConfig, parallel_context: ParallelContext, max_gen_toks: Optional[int] = 256, max_length: Optional[int] = None, @@ -104,12 +103,11 @@ def __init__( """Initializes a nanotron model for evaluation. Args: """ - model_args: ModelArgs = nanotron_config.model - tokenizer: TokenizerArgs = nanotron_config.tokenizer - lighteval_config: LightEvalConfig = nanotron_config.lighteval - parallel_config: ParallelContext = nanotron_config.lighteval.parallelism + model_args = nanotron_config.nanotron_config.model + tokenizer = nanotron_config.nanotron_config.tokenizer + lighteval_config = nanotron_config.lighteval_config + parallel_config = nanotron_config.lighteval_config.parallelism - self._batch_size = lighteval_config.batch_size self._max_gen_toks = max_gen_toks self._max_length = max_length self.parallel_config = parallel_config @@ -120,9 +118,7 @@ def __init__( raise ValueError("PP parallelism is not supported yet") # multichoice_continuations_start_space can be True (forcing space), False (forcing no space) or None (no forcing) - multichoice_continuations_start_space = boolstring_to_bool( - lighteval_config.tasks.multichoice_continuations_start_space - ) + multichoice_continuations_start_space = lighteval_config.tasks.multichoice_continuations_start_space self.generation_config = lighteval_config.generation if isinstance(self.generation_config, dict): @@ -217,7 +213,9 @@ def __init__( self.multichoice_continuations_start_space = multichoice_continuations_start_space - self.model_info = ModelInfo(model_name=f"{nanotron_config.general.run}/{nanotron_config.general.step}") + self.model_info = ModelInfo( + model_name=f"{nanotron_config.nanotron_config.general.run}/{nanotron_config.nanotron_config.general.step}" + ) @property def tokenizer(self): @@ -299,12 +297,6 @@ def max_length(self) -> int: return self.tokenizer.model_max_length return self._DEFAULT_MAX_LENGTH - @property - def batch_size(self) -> int: - if self._batch_size >= 0: - self._batch_size = self._get_batch_size(max_input_length=self.max_length) - return self._batch_size # * gpus - @property def device(self) -> Union[int, str, torch.device]: return "cuda" @@ -415,7 +407,7 @@ def _check_continuations_start_space(self, continuation: str) -> str: return continuation def loglikelihood_single_token( - self, requests: List[Tuple[str, dict]], override_bs=None + self, requests: List[Tuple[str, dict]], override_bs=0 ) -> List[LoglikelihoodSingleTokenResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. @@ -475,7 +467,7 @@ def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None) ) def loglikelihood_rolling( - self, requests: List[LoglikelihoodRollingRequest], override_bs=None + self, requests: List[LoglikelihoodRollingRequest], override_bs: int = 0 ) -> List[LoglikelihoodResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" for request in tqdm( @@ -652,7 +644,7 @@ def _get_subsets(self, dataset, num_dataset_splits): @torch.inference_mode() def _loglikelihood_single_token( - self, requests, disable_tqdm: bool = False, override_bs: int = -1, num_dataset_splits: int = 1 + self, requests, disable_tqdm: bool = False, override_bs: int = 0, num_dataset_splits: int = 1 ) -> List[LoglikelihoodSingleTokenResponse]: dataset = LoglikelihoodSingleTokenDataset(requests=requests) res = [] @@ -1115,7 +1107,7 @@ def greedy_until( self, requests: List[GreedyUntilRequest], disable_tqdm: bool = False, - override_bs=None, + override_bs: int = -1, num_dataset_splits: int = 1, ) -> List[GenerativeResponse]: """Greedy generation until a stop token is generated.""" @@ -1155,7 +1147,7 @@ def greedy_until( max_input_length = min(len(context_enc) + max_gen, self.max_length) batch_size = self._get_batch_size( - override_bs=self._batch_size, + override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size, ) diff --git a/src/lighteval/parsers.py b/src/lighteval/parsers.py index 499d945e..dbacc6be 100644 --- a/src/lighteval/parsers.py +++ b/src/lighteval/parsers.py @@ -70,6 +70,7 @@ def parser_accelerate(parser=None): "--results_org", type=str, help="Hub organisation where you want to store the results. Your current token must have write access to it", + default=None, ) # Common parameters parser.add_argument( @@ -110,15 +111,16 @@ def parser_nanotron(parser=None): ) parser.add_argument( - "--checkpoint-config-path", + "--checkpoint_config_path", type=str, required=True, - help="Path to the brr checkpoint YAML or python config file, potentially on S3", + help="Path to the nanotron checkpoint YAML or python config file, potentially on S3", ) parser.add_argument( - "--lighteval-override", + "--lighteval_config_path", type=str, - help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", + help="Path to a YAML or python lighteval config to be used for the evaluation. Lighteval key in nanotron config is ignored!", + required=True, ) parser.add_argument( "--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models" diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index dbc28510..228c80d7 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -75,15 +75,15 @@ class PipelineParameters: env_config: EnvConfig = field(default_factory=EnvConfig) job_id: int = 0 dataset_loading_processes: int = 1 - nanotron_checkpoint_path: str = None # only for nanotron models + nanotron_checkpoint_path: str | None = None # only for nanotron models # Dataset - custom_tasks_directory: str = None + custom_tasks_directory: str | None = None # Generation parameters - override_batch_size: int = None + override_batch_size: int | None = None num_fewshot_seeds: int = 1 - max_samples: int = None + max_samples: int | None = None use_chat_template: bool = False - system_prompt: str = None + system_prompt: str | None = None def __post_init__(self): if self.launcher_type == ParallelismManager.ACCELERATE: @@ -140,9 +140,9 @@ def _init_parallelism_manager(self): raise ValueError("You are trying to launch a nanotron model, but nanotron is not installed") dist.initialize_torch_distributed() parallel_context = ParallelContext( - tensor_parallel_size=self.model_config.lighteval.parallelism.tp, - pipeline_parallel_size=self.model_config.lighteval.parallelism.pp, - data_parallel_size=self.model_config.lighteval.parallelism.dp, + tensor_parallel_size=self.model_config.lighteval_config.parallelism.tp, + pipeline_parallel_size=self.model_config.lighteval_config.parallelism.pp, + data_parallel_size=self.model_config.lighteval_config.parallelism.dp, ) test_all_gather(parallel_context=parallel_context) @@ -153,7 +153,9 @@ def _init_model(self, model_config, model): if model_config is not None: if self.parallel_context: return NanotronLightevalModel( - checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path), + checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path) + if self.pipeline_parameters.nanotron_checkpoint_path + else "", nanotron_config=self.model_config, parallel_context=self.parallel_context, debug_one_layer_model=False, diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index f6718389..0ff8f2aa 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -571,10 +571,10 @@ def create_requests_from_tasks( # noqa: C901 fewshot_dict: dict[str, list[Tuple[int, bool]]], num_fewshot_seeds: int, lm: BaseModel, - max_samples: int, + max_samples: int | None, evaluation_tracker: "EvaluationTracker", use_chat_template: bool, - system_prompt: str, + system_prompt: str | None, ) -> Tuple[dict[RequestType, list[Request]], dict[SampleUid, Doc]]: """ Takes a task dict and a fewshot dict and returns a dict of requests, a dict diff --git a/tests/utils.py b/tests/utils.py index cff23167..32ef2431 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -122,7 +122,7 @@ def fake_evaluate_task( task_name = f"{task.suite[0]}|{task.name}" task_dict = {task_name: task} - evaluation_tracker = EvaluationTracker() + evaluation_tracker = EvaluationTracker(output_dir="outputs") evaluation_tracker.task_config_logger.log(task_dict) # Create a mock Registry class