Skip to content

Commit

Permalink
improve filtering indices (#673)
Browse files Browse the repository at this point in the history
* improve filter and more time measuring

* fix macro

* deploy

* revert deploy
  • Loading branch information
philsippl authored Nov 9, 2024
1 parent 3e4e757 commit 6535d81
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 39 deletions.
88 changes: 55 additions & 33 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ use std::{collections::HashMap, mem, sync::Arc, time::Instant};
use tokio::sync::{mpsc, oneshot};

macro_rules! record_stream_time {
($manager:expr, $streams:expr, $map:expr, $label:expr, $block:block) => {
($manager:expr, $streams:expr, $map:expr, $label:expr, $block:block) => {{
let evt0 = $manager.create_events();
let evt1 = $manager.create_events();
$manager.record_event($streams, &evt0);
$block
let res = $block;
$manager.record_event($streams, &evt1);
$map.entry($label).or_default().extend(vec![evt0, evt1])
};
$map.entry($label).or_default().extend(vec![evt0, evt1]);
res
}};
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -510,7 +511,6 @@ impl ServerActor {
///////////////////////////////////////////////////////////////////
// PERFORM DELETIONS (IF ANY)
///////////////////////////////////////////////////////////////////

if !batch.deletion_requests_indices.is_empty() {
tracing::info!("Performing deletions");
// Prepare dummy deletion shares
Expand Down Expand Up @@ -552,11 +552,13 @@ impl ServerActor {
///////////////////////////////////////////////////////////////////
// SYNC BATCH CONTENTS AND FILTER OUT INVALID ENTRIES
///////////////////////////////////////////////////////////////////
let tmp_now = Instant::now();
tracing::info!("Syncing batch entries");
let valid_entries = self.sync_batch_entries(&batch.valid_entries)?;
let valid_entry_idxs = valid_entries.iter().positions(|&x| x).collect::<Vec<_>>();
batch_size = valid_entry_idxs.len();
batch.retain(&valid_entry_idxs);
tracing::info!("Sync and filter done in {:?}", tmp_now.elapsed());

///////////////////////////////////////////////////////////////////
// COMPARE LEFT EYE QUERIES
Expand All @@ -571,20 +573,30 @@ impl ServerActor {
};
let query_store_left = batch.store_left;

// THIS needs to be max_batch_size, even though the query can be shorter to have
// enough padding for GEMM
let compact_device_queries_left = compact_query_left.htod_transfer(
let (compact_device_queries_left, compact_device_sums_left) = record_stream_time!(
&self.device_manager,
&self.streams[0],
self.max_batch_size,
)?;
events,
"query_preprocess",
{
// This needs to be max_batch_size, even though the query can be shorter to have
// enough padding for GEMM
let compact_device_queries_left = compact_query_left.htod_transfer(
&self.device_manager,
&self.streams[0],
self.max_batch_size,
)?;

let compact_device_sums_left = compact_device_queries_left.query_sums(
&self.codes_engine,
&self.masks_engine,
&self.streams[0],
&self.cublas_handles[0],
)?;
let compact_device_sums_left = compact_device_queries_left.query_sums(
&self.codes_engine,
&self.masks_engine,
&self.streams[0],
&self.cublas_handles[0],
)?;

(compact_device_queries_left, compact_device_sums_left)
}
);

tracing::info!("Comparing left eye queries against DB and self");
self.compare_query_against_db_and_self(
Expand All @@ -607,27 +619,37 @@ impl ServerActor {
};
let query_store_right = batch.store_right;

// THIS needs to be MAX_BATCH_SIZE, even though the query can be shorter to have
// enough padding for GEMM
let compact_device_queries_right = compact_query_right.htod_transfer(
let (compact_device_queries_right, compact_device_sums_right) = record_stream_time!(
&self.device_manager,
&self.streams[0],
self.max_batch_size,
)?;
events,
"query_preprocess",
{
// This needs to be MAX_BATCH_SIZE, even though the query can be shorter to have
// enough padding for GEMM
let compact_device_queries_right = compact_query_right.htod_transfer(
&self.device_manager,
&self.streams[0],
self.max_batch_size,
)?;

let compact_device_sums_right = compact_device_queries_right.query_sums(
&self.codes_engine,
&self.masks_engine,
&self.streams[0],
&self.cublas_handles[0],
)?;
let compact_device_sums_right = compact_device_queries_right.query_sums(
&self.codes_engine,
&self.masks_engine,
&self.streams[0],
&self.cublas_handles[0],
)?;

tracing::info!("Comparing right eye queries against DB and self");
self.compare_query_against_db_and_self(
&compact_device_queries_right,
&compact_device_sums_right,
&mut events,
Eye::Right,
);

tracing::info!("Comparing right eye queries against DB and self");
self.compare_query_against_db_and_self(
&compact_device_queries_right,
&compact_device_sums_right,
&mut events,
Eye::Right,
(compact_device_queries_right, compact_device_sums_right)
}
);

///////////////////////////////////////////////////////////////////
Expand Down
12 changes: 6 additions & 6 deletions iris-mpc-gpu/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,21 @@ macro_rules! filter_by_indices {
macro_rules! filter_by_indices_with_rotations {
($data:expr, $indices:expr) => {
$data = $data
.iter()
.chunks(ROTATIONS)
.enumerate()
.filter(|(i, _)| $indices.contains((&(i / ROTATIONS))))
.map(|(_, v)| v.clone())
.filter(|(i, _)| $indices.contains(i))
.flat_map(|(_, chunk)| chunk.iter().cloned())
.collect();
};
}

macro_rules! filter_by_indices_with_rotations_and_code_length {
($data:expr, $indices:expr, $code_length:expr) => {
$data = $data
.iter()
.chunks($code_length * ROTATIONS)
.enumerate()
.filter(|(i, _)| $indices.contains((&(i / ROTATIONS / $code_length))))
.map(|(_, v)| v.clone())
.filter(|(i, _)| $indices.contains(i))
.flat_map(|(_, chunk)| chunk.iter().cloned())
.collect();
};
}
Expand Down

0 comments on commit 6535d81

Please sign in to comment.