diff --git a/language/llama2-70b/Dockerfile b/language/llama2-70b/Dockerfile new file mode 100644 index 000000000..b04910d73 --- /dev/null +++ b/language/llama2-70b/Dockerfile @@ -0,0 +1,48 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 +SHELL ["/bin/bash", "-c"] + +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 + +ENV TZ=US/Pacific +ENV DEBIAN_FRONTEND=noninteractive + +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone +RUN rm -rf /var/lib/apt/lists/* && rm /etc/apt/sources.list.d/* \ + && apt update \ + && apt install -y --no-install-recommends build-essential autoconf \ + libtool git ccache curl wget pkg-config sudo ca-certificates \ + automake libssl-dev bc python3-dev python3-pip google-perftools \ + gdb libglib2.0-dev clang sshfs libre2-dev libboost-dev \ + libnuma-dev numactl sysstat sshpass ntpdate less iputils-ping \ + && apt -y autoremove \ + && apt remove -y cmake \ + && apt install -y --no-install-recommends pkg-config zip g++ zlib1g-dev \ + unzip libarchive-dev +RUN apt install -y --no-install-recommends rsync + +# Install setuptools +RUN python3 -m pip install --upgrade pip \ + && python3 -m pip install --upgrade setuptools wheel virtualenv + +# Install conda +WORKDIR /tmp +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh \ + && bash Miniconda3-* -b -p /opt/miniconda3 +ENV PATH="$PATH:/opt/miniconda3/bin" +RUN conda create -n llama2-70b python=3.10 +RUN chmod -R 777 /opt/miniconda3 diff --git a/language/llama2-70b/README.md b/language/llama2-70b/README.md index 7827fd4b6..0fefec95f 100644 --- a/language/llama2-70b/README.md +++ b/language/llama2-70b/README.md @@ -7,6 +7,9 @@ ## Prepare environment + +For a CPU-only run: + ``` conda create -n llama2-70b python=3.9 conda activate llama2-70b @@ -26,9 +29,35 @@ git merge llm-server python -m pip install . ``` +For a GPU-based run: + +A dockerfile is provided, along with scripts to help launch it. First, add any docker volume mounts you want in +`launch.sh`. There is a section at the top of the file that looks like: +``` +# Add any volume mounts here with the following syntax +# /path/to/src:/path/to/dir/in/container +MOUNTS=( + $MLCOMMONS_REPO_PATH:$MLCOMMONS_REPO_PATH +) +``` + +For example if you have a raid space located at `/raid/data` on your local machine, you can add it to the same path in the container like so: +``` +# Add any volume mounts here with the following syntax +# /path/to/src:/path/to/dir/in/container +MOUNTS=( + $MLCOMMONS_REPO_PATH:$MLCOMMONS_REPO_PATH + /raid/data:/raid/data +) +``` +Once you have added all your mounts, launch the container with `bash launch.sh`. + +Inside the container, set up the environment with `bash build.sh`. This will install all the dependencies from the +CPU-only setup, as well as any GPU versions for applicable libraries like PyTorch. + ## Get Model -+ For now, MLCommons is not hosting the checkpoing, so you must first go to [llama2-request-link](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and make a request, sign in to huggingface (if you don't have account, you'd need to create one). **Please note your authentication credentials** as you may be required to provide them when cloninng below ++ For now, MLCommons is not hosting the checkpoint, so you must first go to [llama2-request-link](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and make a request, sign in to huggingface (if you don't have account, you'd need to create one). **Please note your authentication credentials** as you may be required to provide them when cloninng below + Requires Git Large Files Storage ``` export CHECKPOINT_PATH=${PWD}/Llama-2-70b-chat-hf @@ -49,7 +78,7 @@ EXPORT_DIR=${PWD}/processed-openorca export DATASET_PATH=${PWD}/processed-data.pkl # Process the dataset according the Taskforce's agreed criteria -python3 processorca.py --dataset_pq_path=${OPENORCA_PARQUET} --model_dir=${CHECKPOINT_PATH} --seqlen_limit=2048 --export_dir=${EXPORT_DIR} --num_total_samples=24576 +python3 processorca.py --dataset_pq_path=${OPENORCA_PARQUET} --model_dir=${CHECKPOINT_PATH} --seqlen_limit=1024 --export_dir=${EXPORT_DIR} --num_total_samples=24576 mv ${EXPORT_DIR}/open_orca_gpt4_tokenized_llama.sampled_24576.pkl ${DATASET_PATH} ``` @@ -65,11 +94,24 @@ python -u main.py --scenario Offline \ --user-conf user.conf \ --total-sample-count 24576 \ --device cpu \ - --dataset-path ${DATASET_PATH} \ + --dataset-path ${DATASET_PATH} \ --output-log-dir offline-logs ``` +For a GPU-based run: +``` +python3 -u main.py --scenario Offline \ + --model-path ${CHECKPOINT_PATH} \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 24576 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir offline-logs \ + --dtype float32 \ + --device cuda:0 2>&1 | tee offline_performance_log.log +``` + ### Server ``` python -u main.py --scenario Server \ @@ -82,6 +124,8 @@ python -u main.py --scenario Server \ --output-log-dir server-logs ``` +The ServerSUT was not tested for GPU runs. + ## Run Accuracy Benchmarks @@ -89,6 +133,8 @@ python -u main.py --scenario Server \ ``` OUTPUT_LOG_DIR=offline-accuracy-logs +mkdir -p "run_outputs" # The script will dump all the outputs to 'run_outputs'. + python -u main.py --scenario Offline \ --model-path ${CHECKPOINT_PATH} \ --accuracy \ @@ -105,8 +151,23 @@ if [ -e ${ACCURACY_LOG_FILE} ]; then python evaluate-accuracy.py --checkpoint-path ${CHECKPOINT_PATH} \ --mlperf-accuracy-file ${ACCURACY_LOG_FILE} --dataset-file ${DATASET_PATH} --dtype int32 fi + +# Optional: Create a pickled pandas DataFrame that is the original dataset with extra columns with output data from the +# accuracy run. The following columns will be added: +# - "gen_output_tok_id": A list of ints representing the tokenized output sequence. +# - "gen_output_text": A str representing the untokenized output sequence. +# - "gen_output_tok_len": An int representing the number of output tokens. +# - "rouge1": The rouge1 score for this sample +# - "rouge2": The rouge2 score for this sample +# - "rougeL": The rougeL score for this sample +# This file will by default be saved to 'full_output.pkl'. You can modify this with --output-pkl-path. +python consolidate_results.py --dataset-path ${DATASET_PATH} --model-dir ${CHECKPOINT_PATH} ``` +For the GPU run - The above steps have been automated in `run_accuracy.sh`. You can also modify this script to use +`--device cpu` to adapt it to a CPU-only run. + + ### Server ``` OUTPUT_LOG_DIR=server-accuracy-logs @@ -129,3 +190,15 @@ if [ -e ${ACCURACY_LOG_FILE} ]; then fi ``` +The ServerSUT was not tested for GPU runs. You can try setting `--device cuda:0`, but YMMV. + + +## Accuracy Target +Running the GPU implementation in FP32 precision resulted in the following FP32 accuracy targets (normalized to a 0-100 +scale from a 0.0-1.0 scale): +- Rouge1: 43.88 +- Rouge2: 21.7108 +- RougeL: 28.2502 +- RougeLsum: 41.4821 + +This was run an 8xH100 node. Total runtime was ~4.5 days. diff --git a/language/llama2-70b/SUT.py b/language/llama2-70b/SUT.py index d29f789bf..9cf009c08 100644 --- a/language/llama2-70b/SUT.py +++ b/language/llama2-70b/SUT.py @@ -8,11 +8,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM from transformers.generation.streamers import BaseStreamer +import pickle +import time import threading +import tqdm import queue import logging from typing import TYPE_CHECKING, Optional, List +from pathlib import Path import mlperf_loadgen as lg from dataset import Dataset @@ -74,11 +78,26 @@ def get_out_tokens(self): class SUT(): - def __init__(self, model_path=None, dtype="bfloat16", device="cpu", total_sample_count=24576, dataset_path=None, workers=1): + def __init__(self, + model_path=None, + dtype="bfloat16", + device="cpu", + batch_size=None, + total_sample_count=24576, + dataset_path=None, + use_cached_outputs=False, # Set this to True *only for test accuracy runs* in case your prior session was killed partway through + workers=1): self.model_path = model_path or "meta-llama/Llama-2-70b-chat-hf" self.device = device + if not batch_size: + if device == "cpu": + batch_size = 1 + else: + batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8. + self.batch_size = batch_size + # dtype if dtype == 'bfloat16': self.amp_enabled = True @@ -94,7 +113,10 @@ def __init__(self, model_path=None, dtype="bfloat16", device="cpu", total_sample assert torch.cuda.is_available(), "torch gpu is not available, exiting..." self.dataset_path = dataset_path - self.data_object = Dataset(self.model_path, dataset_path=self.dataset_path, total_sample_count=total_sample_count) + self.data_object = Dataset(self.model_path, + dataset_path=self.dataset_path, + total_sample_count=total_sample_count, + device=self.device) self.qsl = lg.ConstructQSL(self.data_object.total_sample_count, self.data_object.perf_count, self.data_object.LoadSamplesToRam, self.data_object.UnloadSamplesFromRam) @@ -104,9 +126,12 @@ def __init__(self, model_path=None, dtype="bfloat16", device="cpu", total_sample self.worker_threads = [None] * self.num_workers self.query_queue = queue.Queue() + self.use_cached_outputs = use_cached_outputs + self.sample_counter = 0 + self.sample_counter_lock = threading.Lock() + def start(self): - # Create worker threads for j in range(self.num_workers): worker = threading.Thread(target=self.process_queries) @@ -129,38 +154,90 @@ def process_queries(self): if qitem is None: break - # TODO: If batching, call collator to batch the inputs here - input_ids_tensor = self.data_object.input_ids[qitem.index] - input_masks_tensor = self.data_object.attention_masks[qitem.index] - input_len = [self.data_object.input_lens[qitem.index]] - - - pred_output_tokens = self.model.generate( - input_ids=input_ids_tensor, - attention_mask=input_masks_tensor, - pad_token_id=self.tokenizer.pad_token_id, - **gen_kwargs - ) - - processed_output = self.data_object.postProcess(pred_output_tokens, input_len) - - response_array = array.array("B", processed_output[0].tobytes()) - bi = response_array.buffer_info() - response = [lg.QuerySampleResponse( - qitem.id, bi[0], bi[1])] - lg.QuerySamplesComplete(response) + query_ids = [q.index for q in qitem] + + fname = "q" + "_".join([str(i) for i in query_ids]) + fname = f"run_outputs/{fname}.pkl" + _p = Path(fname) + if self.use_cached_outputs and _p.exists(): + # Read cache + with _p.open(mode="rb") as f: + d = pickle.load(f) + processed_output = d["outputs"] + tik1 = None + tik2 = None + tik3 = None + tok = None + else: + # Construct / collate batch + max_seq_len = 1024 + + tik1 = time.time() + + input_ids_tensor = [] + input_masks_tensor = [] + input_len = [] + for q in qitem: + input_ids_tensor.append(pad(self.data_object.input_ids[q.index], + (0, max_seq_len - self.data_object.input_lens[q.index], 0, 0), + value=self.tokenizer.pad_token_id)) + input_masks_tensor.append(pad(self.data_object.attention_masks[q.index], + (0, max_seq_len - self.data_object.input_lens[q.index], 0, 0), + value=0)) + input_len.append(self.data_object.input_lens[q.index]) + input_ids_tensor = torch.cat(input_ids_tensor) + input_masks_tensor = torch.cat(input_masks_tensor) + + assert input_ids_tensor.shape == input_masks_tensor.shape + assert input_ids_tensor.shape[0] <= self.batch_size + + tik2 = time.time() + + pred_output_tokens = self.model.generate( + input_ids=input_ids_tensor, + attention_mask=input_masks_tensor, + pad_token_id=self.tokenizer.pad_token_id, + **gen_kwargs + ) + + tik3 = time.time() + + processed_output = self.data_object.postProcess(pred_output_tokens, + input_seq_lens=input_len, + query_id_list=query_ids) + + for i in range(len(qitem)): + response_array = array.array("B", processed_output[i].tobytes()) + bi = response_array.buffer_info() + response = [lg.QuerySampleResponse(qitem[i].id, bi[0], bi[1])] + lg.QuerySamplesComplete(response) + + tok = time.time() + + with self.sample_counter_lock: + self.sample_counter += len(qitem) + print(f"Samples run: {self.sample_counter}") + if tik1: + print(f"\tBatchMaker time: {tik2 - tik1}") + print(f"\tInference time: {tik3 - tik2}") + print(f"\tPostprocess time: {tok - tik3}") + print(f"\t==== Total time: {tok - tik1}") + else: + print(f"\tLoaded from cache: {_p}") def load_model(self): self.model = LlamaForCausalLM.from_pretrained( self.model_path, - device_map= "auto" if self.device=="cpu" else None, - low_cpu_mem_usage=True if self.device=="cpu" else False, + device_map="auto", + low_cpu_mem_usage=True, torch_dtype=self.amp_dtype ) + print("Loaded model") self.device = torch.device(self.device) - self.model.to(self.device) + if self.device == "cpu": + self.model = self.model.to(self.device) # Force CPU if your system has GPU and you specifically want CPU-only run self.model.eval() self.model = self.model.to(memory_format=torch.channels_last) @@ -172,6 +249,7 @@ def load_model(self): use_fast=False,) self.tokenizer.pad_token = self.tokenizer.eos_token + print("Loaded tokenizer") def get_sut(self): self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) @@ -191,8 +269,11 @@ def issue_queries(self, query_samples): list_prompts_tokens = [] list_prompts_attn_masks = [] - for i in range(len(query_samples)): - self.query_queue.put(query_samples[i]) + print(f"IssueQuery started with {len(query_samples)} samples") + while len(query_samples) > 0: + self.query_queue.put(query_samples[:self.batch_size]) + query_samples = query_samples[self.batch_size:] + print(f"IssueQuery done") def flush_queries(self): @@ -245,7 +326,7 @@ def process_queries(self): qitem = self.query_queue.get() if qitem is None: break - + input_ids_tensor = self.data_object.input_ids[qitem.index] input_masks_tensor = self.data_object.attention_masks[qitem.index] @@ -267,7 +348,7 @@ def process_queries(self): response = [lg.QuerySampleResponse( qitem.id, bi[0], bi[1])] lg.QuerySamplesComplete(response) - + def issue_queries(self, query_samples): diff --git a/language/llama2-70b/build.sh b/language/llama2-70b/build.sh new file mode 100644 index 000000000..87afb992f --- /dev/null +++ b/language/llama2-70b/build.sh @@ -0,0 +1,8 @@ +set -e + +conda install pybind11==2.10.4 -c conda-forge -y +conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch-nightly -c nvidia +python -m pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 + + +cd ../../loadgen && python3 -m pip install . diff --git a/language/llama2-70b/consolidate_results.py b/language/llama2-70b/consolidate_results.py new file mode 100644 index 000000000..ad5fbe411 --- /dev/null +++ b/language/llama2-70b/consolidate_results.py @@ -0,0 +1,116 @@ +import argparse +import evaluate +import glob +import nltk +import numpy as np +import os +import pandas as pd +import pickle + +from pathlib import Path +from transformers import LlamaTokenizerFast +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset-path", type=str, default=None, help="Path to .pkl generated by processorca.py") + parser.add_argument("--run-outputs", type=str, default="run_outputs", help="Output dir generated by accuracy run.") + parser.add_argument("--model-dir", type=str, default=None, help="Path to Llamav2 HuggingFace repo clone") + parser.add_argument("--output-pkl-path", type=str, default="full_output.pkl", help="Path to dump output to") + args = parser.parse_args() + return args + + +def load_dataset(p: os.PathLike): + print(f"Loading from {p}...") + return pd.read_pickle(p) + + +def load_run_outputs(p: os.PathLike): + g = glob.glob(str(Path(p) / "q*.pkl")) + + by_query_idx = dict() + for pkl_file in g: + print(f"Loading from {pkl_file}...") + with open(pkl_file, 'rb') as f: + d = pickle.load(f) + assert len(d["query_ids"]) == len(d["outputs"]) + + for i in range(len(d["query_ids"])): + qid = d["query_ids"][i] + assert qid not in by_query_idx + by_query_idx[qid] = d["outputs"][i] + + return by_query_idx + + +def main(args): + # Set up decode and evaluation objects + tokenizer = LlamaTokenizerFast.from_pretrained(args.model_dir) + metric = evaluate.load("rouge") + nltk.download("punkt") + + # Load Data + df = load_dataset(args.dataset_path) + run_outputs = load_run_outputs(args.run_outputs) + assert len(run_outputs) == 24576 + + # Set up columns to add + output_tok_ids_col = [None] * 24576 + output_text_col = [None] * 24576 + output_lens = [None] * 24576 + + # Process data + no_eos_ids = [] + for qid, output in tqdm(run_outputs.items()): + L = list(output) + # Prune trailing 2s (EOS token) + try: + first2 = L.index(2) + L = L[:first2] + except ValueError: + # Do nothing + no_eos_ids.append(qid) + + assert L[-1] != 2 + output_tok_ids_col[qid] = L + output_lens[qid] = len(L) + + # Decode tokens + output_text_col[qid] = tokenizer.decode(output_tok_ids_col[qid], skip_special_tokens=True) + print(f"Found {len(no_eos_ids)} samples with no EOS token") + + print("Calculating rouge scores...") + _preproc = lambda s: "\n".join(nltk.sent_tokenize(s.strip())) + preds = list(map(_preproc, output_text_col)) + targets = list(map(_preproc, list(df["output"]))) + rouge_scores = metric.compute(predictions=preds, + references=targets, + use_stemmer=True, + use_aggregator=False) + + assert len(rouge_scores["rouge1"]) == 24576 + assert len(rouge_scores["rouge2"]) == 24576 + assert len(rouge_scores["rougeL"]) == 24576 + + agg = {k: round(np.mean(v) * 100, 4) for k, v in rouge_scores.items()} + print(agg) + print("Avg output seqlen:", np.mean(output_lens)) + + # Set columns + df["gen_output_tok_id"] = output_tok_ids_col + df["gen_output_text"] = output_text_col + df["gen_output_tok_len"] = output_lens + df["rouge1"] = rouge_scores["rouge1"] + df["rouge2"] = rouge_scores["rouge2"] + df["rougeL"] = rouge_scores["rougeL"] + + p = Path(args.output_pkl_path) + p.parent.mkdir(exist_ok=True) + df.to_pickle(p) + print(f"Dumped to {p}") + + +if __name__ == "__main__": + main(get_args()) diff --git a/language/llama2-70b/dataset.py b/language/llama2-70b/dataset.py index 9b5452050..a59fd7f55 100644 --- a/language/llama2-70b/dataset.py +++ b/language/llama2-70b/dataset.py @@ -10,6 +10,7 @@ import io #import utils import copy +import pickle import logging logging.basicConfig(level=logging.INFO) @@ -46,6 +47,7 @@ def load_processed_dataset(self): if not os.path.isfile(self.dataset_path): log.warn("Processed pickle file {} not found. Please check that the path is correct".format(self.dataset_path)) + print("Loading dataset...") import pandas as pd processed_data = pd.read_pickle(self.dataset_path) @@ -61,12 +63,14 @@ def load_processed_dataset(self): self.input_ids.append(input_ids) self.attention_masks.append(attn_mask) self.input_lens.append(input_ids.shape[-1]) + print("Finished loading dataset.") def postProcess(self, out_tokens, input_seq_lens=None, query_id_list=None, sample_index_list=None): """ Postprocesses output prediction """ #TODO: Create response object in postProcess(?) + """ preds = [] for i in range(out_tokens.shape[0]): #pred = out_tokens[i].reshape(-1).cpu().numpy() # Slice up to original input length as below? @@ -74,8 +78,21 @@ def postProcess(self, out_tokens, input_seq_lens=None, query_id_list=None, sampl input_len = input_seq_lens[i] if input_seq_lens else 0 pred = out_tokens[i, input_len:].reshape(-1).cpu().numpy() preds.append(pred) - - return preds + """ + # Everything is padded to max_len (1024), so prune the input and parse to numpy + output_seq = out_tokens[:, 1024:].cpu().numpy() + assert len(query_id_list) == output_seq.shape[0] + + # Save outputs + fname = "q" + "_".join([str(i) for i in query_id_list]) + fname = f"run_outputs/{fname}.pkl" + with open(fname, mode='wb') as f: + d = {"query_ids": query_id_list, + "outputs": output_seq} + print(f"Saving outputs to {fname}") + pickle.dump(d, f) + + return output_seq def LoadSamplesToRam(self, sample_list): pass diff --git a/language/llama2-70b/launch.sh b/language/llama2-70b/launch.sh new file mode 100644 index 000000000..c3389c516 --- /dev/null +++ b/language/llama2-70b/launch.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +MLCOMMONS_REPO_PATH="$(dirname "$(dirname "$PWD")")" + +# Add any volume mounts here with the following syntax +# /path/to/src:/path/to/dir/in/container +MOUNTS=( + $MLCOMMONS_REPO_PATH:$MLCOMMONS_REPO_PATH +) + +# Set up docker environment file for current user +rm -f .docker_env +echo "CI_BUILD_USER=`id -u -n`" >> .docker_env +echo "CI_BUILD_UID=`id -u`" >> .docker_env +echo "CI_BUILD_GROUP=`id -g -n`" >> .docker_env +echo "CI_BUILD_GID=`id -g`" >> .docker_env +cat .docker_env + +# Build container +docker build . -t llm/gpubringup + +# Build mount flags +declare -a MOUNT_FLAGS +for _mount in ${MOUNTS[@]}; do + _split=($(echo $_mount | tr ':' '\n')); + MOUNT_FLAGS+=("--mount type=bind,source=${_split[0]},target=${_split[1]}"); +done + +set -x +nvidia-docker run -it --rm --net=host --runtime=nvidia --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ + --cap-add=SYS_PTRACE --cap-add=SYS_ADMIN --cap-add=DAC_READ_SEARCH \ + --security-opt seccomp=unconfined \ + -w $PWD \ + --env-file `pwd`/.docker_env \ + ${MOUNT_FLAGS[*]} \ + llm/gpubringup \ + bash ./with_the_same_user diff --git a/language/llama2-70b/processorca.py b/language/llama2-70b/processorca.py index 0cc9c3f8a..91dec5d51 100644 --- a/language/llama2-70b/processorca.py +++ b/language/llama2-70b/processorca.py @@ -19,6 +19,7 @@ import pandas as pd import numpy as np from dataclasses import dataclass +from functools import partial from pathlib import Path from transformers import LlamaTokenizerFast from typing import Dict @@ -41,6 +42,19 @@ def is_english(s): return True +def _tokenize_helper(x, llama_tokenizer=None, append_response_init_token=True): + if not isinstance(x, str): + return [] + + tokens = llama_tokenizer(x)["input_ids"] + + if append_response_init_token: + # Workaround to enable cheat checking for first token: Llama always outputs token 29871 first + # It is possible for submitters to just immediately output this token to achieve a very fast TTFT. + tokens.append(29871) + return tokens + + @dataclass class Keyphrase: col: str @@ -53,11 +67,13 @@ class OpenOrcaDatasetGenerator: def __init__(self, pq_path: os.PathLike, model_dir: os.PathLike, - io_token_limit: int): + io_token_limit: int, + calibration_subset_size: int = 1000): self.pq_path = Path(pq_path) self.model_dir = Path(model_dir) self.io_token_limit = io_token_limit self.keyphrases = [] + self.calibration_subset_size = calibration_subset_size def load_parquet(self) -> pd.DataFrame: llama_tokenizer = LlamaTokenizerFast.from_pretrained(self.model_dir) @@ -67,8 +83,11 @@ def load_parquet(self) -> pd.DataFrame: print(f"Tokenizing input") df.rename(columns={'response': 'output'}, inplace=True) df['input'] = df.apply(format_llama_input, axis=1) - df['tok_input'] = df['input'].apply(lambda x: llama_tokenizer(x)['input_ids'] if isinstance(x, str) else []) - df['tok_output'] = df['output'].apply(lambda x: llama_tokenizer(x)['input_ids'] if isinstance(x, str) else []) + + input_tokenizer = partial(_tokenize_helper, llama_tokenizer=llama_tokenizer) + output_tokenizer = partial(_tokenize_helper, llama_tokenizer=llama_tokenizer, append_response_init_token=False) + df['tok_input'] = df['input'].apply(input_tokenizer) + df['tok_output'] = df['output'].apply(output_tokenizer) tok = time.time() print(f"Loaded parquet and tokenized in {tok-tik} sec.") return df @@ -167,7 +186,11 @@ def sample(self, dfs_by_origin: Dict[str, pd.DataFrame], n_total, rng_seed: int sampled_df = sampled_df.reset_index(drop=True) return sampled_df - def generate(self, export_dir: os.PathLike, n_samples: int = 24576, use_cached: bool = True): + def generate(self, + export_dir: os.PathLike, + n_samples: int = 24576, + use_cached: bool = True, + calib_rng_seed: int = 12345): export_dir = Path(export_dir) if not export_dir.exists(): print(f"Creating {export_dir}") @@ -208,6 +231,13 @@ def generate(self, export_dir: os.PathLike, n_samples: int = 24576, use_cached: sampled_fpath = export_dir / f"open_orca_gpt4_tokenized_llama.sampled_{n_samples}.pkl" sampled_df.to_pickle(sampled_fpath) + # Calibration dataset + calib_ds = sampled_df.sample(n=self.calibration_subset_size, + random_state=calib_rng_seed) + calib_ds = calib_ds.reset_index(drop=True) + calib_fpath = export_dir / f"open_orca_gpt4_tokenized_llama.calibration_{self.calibration_subset_size}.pkl" + calib_ds.to_pickle(calib_fpath) + def parse_arguments(): parser = argparse.ArgumentParser() @@ -215,11 +245,12 @@ def parse_arguments(): default='/raid/data/mlperf-llm/OpenOrca/1M-GPT4-Augmented.parquet', help="the path to the open_orca GPT4 parquet.") parser.add_argument('--model_dir', type=str, default='/raid/data/mlperf-llm/Llama-2-70b-chat-hf') - parser.add_argument('--seqlen_limit', type=int, default=2048, help="Upper limit of the input/output sequence lengths") + parser.add_argument('--seqlen_limit', type=int, default=1024, help="Upper limit of the input/output sequence lengths") parser.add_argument('--export_dir', type=str, default="/raid/data/mlperf-llm/OpenOrca/llama/filtered", help="Path to the output pkl file.") parser.add_argument('--num_total_samples', type=int, default=24576, help="Number of samples to generate") + parser.add_argument('--calibration_subset_size', type=int, default=1000, help="Number of samples for calibration subset") return parser.parse_args() @@ -229,6 +260,7 @@ def parse_arguments(): pq_path=args.dataset_pq_path, model_dir=args.model_dir, io_token_limit=args.seqlen_limit, + calibration_subset_size=args.calibration_subset_size, ) ds_gen.generate( export_dir=args.export_dir, @@ -236,4 +268,4 @@ def parse_arguments(): ) # Sample command to run: - # python3 processorca.py --dataset_pq_path=/raid/data/mlperf-llm/OpenOrca/1M-GPT4-Augmented.parquet --model_dir=/raid/data/mlperf-llm/Llama-2-70b-chat-hf --seqlen_limit=2048 --export_dir=/raid/data/mlperf-llm/OpenOrca/llama/filtered --num_total_samples=24576 + # python3 processorca.py --dataset_pq_path=/raid/data/mlperf-llm/OpenOrca/1M-GPT4-Augmented.parquet --model_dir=/raid/data/mlperf-llm/Llama-2-70b-chat-hf --seqlen_limit=1024 --export_dir=/raid/data/mlperf-llm/OpenOrca/llama/filtered --num_total_samples=24576 diff --git a/language/llama2-70b/run_accuracy.sh b/language/llama2-70b/run_accuracy.sh new file mode 100644 index 000000000..b4f7f8ad9 --- /dev/null +++ b/language/llama2-70b/run_accuracy.sh @@ -0,0 +1,22 @@ +CHECKPOINT_PATH="${CHECKPOINT_PATH:-meta-llama/Llama-2-70b-chat-hf}" +DATASET_PATH="${DATASET_PATH:-open-orca-val-set.pkl}" + +mkdir -p "run_outputs" + +python3 -u main.py --scenario Offline \ + --model-path ${CHECKPOINT_PATH} \ + --accuracy \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 24576 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir offline_accuracy_loadgen_logs \ + --dtype float32 \ + --device cuda:0 2>&1 | tee offline_accuracy_log.log + +python3 evaluate-accuracy.py --checkpoint-path ${CHECKPOINT_PATH} \ + --mlperf-accuracy-file offline_accuracy_loadgen_logs/mlperf_log_accuracy.json \ + --dataset-file ${DATASET_PATH} \ + --dtype int32 + +python3 consolidate_results.py --dataset-path ${DATASET_PATH} --model-dir ${CHECKPOINT_PATH} diff --git a/language/llama2-70b/run_offline.sh b/language/llama2-70b/run_offline.sh index 2ea28c436..7153ea7ca 100644 --- a/language/llama2-70b/run_offline.sh +++ b/language/llama2-70b/run_offline.sh @@ -1,5 +1,3 @@ - - CHECKPOINT_PATH="${CHECKPOINT_PATH:-meta-llama/Llama-2-70b-chat-hf}" DATASET_PATH="${DATASET_PATH:-open-orca-val-set.pkl}" diff --git a/language/llama2-70b/with_the_same_user b/language/llama2-70b/with_the_same_user new file mode 100755 index 000000000..6b98baf7d --- /dev/null +++ b/language/llama2-70b/with_the_same_user @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# wkong: manually set the user info in env first + +set -ex + +if [ -z "$@" ]; then + COMMAND=(bash) +else + COMMAND=("$@") +fi + +apt-get update && apt-get install -y sudo + +getent group "${CI_BUILD_GID}" || addgroup --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" +getent passwd "${CI_BUILD_UID}" || adduser --gid "${CI_BUILD_GID}" --uid "${CI_BUILD_UID}" --gecos "${CI_BUILD_USER} (generated by with_the_same_user script)" --disabled-password --quiet "${CI_BUILD_USER}" + +usermod -a -G dip "${CI_BUILD_USER}" +usermod -a -G sudo "${CI_BUILD_USER}" +usermod -a -G root "${CI_BUILD_USER}" + +echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers +mkdir -p /home/"${CI_BUILD_USER}" +touch /home/"${CI_BUILD_USER}"/.bashrc +echo 'export PATH="$PATH:/opt/miniconda3/bin"' >> /home/"${CI_BUILD_USER}"/.bashrc + +sudo -H -u "#${CI_BUILD_UID}" --preserve-env \ + PATH="${PATH}" \ + LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" \ + PYTHONPATH="${PYTHONPATH}" \ + bash -c "conda init bash" + +echo 'conda activate llama2-70b' >> /home/"${CI_BUILD_USER}"/.bashrc + + +sudo -H -u "#${CI_BUILD_UID}" --preserve-env \ + PATH="${PATH}" \ + LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" \ + PYTHONPATH="${PYTHONPATH}" \ + ${COMMAND[@]}