Skip to content

Commit

Permalink
Sortformer Diarizer 4spk v1 model PR Part 2: Unit-tests for Sortforme…
Browse files Browse the repository at this point in the history
…r Diarizer. (#11336)

* Adding the first pr files models and dataset

Signed-off-by: taejinp <[email protected]>

* Tested all unit-test files

Signed-off-by: taejinp <[email protected]>

* Name changes on yaml files and train example

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Reflecting comments and removing unnecessary parts for this PR

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Adding docstrings to reflect the PR comments

Signed-off-by: taejinp <[email protected]>

* removed the unused find_first_nonzero

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Fixed all pylint issues

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Resolving pylint issues

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Removing unused varialbe in audio_to_diar_label.py

Signed-off-by: taejinp <[email protected]>

* Fixed docstrings in training script

Signed-off-by: taejinp <[email protected]>

* Line-too-long issue from Pylint fixed

Signed-off-by: taejinp <[email protected]>

* Adding get_subsegments_scriptable to prevent jit.script error

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Addressed Code-QL issues

Signed-off-by: taejinp <[email protected]>

* Resolved conflicts on bce_loss.py

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Adding all the diarization reltated unit-tests

Signed-off-by: taejinp <[email protected]>

* Moving speaker task related unit test files to speaker_tasks folder

Signed-off-by: taejinp <[email protected]>

* Fixed uninit variable issue in bce_loss.py spotted by codeQL

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Fixing code-QL issues

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Reflecting PR comments from weiqingw

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Line too long pylint issue resolved in e2e_diarize_speech.py

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Resovled unused variable issue in model test

Signed-off-by: taejinp <[email protected]>

* Reflecting the comment on Nov 21st  2024.

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Unused variable import time

Signed-off-by: taejinp <[email protected]>

* Adding docstrings to score_labels() function in der.py

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Reflecting comments on YAML files and model file variable changes.

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Added get_subsegments_scriptable for legacy get_subsegment functions

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Resolved line too long pylint issues

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Added training and inference CI-tests

Signed-off-by: taejinp <[email protected]>

* Added the missing parse_func in preprocessing/collections.py

Signed-off-by: taejinp <[email protected]>

* Adding the missing parse_func in preprocessing/collections.py

Signed-off-by: taejinp <[email protected]>

* Fixed an indentation error

Signed-off-by: taejinp <[email protected]>

* Resolved multi_bin_acc and bce_loss issues

Signed-off-by: taejinp <[email protected]>

* Resolved line-too-long for msdd_models.py

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Code QL issues and fixed test errors

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* line too long in audio_to_diar_label.py

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* resolving CICD test issues

Signed-off-by: taejinp <[email protected]>

* Fixing codeQL issues

Signed-off-by: taejinp <[email protected]>

* Fixed pin memory False for inference

Signed-off-by: taejinp <[email protected]>

---------

Signed-off-by: taejinp <[email protected]>
Signed-off-by: tango4j <[email protected]>
Co-authored-by: tango4j <[email protected]>
  • Loading branch information
tango4j and tango4j authored Dec 16, 2024
1 parent 356a3a6 commit 523936e
Show file tree
Hide file tree
Showing 15 changed files with 1,686 additions and 10 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,33 @@ jobs:
+trainer.fast_dev_run=True \
exp_manager.exp_dir=/tmp/speaker_diarization_results
L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
python examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py \
trainer.devices="[0]" \
batch_size=3 \
model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
exp_manager.exp_dir=/tmp/speaker_diarization_results \
+trainer.fast_dev_run=True
L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \
model_path=/home/TestData/an4_diarizer/diar_sortformer_4spk-v1-tiny.nemo \
dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
batch_size=1
L2_Speaker_dev_run_Speech_to_Label:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4753,6 +4780,8 @@ jobs:
- L2_Speech_to_Text_EMA
- L2_Speaker_dev_run_Speaker_Recognition
- L2_Speaker_dev_run_Speaker_Diarization
- L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer
- L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference
- L2_Speaker_dev_run_Speech_to_Label
- L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference
- L2_Speaker_dev_run_Clustering_Diarizer_Inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]:
diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest
infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest)
diar_model._cfg.test_ds.batch_size = cfg.batch_size
diar_model._cfg.test_ds.pin_memory = False

# Model setup for inference
diar_model._cfg.test_ds.num_workers = cfg.num_workers
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ def __init__(
round_digits: int = 2,
soft_targets: bool = False,
subsampling_factor: int = 8,
device: str = 'cpu',
):
super().__init__()
self.collection = EndtoEndDiarizationSpeechLabel(
Expand All @@ -1084,6 +1085,7 @@ def __init__(
self.soft_targets = soft_targets
self.round_digits = 2
self.floor_decimal = 10**self.round_digits
self.device = device

def __len__(self):
return len(self.collection)
Expand Down Expand Up @@ -1232,11 +1234,13 @@ def __getitem__(self, index):
audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)]

audio_signal_length = torch.tensor(audio_signal.shape[0]).long()
audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu')
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate)
audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device)
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate).to(
self.device
)
targets = self.parse_rttm_for_targets_and_lens(
rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len
)
).to(self.device)
return audio_signal, audio_signal_length, targets, target_len


