Skip to content

Commit

Permalink
Add functionality to included label_dataset_id or label_file_path fea…
Browse files Browse the repository at this point in the history
…ture in each example

PiperOrigin-RevId: 575762101
  • Loading branch information
lami-genius authored and copybara-github committed Oct 23, 2023
1 parent 00b5f40 commit 2142b93
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/skai/cloud_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,15 +678,23 @@ def _merge_single_example_file_and_labels(
)

label_tuples = labels.get(example_id, [])
for string_label, numeric_label, dataset_id in label_tuples:
for string_label, numeric_label, dataset_id_or_label_path in label_tuples:
labeled_example = Example()
labeled_example.CopyFrom(example)
features = labeled_example.features
features.feature['string_label'].bytes_list.value[:] = [
string_label.encode()
]
features.feature['label_dataset_id'].bytes_list.value.append(
dataset_id.encode())

if tf.io.gfile.exists(dataset_id_or_label_path):
features.feature['label_file_path'].bytes_list.value.append(
dataset_id_or_label_path.encode()
)
else:
features.feature['label_dataset_id'].bytes_list.value.append(
dataset_id_or_label_path.encode()
)

label_feature = features.feature['label'].float_list
if not label_feature.value:
label_feature.value.append(numeric_label)
Expand Down Expand Up @@ -777,7 +785,7 @@ def _get_labels_from_dataset(
export_dir: GCS directory to export annotations to.
Returns:
List of (example id, string label) tuples.
List of (example id, string label, dataset id) tuples.
"""

aiplatform.init(project=project, location=location)
Expand Down Expand Up @@ -811,7 +819,7 @@ def _read_label_file(path: str) -> List[Tuple[str, str, str]]:
path: Path to file.
Returns:
List of (example id, string label) tuples.
List of (example id, string label, label file path) tuples.
"""
with tf.io.gfile.GFile(path) as f:
df = pd.read_csv(f)
Expand Down Expand Up @@ -879,15 +887,17 @@ def create_labeled_examples(

logging.info('Read %d labels total.', len(labels))
ids_to_labels = collections.defaultdict(list)
for example_id, string_label, dataset_id in labels:
for example_id, string_label, dataset_id_or_label_path in labels:
example_labels = ids_to_labels[example_id]
if string_label in [l[0] for l in example_labels]:
# Don't add multiple labels with the same value for a single example.
continue
numeric_label = string_to_numeric_map.get(string_label, None)
if numeric_label is None:
raise ValueError(f'Label "{string_label}" has no numeric mapping.')
example_labels.append((string_label, numeric_label, dataset_id))
example_labels.append(
(string_label, numeric_label, dataset_id_or_label_path)
)

_merge_examples_and_labels(
examples_pattern,
Expand Down

0 comments on commit 2142b93

Please sign in to comment.