Skip to content

Commit

Permalink
Merge pull request #2 from myshell-ai/v0.2.0
Browse files Browse the repository at this point in the history
V0.2.0
  • Loading branch information
ctlllll authored Apr 5, 2024
2 parents 33c2f55 + ac7b88a commit 85a5a05
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 57 deletions.
2 changes: 1 addition & 1 deletion docs/validator.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Prerequisites:

To run the validator with auto-updates, use the following command:
```bash
pm2 start neurons/validator.py --name validator --interpreter python -- --wallet.name your_wallet --wallet.hotkey your_hotkey
pm2 start --name finetune-vali-updater --interpreter python scripts/start_validator.py -- --pm2_name finetune-vali --wallet.name coldkey --wallet.hotkey hotkey [other vali flags]
```

## Without auto-updates
Expand Down
13 changes: 10 additions & 3 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
import traceback
import bittensor as bt

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def iswin(loss_i, loss_j, block_i, block_j):
"""
Expand Down Expand Up @@ -712,6 +716,9 @@ async def run_step(self):
model_i.ckpt,
competition_parameters.competition_id,
seed,
samples=self.config.num_samples_per_eval,
batch_size=16,
group_size=16,
)

del model_i
Expand All @@ -729,8 +736,8 @@ async def run_step(self):
f"Unable to load the model for {uid_i} (perhaps a duplicate?). Setting loss to inifinity."
)
if len(losses) == 0:
# Currently evaluate on 10 samples to get 10 tone similarity losses and 10 word error rate losses.
losses = [math.inf] * 20
# 3 metrics, 64 samples, 16 per group
losses = [math.inf] * (3 * self.config.num_samples_per_eval // 16)

losses_per_uid[uid_i] = losses
average_model_loss = sum(losses) / len(losses)
Expand Down Expand Up @@ -937,7 +944,7 @@ async def run(self):
self.global_step += 1

if not self.config.dont_set_weights and not self.config.offline:
await self.try_set_weights(ttl=60)
await self.try_set_weights(ttl=120)
self.last_epoch = self.metagraph.block.item()
self.epoch_step += 1

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tts-subnet"
version = "0.1.0"
version = "0.2.0"
description = "MyShell TTS Subnet"
authors = ["TC <[email protected]>"]
readme = "README.md"
Expand All @@ -15,7 +15,7 @@ melotts = {path = "MeloTTS"}
wandb = "^0.16.4"
safetensors = "^0.4.2"
python-dotenv = "^1.0.1"

onnxruntime = "^1.17.1"

[build-system]
requires = ["poetry-core"]
Expand Down
169 changes: 169 additions & 0 deletions scripts/start_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
This script runs a validator process and automatically updates it when a new version is released.
Command-line arguments will be forwarded to validator (`neurons/validator.py`), so you can pass
them like this:
python3 scripts/start_validator.py --wallet.name=my-wallet
Auto-updates are enabled by default and will make sure that the latest version is always running
by pulling the latest version from git and upgrading python packages. This is done periodically.
Local changes may prevent the update, but they will be preserved.
The script will use the same virtual environment as the one used to run it. If you want to run
validator within virtual environment, run this auto-update script from the virtual environment.
Pm2 is required for this script. This script will start a pm2 process using the name provided by
the --pm2_name argument.
"""

import argparse
import logging
import subprocess
import sys
import time
from datetime import timedelta
from shlex import split
from typing import List
import constants

log = logging.getLogger(__name__)
UPDATES_CHECK_TIME = timedelta(minutes=15)


def get_version() -> str:
"""Extract the version as current git commit hash"""
result = subprocess.run(
split("git rev-parse HEAD"),
check=True,
capture_output=True,
cwd=constants.ROOT_DIR,
)
commit = result.stdout.decode().strip()
assert len(commit) == 40, f"Invalid commit hash: {commit}"
return commit[:8]


def start_validator_process(pm2_name: str, args: List[str]) -> subprocess.Popen:
"""
Spawn a new python process running neurons.validator.
`sys.executable` ensures thet the same python interpreter is used as the one
used to run this auto-updater.
"""
assert sys.executable, "Failed to get python executable"

log.info("Starting validator process with pm2, name: %s", pm2_name)
process = subprocess.Popen(
(
"pm2",
"start",
sys.executable,
"--name",
pm2_name,
"--",
"-m",
"neurons.validator",
*args,
),
cwd=constants.ROOT_DIR,
)
process.pm2_name = pm2_name

return process


def stop_validator_process(process: subprocess.Popen) -> None:
"""Stop the validator process"""
subprocess.run(
("pm2", "delete", process.pm2_name), cwd=constants.ROOT_DIR, check=True
)


def pull_latest_version() -> None:
"""
Pull the latest version from git.
This uses `git pull --rebase`, so if any changes were made to the local repository,
this will try to apply them on top of origin's changes. This is intentional, as we
don't want to overwrite any local changes. However, if there are any conflicts,
this will abort the rebase and return to the original state.
The conflicts are expected to happen rarely since validator is expected
to be used as-is.
"""
try:
subprocess.run(
split("git pull --rebase --autostash"), check=True, cwd=constants.ROOT_DIR
)
except subprocess.CalledProcessError as exc:
log.error("Failed to pull, reverting: %s", exc)
subprocess.run(split("git rebase --abort"), check=True, cwd=constants.ROOT_DIR)


def upgrade_packages() -> None:
"""
Upgrade python packages by running `pip install --upgrade -r requirements.txt`.
Notice: this won't work if some package in `requirements.txt` is downgraded.
Ignored as this is unlikely to happen.
"""

log.info("Upgrading packages")
try:
subprocess.run(
split(f"{sys.executable} -m pip install -e ."),
check=True,
cwd=constants.ROOT_DIR,
)
except subprocess.CalledProcessError as exc:
log.error("Failed to upgrade packages, proceeding anyway. %s", exc)


def main(pm2_name: str, args: List[str]) -> None:
"""
Run the validator process and automatically update it when a new version is released.
This will check for updates every `UPDATES_CHECK_TIME` and update the validator
if a new version is available. Update is performed as simple `git pull --rebase`.
"""

validator = start_validator_process(pm2_name, args)
current_version = latest_version = get_version()
log.info("Current version: %s", current_version)

try:
while True:
pull_latest_version()
latest_version = get_version()
log.info("Latest version: %s", latest_version)

if latest_version != current_version:
log.info(
"Upgraded to latest version: %s -> %s",
current_version,
latest_version,
)
upgrade_packages()

stop_validator_process(validator)
validator = start_validator_process(pm2_name, args)
current_version = latest_version

time.sleep(UPDATES_CHECK_TIME.total_seconds())

finally:
stop_validator_process(validator)


if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)

parser = argparse.ArgumentParser(
description="Automatically update and restart the validator process when a new version is released.",
epilog="Example usage: python start_validator.py --pm2_name 'net3vali' --wallet_name 'wallet1' --wallet_hotkey 'key123'",
)

parser.add_argument(
"--pm2_name", default="net3vali", help="Name of the PM2 process."
)

flags, extra_args = parser.parse_known_args()

main(flags.pm2_name, extra_args)
Binary file added tts_rater/DNSMOS/bak_ovr.onnx
Binary file not shown.
Binary file added tts_rater/DNSMOS/model_v8.onnx
Binary file not shown.
Binary file added tts_rater/DNSMOS/sig.onnx
Binary file not shown.
Binary file added tts_rater/DNSMOS/sig_bak_ovr.onnx
Binary file not shown.
Loading

0 comments on commit 85a5a05

Please sign in to comment.