From 2142b93dc46fe5ce13977b23160af8dcd834f6b2 Mon Sep 17 00:00:00 2001 From: Boris Lami Fonyuy Date: Mon, 23 Oct 2023 03:14:37 -0700 Subject: [PATCH] Add functionality to included label_dataset_id or label_file_path feature in each example PiperOrigin-RevId: 575762101 --- src/skai/cloud_labeling.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/skai/cloud_labeling.py b/src/skai/cloud_labeling.py index 53587b19..4f6ed9ce 100644 --- a/src/skai/cloud_labeling.py +++ b/src/skai/cloud_labeling.py @@ -678,15 +678,23 @@ def _merge_single_example_file_and_labels( ) label_tuples = labels.get(example_id, []) - for string_label, numeric_label, dataset_id in label_tuples: + for string_label, numeric_label, dataset_id_or_label_path in label_tuples: labeled_example = Example() labeled_example.CopyFrom(example) features = labeled_example.features features.feature['string_label'].bytes_list.value[:] = [ string_label.encode() ] - features.feature['label_dataset_id'].bytes_list.value.append( - dataset_id.encode()) + + if tf.io.gfile.exists(dataset_id_or_label_path): + features.feature['label_file_path'].bytes_list.value.append( + dataset_id_or_label_path.encode() + ) + else: + features.feature['label_dataset_id'].bytes_list.value.append( + dataset_id_or_label_path.encode() + ) + label_feature = features.feature['label'].float_list if not label_feature.value: label_feature.value.append(numeric_label) @@ -777,7 +785,7 @@ def _get_labels_from_dataset( export_dir: GCS directory to export annotations to. Returns: - List of (example id, string label) tuples. + List of (example id, string label, dataset id) tuples. """ aiplatform.init(project=project, location=location) @@ -811,7 +819,7 @@ def _read_label_file(path: str) -> List[Tuple[str, str, str]]: path: Path to file. Returns: - List of (example id, string label) tuples. + List of (example id, string label, label file path) tuples. """ with tf.io.gfile.GFile(path) as f: df = pd.read_csv(f) @@ -879,7 +887,7 @@ def create_labeled_examples( logging.info('Read %d labels total.', len(labels)) ids_to_labels = collections.defaultdict(list) - for example_id, string_label, dataset_id in labels: + for example_id, string_label, dataset_id_or_label_path in labels: example_labels = ids_to_labels[example_id] if string_label in [l[0] for l in example_labels]: # Don't add multiple labels with the same value for a single example. @@ -887,7 +895,9 @@ def create_labeled_examples( numeric_label = string_to_numeric_map.get(string_label, None) if numeric_label is None: raise ValueError(f'Label "{string_label}" has no numeric mapping.') - example_labels.append((string_label, numeric_label, dataset_id)) + example_labels.append( + (string_label, numeric_label, dataset_id_or_label_path) + ) _merge_examples_and_labels( examples_pattern,