From 349751ed24400d9c0bbb52c22f07e9f3e11b7bc7 Mon Sep 17 00:00:00 2001 From: Filippo Pedrazzini Date: Fri, 22 Sep 2023 09:53:56 +0200 Subject: [PATCH] using float32 on all archs for petals clients --- cht-petals/build.sh | 2 +- cht-petals/download.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cht-petals/build.sh b/cht-petals/build.sh index 56036dd..87c6491 100755 --- a/cht-petals/build.sh +++ b/cht-petals/build.sh @@ -1,6 +1,6 @@ #!/bin/bash set -e -export VERSION=1.0.1 +export VERSION=1.0.2 source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh" build_cpu ghcr.io/premai-io/chat-stable-beluga-2-cpu petals-team/StableBeluga2 ${@:1} diff --git a/cht-petals/download.py b/cht-petals/download.py index bddad6b..136edef 100644 --- a/cht-petals/download.py +++ b/cht-petals/download.py @@ -1,5 +1,4 @@ import argparse -from platform import machine import torch from petals import AutoDistributedModelForCausalLM @@ -17,11 +16,9 @@ def download_model() -> None: Tokenizer = LlamaTokenizer if "llama" in args.model.lower() else AutoTokenizer _ = Tokenizer.from_pretrained(args.model) - - kwargs = {} - if "x86_64" in machine(): - kwargs["torch_dtype"] = torch.float32 - _ = AutoDistributedModelForCausalLM.from_pretrained(args.model, **kwargs) + _ = AutoDistributedModelForCausalLM.from_pretrained( + args.model, torch_dtype=torch.float32 + ) download_model()