Skip to content

Commit

Permalink
chore: Shared Permutation+Lookup relation arithmetic (#559)
Browse files Browse the repository at this point in the history
Co-authored-by: ledwards2225 <[email protected]>
  • Loading branch information
zac-williamson and ledwards2225 authored Jul 21, 2023
1 parent af91bc8 commit 1672005
Show file tree
Hide file tree
Showing 23 changed files with 564 additions and 510 deletions.
7 changes: 3 additions & 4 deletions cpp/src/barretenberg/honk/composer/standard_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ namespace proof_system::honk {
* @tparam Program settings needed to establish if w_4 is being used.
* */
template <StandardFlavor Flavor>
void StandardComposer_<Flavor>::compute_witness(const CircuitBuilder& circuit_constructor,
const size_t minimum_circuit_size)
void StandardComposer_<Flavor>::compute_witness(const CircuitBuilder& circuit_constructor, const size_t /*unused*/)
{
if (computed_witness) {
return;
Expand Down Expand Up @@ -72,7 +71,7 @@ std::shared_ptr<typename Flavor::ProvingKey> StandardComposer_<Flavor>::compute_
* */
template <StandardFlavor Flavor>
std::shared_ptr<typename Flavor::VerificationKey> StandardComposer_<Flavor>::compute_verification_key(
const CircuitBuilder& circuit_constructor)
const CircuitBuilder& /*unused*/)
{
if (verification_key) {
return verification_key;
Expand Down Expand Up @@ -124,7 +123,7 @@ StandardProver_<Flavor> StandardComposer_<Flavor>::create_prover(const CircuitBu
compute_proving_key(circuit_constructor);
compute_witness(circuit_constructor);

compute_commitment_key(proving_key->circuit_size, crs_factory_);
compute_commitment_key(proving_key->circuit_size);

StandardProver_<Flavor> output_state(proving_key, commitment_key);

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/barretenberg/honk/composer/standard_composer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ template <StandardFlavor Flavor> class StandardComposer_ {

void compute_witness(const CircuitBuilder& circuit_constructor, const size_t minimum_circuit_size = 0);

void compute_commitment_key(size_t circuit_size, std::shared_ptr<srs::factories::CrsFactory> crs_factory)
void compute_commitment_key(size_t circuit_size)
{
commitment_key = std::make_shared<typename PCSParams::CommitmentKey>(circuit_size, crs_factory_);
};
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/barretenberg/honk/composer/ultra_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ void UltraComposer_<Flavor>::compute_circuit_size_parameters(CircuitBuilder& cir
lookups_size += table.lookup_gates.size();
}

const size_t num_gates = circuit_constructor.num_gates;
num_public_inputs = circuit_constructor.public_inputs.size();

// minimum circuit size due to the length of lookups plus tables
Expand Down Expand Up @@ -170,7 +169,7 @@ UltraProver_<Flavor> UltraComposer_<Flavor>::create_prover(CircuitBuilder& circu
compute_proving_key(circuit_constructor);
compute_witness(circuit_constructor);

compute_commitment_key(proving_key->circuit_size, crs_factory_);
compute_commitment_key(proving_key->circuit_size);

UltraProver_<Flavor> output_state(proving_key, commitment_key);

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/barretenberg/honk/composer/ultra_composer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ template <UltraFlavor Flavor> class UltraComposer_ {

void add_table_column_selector_poly_to_proving_key(polynomial& small, const std::string& tag);

void compute_commitment_key(size_t circuit_size, std::shared_ptr<srs::factories::CrsFactory> crs_factory)
void compute_commitment_key(size_t circuit_size)
{
commitment_key = std::make_shared<typename PCSParams::CommitmentKey>(circuit_size, crs_factory_);
};
Expand Down
1 change: 1 addition & 0 deletions cpp/src/barretenberg/honk/flavor/standard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Standard {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 4;

using GrandProductRelations = std::tuple<sumcheck::PermutationRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::ArithmeticRelation<FF>, sumcheck::PermutationRelation<FF>>;

Expand Down
2 changes: 2 additions & 0 deletions cpp/src/barretenberg/honk/flavor/standard_grumpkin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class StandardGrumpkin {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 4;

// define the tuple of Relations that require grand products
using GrandProductRelations = std::tuple<sumcheck::PermutationRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::ArithmeticRelation<FF>, sumcheck::PermutationRelation<FF>>;

Expand Down
1 change: 1 addition & 0 deletions cpp/src/barretenberg/honk/flavor/ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Ultra {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 11;

using GrandProductRelations = std::tuple<sumcheck::UltraPermutationRelation<FF>, sumcheck::LookupRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::UltraArithmeticRelation<FF>,
sumcheck::UltraPermutationRelation<FF>,
Expand Down
1 change: 1 addition & 0 deletions cpp/src/barretenberg/honk/flavor/ultra_grumpkin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class UltraGrumpkin {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 11;

using GrandProductRelations = std::tuple<sumcheck::UltraPermutationRelation<FF>, sumcheck::LookupRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::UltraArithmeticRelation<FF>,
sumcheck::UltraPermutationRelation<FF>,
Expand Down
166 changes: 166 additions & 0 deletions cpp/src/barretenberg/honk/proof_system/grand_product_library.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#pragma once
#include "barretenberg/honk/sumcheck/sumcheck.hpp"
#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/polynomials/polynomial.hpp"
#include <typeinfo>

namespace proof_system::honk::grand_product_library {

// TODO(luke): This contains utilities for grand product computation and is not specific to the permutation grand
// product. Update comments accordingly.
/**
* @brief Compute a permutation grand product polynomial Z_perm(X)
* *
* @details
* Z_perm may be defined in terms of its values on X_i = 0,1,...,n-1 as Z_perm[0] = 1 and for i = 1:n-1
* relation::numerator(j)
* Z_perm[i] = ∏ --------------------------------------------------------------------------------
* relation::denominator(j)
*
* where ∏ := ∏_{j=0:i-1}
*
* The specific algebraic relation used by Z_perm is defined by Flavor::GrandProductRelations
*
* For example, in Flavor::Standard the relation describes:
*
* (w_1(j) + β⋅id_1(j) + γ) ⋅ (w_2(j) + β⋅id_2(j) + γ) ⋅ (w_3(j) + β⋅id_3(j) + γ)
* Z_perm[i] = ∏ --------------------------------------------------------------------------------
* (w_1(j) + β⋅σ_1(j) + γ) ⋅ (w_2(j) + β⋅σ_2(j) + γ) ⋅ (w_3(j) + β⋅σ_3(j) + γ)
* where ∏ := ∏_{j=0:i-1} and id_i(X) = id(X) + n*(i-1)
*
* For Flavor::Ultra both the UltraPermutation and Lookup grand products are computed by this method.
*
* The grand product is constructed over the course of three steps.
*
* For expositional simplicity, write Z_perm[i] as
*
* A(j)
* Z_perm[i] = ∏ --------------------------
* B(h)
*
* Step 1) Compute 2 length-n polynomials A, B
* Step 2) Compute 2 length-n polynomials numerator = ∏ A(j), nenominator = ∏ B(j)
* Step 3) Compute Z_perm[i + 1] = numerator[i] / denominator[i] (recall: Z_perm[0] = 1)
*
* Note: Step (3) utilizes Montgomery batch inversion to replace n-many inversions with
*/
template <typename Flavor, typename GrandProdRelation>
void compute_grand_product(const size_t circuit_size,
auto& full_polynomials,
sumcheck::RelationParameters<typename Flavor::FF>& relation_parameters)
{
using FF = typename Flavor::FF;
using Polynomial = typename Flavor::Polynomial;
using ValueAccumTypes = typename GrandProdRelation::ValueAccumTypes;

// Allocate numerator/denominator polynomials that will serve as scratch space
// TODO(zac) we can re-use the permutation polynomial as the numerator polynomial. Reduces readability
Polynomial numerator = Polynomial{ circuit_size };
Polynomial denominator = Polynomial{ circuit_size };

// Step (1)
// Populate `numerator` and `denominator` with the algebra described by Relation
const size_t num_threads = circuit_size >= get_num_cpus_pow2() ? get_num_cpus_pow2() : 1;
const size_t block_size = circuit_size / num_threads;
parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx + 1) * block_size;
for (size_t i = start; i < end; ++i) {

typename Flavor::ClaimedEvaluations evaluations;
for (size_t k = 0; k < Flavor::NUM_ALL_ENTITIES; ++k) {
evaluations[k] = full_polynomials[k].size() > i ? full_polynomials[k][i] : 0;
}
numerator[i] = GrandProdRelation::template compute_grand_product_numerator<ValueAccumTypes>(
evaluations, relation_parameters, i);
denominator[i] = GrandProdRelation::template compute_grand_product_denominator<ValueAccumTypes>(
evaluations, relation_parameters, i);
}
});

// Step (2)
// Compute the accumulating product of the numerator and denominator terms.
// This step is split into three parts for efficient multithreading:
// (i) compute ∏ A(j), ∏ B(j) subproducts for each thread
// (ii) compute scaling factor required to convert each subproduct into a single running product
// (ii) combine subproducts into a single running product
//
// For example, consider 4 threads and a size-8 numerator { a0, a1, a2, a3, a4, a5, a6, a7 }
// (i) Each thread computes 1 element of N = {{ a0, a0a1 }, { a2, a2a3 }, { a4, a4a5 }, { a6, a6a7 }}
// (ii) Take partial products P = { 1, a0a1, a2a3, a4a5 }
// (iii) Each thread j computes N[i][j]*P[j]=
// {{a0,a0a1},{a0a1a2,a0a1a2a3},{a0a1a2a3a4,a0a1a2a3a4a5},{a0a1a2a3a4a5a6,a0a1a2a3a4a5a6a7}}
std::vector<FF> partial_numerators(num_threads);
std::vector<FF> partial_denominators(num_threads);

parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx + 1) * block_size;
for (size_t i = start; i < end - 1; ++i) {
numerator[i + 1] *= numerator[i];
denominator[i + 1] *= denominator[i];
}
partial_numerators[thread_idx] = numerator[end - 1];
partial_denominators[thread_idx] = denominator[end - 1];
});

parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx + 1) * block_size;
if (thread_idx > 0) {
FF numerator_scaling = 1;
FF denominator_scaling = 1;

for (size_t j = 0; j < thread_idx; ++j) {
numerator_scaling *= partial_numerators[j];
denominator_scaling *= partial_denominators[j];
}
for (size_t i = start; i < end; ++i) {
numerator[i] *= numerator_scaling;
denominator[i] *= denominator_scaling;
}
}

// Final step: invert denominator
FF::batch_invert(std::span{ &denominator[start], block_size });
});

// Step (3) Compute z_perm[i] = numerator[i] / denominator[i]
auto& grand_product_polynomial = GrandProdRelation::get_grand_product_polynomial(full_polynomials);
grand_product_polynomial[0] = 0;
parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx == num_threads - 1) ? circuit_size - 1 : (thread_idx + 1) * block_size;
for (size_t i = start; i < end; ++i) {
grand_product_polynomial[i + 1] = numerator[i] * denominator[i];
}
});
}

template <typename Flavor>
void compute_grand_products(std::shared_ptr<typename Flavor::ProvingKey>& key,
typename Flavor::ProverPolynomials& full_polynomials,
sumcheck::RelationParameters<typename Flavor::FF>& relation_parameters)
{
using GrandProductRelations = typename Flavor::GrandProductRelations;
using FF = typename Flavor::FF;

constexpr size_t NUM_RELATIONS = std::tuple_size<GrandProductRelations>{};
barretenberg::constexpr_for<0, NUM_RELATIONS, 1>([&]<size_t i>() {
using GrandProdRelation = typename std::tuple_element<i, GrandProductRelations>::type;

// Assign the grand product polynomial to the relevant std::span member of `full_polynomials` (and its shift)
// For example, for UltraPermutationRelation, this will be `full_polynomials.z_perm`
// For example, for LookupRelation, this will be `full_polynomials.z_lookup`
std::span<FF>& full_polynomial = GrandProdRelation::get_grand_product_polynomial(full_polynomials);
auto& key_polynomial = GrandProdRelation::get_grand_product_polynomial(*key);
full_polynomial = key_polynomial;

compute_grand_product<Flavor, GrandProdRelation>(key->circuit_size, full_polynomials, relation_parameters);
std::span<FF>& full_polynomial_shift =
GrandProdRelation::get_shifted_grand_product_polynomial(full_polynomials);
full_polynomial_shift = key_polynomial.shifted();
});
}

} // namespace proof_system::honk::grand_product_library
6 changes: 2 additions & 4 deletions cpp/src/barretenberg/honk/proof_system/prover.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "prover.hpp"
#include "barretenberg/honk/proof_system/grand_product_library.hpp"
#include "barretenberg/honk/proof_system/prover_library.hpp"
#include "barretenberg/honk/sumcheck/sumcheck.hpp"
#include "barretenberg/honk/transcript/transcript.hpp"
Expand Down Expand Up @@ -112,12 +113,9 @@ template <StandardFlavor Flavor> void StandardProver_<Flavor>::execute_grand_pro
.public_input_delta = public_input_delta,
};

key->z_perm = prover_library::compute_permutation_grand_product<Flavor>(key, beta, gamma);
grand_product_library::compute_grand_products<Flavor>(key, prover_polynomials, relation_parameters);

queue.add_commitment(key->z_perm, commitment_labels.z_perm);

prover_polynomials.z_perm = key->z_perm;
prover_polynomials.z_perm_shift = key->z_perm.shifted();
}

/**
Expand Down
Loading

0 comments on commit 1672005

Please sign in to comment.