Skip to content

Commit

Permalink
Add functionality that avoids putting neighbouring examples into diff…
Browse files Browse the repository at this point in the history
…erent train/test split

PiperOrigin-RevId: 573779093
  • Loading branch information
lami-genius authored and copybara-github committed Oct 16, 2023
1 parent f9fde5c commit cacf3f8
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 41 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ tqdm
openlocationcode
xmanager
tensorflow-text
scipy
3 changes: 3 additions & 0 deletions src/create_labeled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
'If specified, random seed for train/test split.')
flags.DEFINE_float('test_fraction', 0.2,
'Fraction of labeled examples to use for testing.')
flags.DEFINE_float('connecting_distance_meters', 77.0,
'Maximum distance for two points to be connected.')
flags.DEFINE_list(
'string_to_numeric_labels',
[
Expand Down Expand Up @@ -75,6 +77,7 @@ def main(unused_argv):
FLAGS.test_fraction,
FLAGS.train_output_path,
FLAGS.test_output_path,
FLAGS.connecting_distance_meters,
FLAGS.use_multiprocessing)


Expand Down
170 changes: 150 additions & 20 deletions src/skai/cloud_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@
import queue
import random
import time
from typing import Dict, Iterable, List, Optional, Tuple, Set
from typing import Dict, Iterable, List, Optional, Set, Tuple

from absl import logging

import geopandas as gpd
from google.cloud import aiplatform
from google.cloud import aiplatform_v1
import numpy as np
import pandas as pd
import PIL.Image
import PIL.ImageDraw
import PIL.ImageFont
import scipy
from skai import utils
import tensorflow as tf

Expand Down Expand Up @@ -490,40 +492,158 @@ def _write_tfrecord(examples: Iterable[Example], path: str) -> None:
writer.write(example.SerializeToString())


def get_connection_matrix(
longitudes: List[float],
latitudes: List[float],
encoded_coordinates: List[str],
connecting_distance_meters: float,
)-> Tuple[gpd.GeoDataFrame, np.ndarray]:
"""Gets a connection matrix for a set of points.
Args:
longitudes: Longitudes of points.
latitudes: Latitudes of points.
encoded_coordinates: Encoded coordinates of points.
connecting_distance_meters: Maximum distance for two points to be connected.
Returns:
Tuple of (GeoDataFrame, connection_matrix).
"""
points = gpd.GeoSeries(gpd.points_from_xy(
longitudes,
latitudes,
)).set_crs(4326)

centroid = points.unary_union.centroid
utm_points = points.to_crs(utils.convert_wgs_to_utm(centroid.x, centroid.y))

gpd_df = gpd.GeoDataFrame(
{'encoded_coordinates': encoded_coordinates},
geometry=utm_points
)

def calculate_euclidean_distance(row):
return gpd_df.distance(row.geometry)

distances = np.array(gpd_df.apply(calculate_euclidean_distance, axis=1))
connection_matrix = (distances < connecting_distance_meters).astype('int')

assert connection_matrix.shape == (
len(encoded_coordinates),
len(encoded_coordinates),
)

return gpd_df, connection_matrix


def get_connected_labels(
connection_matrix: np.ndarray,
) -> List[str]:
"""Gets the labels of connected components.
Args:
connection_matrix: Connection matrix.
Returns:
List of labels of connected components. Components with the same label are
connected and are therefore connected.
"""
graph = scipy.sparse.csr_matrix(connection_matrix)
_, labels = scipy.sparse.csgraph.connected_components(
csgraph=graph, directed=False, return_labels=True
)

return list(labels)


def _split_examples(
examples: List[Example],
test_fraction: float
test_fraction: float,
connecting_distance_meters: float,
) -> Tuple[List[Example], List[Example]]:
"""Splits a list of examples into training and test sets.
Examples with the same encoded coordinates will always end up in the same
split to prevent leaking information between training and test sets.
split to prevent leaking information between training and test sets. Any two
examples separated by less than connecting_distance_meters will always be in
the same split.
Args:
examples: Input examples.
test_fraction: Fraction of examples to use for testing.
connecting_distance_meters: Maximum distance for two points to be connected.
Returns:
Tuple of (training examples, test examples).
"""
coordinates_to_examples = collections.defaultdict(list)
longitudes = []
latitudes = []
encoded_coordinates = []
for example in examples:
c = example.features.feature['encoded_coordinates'].bytes_list.value[0]
coordinates_to_examples[c].append(example)

shuffled = random.sample(sorted(coordinates_to_examples.keys()),
len(coordinates_to_examples))
num_test = int(len(shuffled) * test_fraction)
test_examples = []
for coordinate in shuffled[:num_test]:
test_examples.extend(coordinates_to_examples[coordinate])

train_examples = []
for coordinate in shuffled[num_test:]:
train_examples.extend(coordinates_to_examples[coordinate])
encoded_coordinate = utils.get_string_feature(
example, 'encoded_coordinates'
)
longitude, latitude = utils.get_float_feature(example, 'coordinates')
longitudes.append(longitude)
latitudes.append(latitude)
encoded_coordinates.append(encoded_coordinate)

gpd_df, connection_matrix = get_connection_matrix(
longitudes, latitudes, encoded_coordinates, connecting_distance_meters
)
labels = get_connected_labels(connection_matrix)
connected_groups = collections.defaultdict(list)
for idx, key in enumerate(labels):
connected_groups[key].append(idx)

list_of_connected_examples = []
for _, connected_group in connected_groups.items():
list_of_connected_examples.append(connected_group)

num_test = int(len(gpd_df) * test_fraction)
test_indices = get_testset_indices(num_test, list_of_connected_examples)
test_examples = [examples[idx] for idx in test_indices]
train_examples = [
examples[idx] for idx in range(len(examples)) if idx not in test_indices
]

return train_examples, test_examples


def get_testset_indices(num_test, list_of_connected_examples):
"""Get random list of indices corresponding to test examples.
Args:
num_test: Number of test examples.
list_of_connected_examples: List of connected examples.
Returns:
List of indices corresponding test examples.
"""
max_num_attempts_train_test_splits = 10000
best_test_indices = []
min_diff_best_num_test = num_test

for _ in range(max_num_attempts_train_test_splits):
# Ensure randomness
random_list_of_connected_examples = random.sample(
list_of_connected_examples, len(list_of_connected_examples)
)
current_test_indices = []

for connected_component in random_list_of_connected_examples:
current_test_indices.extend(connected_component)
if abs(len(current_test_indices) - num_test) < min_diff_best_num_test:
best_test_indices = current_test_indices.copy()
min_diff_best_num_test = abs(len(best_test_indices) - num_test)

# Stop trials once best best_test_indices is found
if min_diff_best_num_test == 0:
return best_test_indices

return best_test_indices


def _merge_single_example_file_and_labels(
example_file: str, labels: Dict[str, List[Tuple[str, float, str]]]
) -> List[Example]:
Expand Down Expand Up @@ -583,6 +703,7 @@ def _merge_examples_and_labels(
test_fraction: float,
train_output_path: str,
test_output_path: str,
connecting_distance_meters: float,
use_multiprocessing: bool,
) -> None:
"""Merges examples with labels into train and test TFRecords.
Expand All @@ -594,6 +715,7 @@ def _merge_examples_and_labels(
test_fraction: Fraction of examples to write to test output.
train_output_path: Path to training examples TFRecord output.
test_output_path: Path to test examples TFRecord output.
connecting_distance_meters: Maximum distance for two points to be connected.
use_multiprocessing: If true, create multiple processes to create labeled
examples.
"""
Expand Down Expand Up @@ -629,11 +751,16 @@ def _merge_examples_and_labels(
all_labeled_examples.extend(result)

train_examples, test_examples = _split_examples(
all_labeled_examples, test_fraction
all_labeled_examples, test_fraction, connecting_distance_meters
)

_write_tfrecord(train_examples, train_output_path)
_write_tfrecord(test_examples, test_output_path)
logging.info(
'Written %d test examples and %d train examples',
len(test_examples),
len(train_examples),
)


def _get_labels_from_dataset(
Expand Down Expand Up @@ -669,7 +796,6 @@ def _get_labels_from_dataset(
labels.update(_read_label_annotations_file(path))
tf.io.gfile.remove(path)

logging.info('Read %d labels total.', len(labels))
return [
(example_id, label, dataset_id) for example_id, label in labels.items()
]
Expand Down Expand Up @@ -709,6 +835,7 @@ def create_labeled_examples(
test_fraction: float,
train_output_path: str,
test_output_path: str,
connecting_distance_meters: float,
use_multiprocessing: bool) -> None:
"""Creates a labeled dataset by merging cloud labels and unlabeled examples.
Expand All @@ -724,6 +851,7 @@ def create_labeled_examples(
test_fraction: Fraction of examples to write to test output.
train_output_path: Path to training examples TFRecord output.
test_output_path: Path to test examples TFRecord output.
connecting_distance_meters: Maximum distance for two points to be connected.
use_multiprocessing: If true, create multiple processes to create labeled
examples.
"""
Expand All @@ -749,6 +877,7 @@ def create_labeled_examples(
for path in label_file_paths:
labels.extend(_read_label_file(path))

logging.info('Read %d labels total.', len(labels))
ids_to_labels = collections.defaultdict(list)
for example_id, string_label, dataset_id in labels:
example_labels = ids_to_labels[example_id]
Expand All @@ -766,5 +895,6 @@ def create_labeled_examples(
test_fraction,
train_output_path,
test_output_path,
connecting_distance_meters,
use_multiprocessing,
)
Loading

0 comments on commit cacf3f8

Please sign in to comment.