Skip to content

Commit

Permalink
Actually set the TLS info based on the binary arguments (#1680)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1680

This completes the connection of the TLS feature to the TLS implementation at the binary layer.

- Created a util function to populate the TlsInfo struct based on the arguments passed in
- Modified the main file of each app to create a TlsInfo struct
- passed the TlsInfo into the util functions

Reviewed By: RuiyuZhu

Differential Revision: D39643733

fbshipit-source-id: 9eb2f2fd82c4d06855ff4ea96c83579795e104f2
  • Loading branch information
adshastri authored and facebook-github-bot committed Oct 4, 2022
1 parent 7ccd61d commit d3f93dc
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 52 deletions.
2 changes: 2 additions & 0 deletions fbpcs/emp_games/common/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#pragma once

#include <cstdlib>
#include <memory>
#include <sstream>

#include "folly/dynamic.h"
Expand Down
58 changes: 58 additions & 0 deletions fbpcs/emp_games/common/test/UtilTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include <cstdlib>
#include <tuple>
#include <vector>

#include "../Util.h"

namespace private_measurement {

TEST(UtilTest, TestGetTlsInfoFromArguments) {
auto tlsInfo = common::getTlsInfoFromArgs(
false,
"cert_path",
"server_cert_path",
"private_key_path",
"passphrase_path");

EXPECT_FALSE(tlsInfo.useTls);
EXPECT_STREQ(tlsInfo.rootCaCertPath.c_str(), "");
EXPECT_STREQ(tlsInfo.certPath.c_str(), "");
EXPECT_STREQ(tlsInfo.keyPath.c_str(), "");
EXPECT_STREQ(tlsInfo.passphrasePath.c_str(), "");

const char* home_dir = std::getenv("HOME");
if (home_dir == nullptr) {
home_dir = "";
}

std::string home_dir_string(home_dir);

tlsInfo = common::getTlsInfoFromArgs(
true,
"cert_path",
"server_cert_path",
"private_key_path",
"passphrase_path");

EXPECT_TRUE(tlsInfo.useTls);
EXPECT_STREQ(
tlsInfo.rootCaCertPath.c_str(), (home_dir_string + "/cert_path").c_str());
EXPECT_STREQ(
tlsInfo.certPath.c_str(),
(home_dir_string + "/server_cert_path").c_str());
EXPECT_STREQ(
tlsInfo.keyPath.c_str(), (home_dir_string + "/private_key_path").c_str());
EXPECT_STREQ(
tlsInfo.passphrasePath.c_str(),
(home_dir_string + "/passphrase_path").c_str());
}
} // namespace private_measurement
12 changes: 7 additions & 5 deletions fbpcs/emp_games/compactor/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h"
#include "fbpcf/io/api/FileIOWrappers.h"
#include "fbpcf/scheduler/LazySchedulerFactory.h"
#include "fbpcs/emp_games/common/Util.h"
#include "fbpcs/emp_games/compactor/AttributionOutput.h"
#include "fbpcs/emp_games/compactor/CompactorGame.h"
#include "fbpcs/performance_tools/CostEstimation.h"
Expand Down Expand Up @@ -101,11 +102,12 @@ int main(int argc, char** argv) {

XLOG(INFO) << "Creating communication agent factory\n";

fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo tlsInfo;
tlsInfo.certPath = "";
tlsInfo.keyPath = "";
tlsInfo.passphrasePath = "";
tlsInfo.useTls = false;
auto tlsInfo = common::getTlsInfoFromArgs(
FLAGS_use_tls,
FLAGS_ca_cert_path,
FLAGS_server_cert_path,
FLAGS_private_key_path,
"");

std::map<
int,
Expand Down
9 changes: 3 additions & 6 deletions fbpcs/emp_games/dotproduct/MainUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,14 @@ inline common::SchedulerStatistics startDotProductApp(
std::string& outFilePath,
int numFeatures,
int labelWidth,
bool debugMode) {
bool debugMode,
fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo&
tlsInfo) {
std::map<
int,
fbpcf::engine::communication::SocketPartyCommunicationAgentFactory::
PartyInfo>
partyInfos({{0, {serverIp, port}}, {1, {serverIp, port}}});
fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo tlsInfo;
tlsInfo.certPath = "";
tlsInfo.keyPath = "";
tlsInfo.passphrasePath = "";
tlsInfo.useTls = false;

auto metricCollector =
std::make_shared<fbpcf::util::MetricCollector>("dotproduct");
Expand Down
13 changes: 11 additions & 2 deletions fbpcs/emp_games/dotproduct/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ int main(int argc, char* argv[]) {
XLOGF(INFO, "Base output path: {}", FLAGS_output_base_path);

common::SchedulerStatistics schedulerStatistics;

auto tlsInfo = common::getTlsInfoFromArgs(
FLAGS_use_tls,
FLAGS_ca_cert_path,
FLAGS_server_cert_path,
FLAGS_private_key_path,
"");
try {
if (FLAGS_party == common::PUBLISHER) {
XLOG(INFO)
Expand All @@ -57,7 +64,8 @@ int main(int argc, char* argv[]) {
FLAGS_output_base_path,
FLAGS_num_features,
FLAGS_label_width,
FLAGS_debug);
FLAGS_debug,
tlsInfo);

} else if (FLAGS_party == common::PARTNER) {
XLOG(INFO)
Expand All @@ -71,7 +79,8 @@ int main(int argc, char* argv[]) {
FLAGS_output_base_path,
FLAGS_num_features,
FLAGS_label_width,
FLAGS_debug);
FLAGS_debug,
tlsInfo);
} else {
XLOGF(FATAL, "Invalid Party: {}", FLAGS_party);
}
Expand Down
21 changes: 10 additions & 11 deletions fbpcs/emp_games/pcf2_aggregation/MainUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFilesHelper(
std::string aggregationFormats,
std::vector<std::string>& inputSecretShareFilenames,
std::vector<std::string>& inputClearTextFilenames,
std::vector<std::string>& outputFilenames) {
std::vector<std::string>& outputFilenames,
fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo&
tlsInfo) {
// aggregate scheduler statistics across apps
common::SchedulerStatistics schedulerStatistics{
0, 0, 0, 0, folly::dynamic::object()};
Expand All @@ -77,13 +79,6 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFilesHelper(
{{0, {serverIp, port + index * 100}},
{1, {serverIp, port + index * 100}}});

fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo
tlsInfo;
tlsInfo.certPath = "";
tlsInfo.keyPath = "";
tlsInfo.passphrasePath = "";
tlsInfo.useTls = false;

auto metricCollector = std::make_shared<fbpcf::util::MetricCollector>(
"aggregation_metrics_for_thread_" + std::to_string(index));

Expand Down Expand Up @@ -126,7 +121,8 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFilesHelper(
aggregationFormats,
inputSecretShareFilenames,
inputClearTextFilenames,
outputFilenames);
outputFilenames,
tlsInfo);
schedulerStatistics.add(remainingStats);
}
}
Expand All @@ -146,7 +142,9 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFiles(
int16_t concurrency,
std::string serverIp,
int port,
std::string aggregationFormats) {
std::string aggregationFormats,
fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo&
tlsInfo) {
// use only as many threads as the number of files
auto numThreads =
std::min((int)inputSecretShareFilenames.size(), (int)concurrency);
Expand All @@ -162,7 +160,8 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFiles(
aggregationFormats,
inputSecretShareFilenames,
inputClearTextFilenames,
outputFilenames);
outputFilenames,
tlsInfo);
}

} // namespace pcf2_aggregation
13 changes: 11 additions & 2 deletions fbpcs/emp_games/pcf2_aggregation/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ int main(int argc, char* argv[]) {
inputEncryption = common::InputEncryption::Plaintext;
}

