Skip to content

Commit

Permalink
renamed get_current_handler -> get_rand_handler. changed get_closest_…
Browse files Browse the repository at this point in the history
…lat_lon to use min(dist) instead of kdtree. kept as static method since it is called by class methods.
  • Loading branch information
bnb32 committed Sep 12, 2023
1 parent 0cf4b41 commit 6e84e7a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 96 deletions.
20 changes: 9 additions & 11 deletions sup3r/preprocessing/batch_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def get_handler_index(self):
indices = np.arange(0, len(self.data_handlers))
return np.random.choice(indices, p=self.handler_weights)

def get_current_handler(self):
def get_rand_handler(self):
"""Get random handler based on handler weights"""
self.current_handler_index = self.get_handler_index()
return self.data_handlers[self.current_handler_index]
Expand Down Expand Up @@ -952,7 +952,7 @@ def __next__(self):
"""
self.current_batch_indices = []
if self._i < self.n_batches:
handler = self.get_current_handler()
handler = self.get_rand_handler()
high_res = np.zeros(
(self.batch_size, self.sample_shape[0], self.sample_shape[1],
self.sample_shape[2], self.shape[-1]),
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def __next__(self):
if self._i >= self.n_batches:
raise StopIteration

handler = self.get_current_handler()
handler = self.get_rand_handler()

low_res = None
high_res = None
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def __next__(self):
if self._i >= self.n_batches:
raise StopIteration

handler = self.get_current_handler()
handler = self.get_rand_handler()

high_res = None

Expand Down Expand Up @@ -1178,7 +1178,7 @@ class SpatialBatchHandler(BatchHandler):

def __next__(self):
if self._i < self.n_batches:
handler = self.get_current_handler()
handler = self.get_rand_handler()
high_res = np.zeros((self.batch_size, self.sample_shape[0],
self.sample_shape[1], self.shape[-1]),
dtype=np.float32)
Expand Down Expand Up @@ -1255,10 +1255,8 @@ def _get_val_indices(self):
np.arange(h.data.shape[-1])
])
val_indices[s + self.N_TIME_BINS].append({
'handler_index':
h_idx,
'tuple_index':
tuple_index
'handler_index': h_idx,
'tuple_index': tuple_index
})
return val_indices

Expand Down Expand Up @@ -1370,7 +1368,7 @@ def __iter__(self):
def __next__(self):
self.current_batch_indices = []
if self._i < self.n_batches:
handler = self.get_current_handler()
handler = self.get_rand_handler()
high_res = np.zeros(
(self.batch_size, self.sample_shape[0], self.sample_shape[1],
self.sample_shape[2], self.shape[-1]),
Expand Down Expand Up @@ -1458,7 +1456,7 @@ def __iter__(self):
def __next__(self):
self.current_batch_indices = []
if self._i < self.n_batches:
handler = self.get_current_handler()
handler = self.get_rand_handler()
high_res = np.zeros((self.batch_size, self.sample_shape[0],
self.sample_shape[1], self.shape[-1],
),
Expand Down
Loading

0 comments on commit 6e84e7a

Please sign in to comment.