Expand Down Expand Up @@ -1355,6 +1359,7 @@ def __init__(
window_stride,
global_rank: int,
soft_targets: bool,
device: str,
):
super().__init__(
manifest_filepath=manifest_filepath,
Expand All @@ -1365,6 +1370,7 @@ def __init__(
window_stride=window_stride,
global_rank=global_rank,
soft_targets=soft_targets,
device=device,
)

def eesd_train_collate_fn(self, batch):
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/models/sortformer_diar_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __setup_dataloader_from_config(self, config):
window_stride=self._cfg.preprocessor.window_stride,
global_rank=global_rank,
soft_targets=config.soft_targets if 'soft_targets' in config else False,
device=self.device,
)

self.data_collection = dataset.collection
Expand Down Expand Up @@ -557,13 +558,13 @@ def test_batch(
audio_signal=audio_signal,
audio_signal_length=audio_signal_length,
)
self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens)
preds = preds.detach().to('cpu')
if preds.shape[0] == 1: # batch size = 1
self.preds_total_list.append(preds)
else:
self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0]))
torch.cuda.empty_cache()
self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens)

logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}")
logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}")
Expand Down
110 changes: 110 additions & 0 deletions tests/collections/speaker_tasks/test_diar_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile

import pytest
import torch.cuda

from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.asr.parts.utils.speaker_utils import get_vad_out_from_rttm_line, read_rttm_lines


def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec):
"""
Check if the maximum RTTM duration exceeds the length of the provided audio file.
Args:
rttm_file_path (str): Path to the RTTM file.
wav_len_in_sec (float): Length of the audio file in seconds.
Returns:
bool: True if the maximum RTTM duration is less than or equal to the length of the audio file, False otherwise.
"""
rttm_lines = read_rttm_lines(rttm_file_path)
max_rttm_sec = 0
for line in rttm_lines:
start, dur = get_vad_out_from_rttm_line(line)
max_rttm_sec = max(max_rttm_sec, start + dur)
return max_rttm_sec <= wav_len_in_sec


class TestAudioToSpeechE2ESpkDiarDataset:

@pytest.mark.unit
def test_e2e_speaker_diar_dataset(self, test_data_dir):
manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json'))

batch_size = 4
num_samples = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_dict_list = []
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f:
with open(manifest_path, 'r', encoding='utf-8') as mfile:
for ix, line in enumerate(mfile):
if ix >= num_samples:
break

line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "")
f.write(f"{line}\n")
data_dict = json.loads(line)
data_dict_list.append(data_dict)

f.seek(0)
featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None)

dataset = AudioToSpeechE2ESpkDiarDataset(
manifest_filepath=f.name,
soft_label_thres=0.5,
session_len_sec=90,
num_spks=4,
featurizer=featurizer,
window_stride=0.01,
global_rank=0,
soft_targets=False,
device=device,
)
dataloader_instance = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
collate_fn=dataset.eesd_train_collate_fn,
drop_last=False,
shuffle=False,
num_workers=1,
pin_memory=False,
)
assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct
batch_counts = len(dataloader_instance)

deviation_thres_rate = 0.01 # 1% deviation allowed
for batch_index, batch in enumerate(dataloader_instance):
if batch_index != batch_counts - 1:
assert len(batch) == batch_size, "Batch size does not match the expected value"
audio_signals, audio_signal_len, targets, target_lens = batch
for sample_index in range(audio_signals.shape[0]):
dataloader_audio_in_sec = audio_signal_len[sample_index].item()
data_dur_in_sec = abs(
data_dict_list[batch_size * batch_index + sample_index]['duration'] * featurizer.sample_rate
- dataloader_audio_in_sec
)
assert (
data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec
), "Duration deviation exceeds 1%"
assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values"
assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values"
assert not torch.isnan(targets).any(), "targets tensor contains NaN values"
assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values"
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from omegaconf import DictConfig

from nemo.collections.asr.losses import BCELoss
from nemo.collections.asr.models import EncDecDiarLabelModel


Expand All @@ -24,7 +25,12 @@ def msdd_model():

preprocessor = {
'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
'params': {"features": 80, "window_size": 0.025, "window_stride": 0.01, "sample_rate": 16000,},
'params': {
"features": 80,
"window_size": 0.025,
"window_stride": 0.01,
"sample_rate": 16000,
},
}

speaker_model_encoder = {
Expand Down Expand Up @@ -165,3 +171,37 @@ def test_forward_infer(self, msdd_model):
assert diff <= 1e-6
diff = torch.max(torch.abs(scale_weights_instance - scale_weights_batch))
assert diff <= 1e-6


class TestBCELoss:
@pytest.mark.unit
@pytest.mark.parametrize(
"probs, labels, target_lens, reduction, expected_output",
[
(
torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32),
torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32),
torch.tensor([2]),
"mean",
torch.tensor(0.693147, dtype=torch.float32),
),
(
torch.tensor([[[0.5, 0.5], [0.0, 1.0]]], dtype=torch.float32),
torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32),
torch.tensor([1]),
"mean",
torch.tensor(0.693147, dtype=torch.float32),
),
(
torch.tensor([[[0, 1], [1, 0]]], dtype=torch.float32),
torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32),
torch.tensor([2]),
"mean",
torch.tensor(100, dtype=torch.float32),
),
],
)
def test_loss(self, probs, labels, target_lens, reduction, expected_output):
loss = BCELoss(reduction=reduction)
result = loss(probs=probs, labels=labels, target_lens=target_lens)
assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}"
Loading

0 comments on commit 523936e

Please sign in to comment.