auto tlsInfo = common::getTlsInfoFromArgs(
FLAGS_use_tls,
FLAGS_ca_cert_path,
FLAGS_server_cert_path,
FLAGS_private_key_path,
"");

if (FLAGS_party == common::PUBLISHER) {
XLOGF(INFO, "Aggregation Format: {}", FLAGS_aggregators);

Expand All @@ -106,7 +113,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_aggregators);
FLAGS_aggregators,
tlsInfo);
} else if (FLAGS_party == common::PARTNER) {
XLOG(INFO)
<< "Starting private aggregation as Partner, will wait for Publisher...";
Expand All @@ -121,7 +129,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_aggregators);
FLAGS_aggregators,
tlsInfo);

} else {
XLOGF(FATAL, "Invalid Party: {}", FLAGS_party);
Expand Down
21 changes: 10 additions & 11 deletions fbpcs/emp_games/pcf2_attribution/MainUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFilesHelper(
int port,
std::string attributionRules,
std::vector<std::string>& inputFilenames,
std::vector<std::string>& outputFilenames) {
std::vector<std::string>& outputFilenames,
fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo&
tlsInfo) {
// aggregate scheduler statistics across apps
common::SchedulerStatistics schedulerStatistics{
0, 0, 0, 0, folly::dynamic::object()};
Expand All @@ -86,13 +88,6 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFilesHelper(
{{0, {serverIp, port + static_cast<int>(index) * 100}},
{1, {serverIp, port + static_cast<int>(index) * 100}}});

fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo
tlsInfo;
tlsInfo.certPath = "";
tlsInfo.keyPath = "";
tlsInfo.passphrasePath = "";
tlsInfo.useTls = false;

auto metricCollector = std::make_shared<fbpcf::util::MetricCollector>(
"attribution_metrics_for_thread_" + std::to_string(index));

Expand Down Expand Up @@ -133,7 +128,8 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFilesHelper(
port,
attributionRules,
inputFilenames,
outputFilenames);
outputFilenames,
tlsInfo);
schedulerStatistics.add(remainingStats);
}
}
Expand All @@ -150,7 +146,9 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFiles(
int16_t concurrency,
std::string serverIp,
int port,
std::string attributionRules) {
std::string attributionRules,
fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo&
tlsInfo) {
// use only as many threads as the number of files
auto numThreads =
std::min(static_cast<std::int16_t>(inputFilenames.size()), concurrency);
Expand All @@ -166,7 +164,8 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFiles(
port,
attributionRules,
inputFilenames,
outputFilenames);
outputFilenames,
tlsInfo);
}

} // namespace pcf2_attribution
25 changes: 19 additions & 6 deletions fbpcs/emp_games/pcf2_attribution/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ int main(int argc, char* argv[]) {
CHECK_LE(concurrency, pcf2_attribution::kMaxConcurrency)
<< "Concurrency must be at most " << pcf2_attribution::kMaxConcurrency;

auto tlsInfo = common::getTlsInfoFromArgs(
FLAGS_use_tls,
FLAGS_ca_cert_path,
FLAGS_server_cert_path,
FLAGS_private_key_path,
"");

if (FLAGS_party == common::PUBLISHER) {
XLOGF(INFO, "Attribution Rules: {}", FLAGS_attribution_rules);

Expand All @@ -81,7 +88,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_attribution_rules);
FLAGS_attribution_rules,
tlsInfo);
} else if (FLAGS_input_encryption == 2) {
schedulerStatistics =
pcf2_attribution::startAttributionAppsForShardedFiles<
Expand All @@ -93,7 +101,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_attribution_rules);
FLAGS_attribution_rules,
tlsInfo);
} else {
schedulerStatistics =
pcf2_attribution::startAttributionAppsForShardedFiles<
Expand All @@ -105,7 +114,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_attribution_rules);
FLAGS_attribution_rules,
tlsInfo);
}

} else if (FLAGS_party == common::PARTNER) {
Expand All @@ -123,7 +133,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_attribution_rules);
FLAGS_attribution_rules,
tlsInfo);
} else if (FLAGS_input_encryption == 2) {
schedulerStatistics =
pcf2_attribution::startAttributionAppsForShardedFiles<
Expand All @@ -135,7 +146,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_attribution_rules);
FLAGS_attribution_rules,
tlsInfo);

} else {
schedulerStatistics =
Expand All @@ -148,7 +160,8 @@ int main(int argc, char* argv[]) {
concurrency,
FLAGS_server_ip,
FLAGS_port,
FLAGS_attribution_rules);
FLAGS_attribution_rules,
tlsInfo);
}

} else {
Expand Down
Loading

0 comments on commit d3f93dc

Please sign in to comment.