Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

builtin cost rework with fixes and entry point info #875

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ rstest = "0.23.0"
test-case = "3.3"
walkdir = "2.5.0"
serde_json = { version = "1.0.128" }
rayon = "1.10.0"

[build-dependencies]
cc = "1.1.28"
Expand Down
5 changes: 5 additions & 0 deletions programs/benches/factorial_2M.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ typedef struct factorial_return_values
} result;
} factorial_return_values_t;

extern uint64_t* builtin_costs;

static void run_bench(factorial_return_values_t*, uint64_t)
__attribute__((weakref("_mlir_ciface_factorial_2M::factorial_2M::main(f1)")));
Expand All @@ -25,6 +26,10 @@ int main()
{
factorial_return_values_t return_values;

uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604};

builtin_costs = &BuiltinCosts[0];

run_bench(&return_values, 0);
assert(return_values.result.discriminant == 0);

Expand Down
5 changes: 5 additions & 0 deletions programs/benches/fib_2M.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ typedef struct fib_return_values
} result;
} fib_return_values_t;

extern uint64_t* builtin_costs;

static void run_bench(fib_return_values_t *, uint64_t)
__attribute__((weakref("_mlir_ciface_fib_2M::fib_2M::main(f1)")));


int main()
{
uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604};

builtin_costs = &BuiltinCosts[0];

fib_return_values_t return_values;

run_bench(&return_values, 0);
Expand Down
5 changes: 5 additions & 0 deletions programs/benches/logistic_map.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ typedef struct map_return_values
} result;
} map_return_values_t;

extern uint64_t* builtin_costs;

static void run_bench(map_return_values_t *, uint64_t)
__attribute__((weakref("_mlir_ciface_logistic_map::logistic_map::main(f2)")));


int main()
{
uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604};

builtin_costs = &BuiltinCosts[0];

map_return_values_t return_values;

run_bench(&return_values, 0);
Expand Down
29 changes: 27 additions & 2 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ use melior::{
arith::CmpiPredicate,
cf, func, index,
llvm::{self, LoadStoreOptions},
memref,
memref, ods,
},
ir::{
attribute::{
Expand Down Expand Up @@ -135,6 +135,31 @@ pub fn compile(
}
}

{
// Add the builtin_costs global.
// We always add it because symbol look up otherwise can panic.
let region = Region::new();
let location = Location::unknown(context);
let block = region.append_block(Block::new(&[]));
let value = block.append_op_result(
ods::llvm::mlir_zero(context, llvm::r#type::pointer(context, 0), location).into(),
)?;
block.append_operation(melior::dialect::llvm::r#return(Some(value), location));

module.body().append_operation({
let op = ods::llvm::mlir_global(
context,
region,
TypeAttribute::new(llvm::r#type::pointer(context, 0)),
StringAttribute::new(context, "builtin_costs"),
Attribute::parse(context, "#llvm.linkage<external>").unwrap(),
location,
);

op.into()
});
}

// Sierra programs have the following structure:
// 1. Type declarations, one per line.
// 2. Libfunc declarations, one per line.
Expand Down Expand Up @@ -446,7 +471,7 @@ fn compile_func(
initial_state,
|statement_idx, mut state| {
if let Some(gas_metadata) = metadata.get::<GasMetadata>() {
let gas_cost = gas_metadata.get_gas_cost_for_statement(statement_idx);
let gas_cost = gas_metadata.get_gas_costs_for_statement(statement_idx);
metadata.remove::<GasCost>();
metadata.insert(GasCost(gas_cost));
}
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ pub enum Error {
#[error("integer conversion failed")]
IntegerConversion,

#[error("missing BuiltinCosts global symbol, should never happen, this is a bug")]
MissingBuiltinCostsSymbol,

#[error(transparent)]
IoError(#[from] std::io::Error),

Expand Down
69 changes: 49 additions & 20 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
execution_result::{BuiltinStats, ExecutionResult},
starknet::{handler::StarknetSyscallHandlerCallbacks, StarknetSyscallHandler},
types::TypeBuilder,
utils::{libc_free, RangeExt},
utils::{libc_free, BuiltinCosts, RangeExt},
values::Value,
};
use bumpalo::Bump;
Expand Down Expand Up @@ -69,6 +69,7 @@ extern "C" {
fn invoke_dynamic(
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
function_ptr: *const c_void,
builtin_costs_ptr: Option<*mut c_void>,
function_signature: &FunctionSignature,
args: &[Value],
gas: u128,
Expand Down Expand Up @@ -141,6 +142,15 @@ fn invoke_dynamic(
previous_syscall_handler
});

// Order matters, for the libfunc impl
let builtin_costs: [u64; 7] = BuiltinCosts::default().into();

if let Some(builtin_costs_ptr) = builtin_costs_ptr {
unsafe {
*builtin_costs_ptr.cast() = builtin_costs.as_ptr();
}
}

// Generate argument list.
let mut iter = args.iter();
for item in function_signature.param_types.iter().filter_map(|type_id| {
Expand All @@ -166,6 +176,13 @@ fn invoke_dynamic(
(syscall_handler as *mut StarknetSyscallHandlerCallbacks<_>)
.to_bytes(&mut invoke_data)?;
}
CoreTypeConcrete::BuiltinCosts(_) => {
if let Some(builtin_costs_ptr) = builtin_costs_ptr {
builtin_costs_ptr.to_bytes(&mut invoke_data)?;
} else {
(builtin_costs.as_ptr()).to_bytes(&mut invoke_data)?;
}
}
type_info if type_info.is_builtin() => 0u64.to_bytes(&mut invoke_data)?,
type_info => ValueWithInfoWrapper {
value: iter.next().unwrap(),
Expand Down Expand Up @@ -250,26 +267,38 @@ fn invoke_dynamic(
}
_ if type_info.is_builtin() => {
if !type_info.is_zst(registry)? {
let value = match &mut return_ptr {
Some(return_ptr) => unsafe { *read_value::<u64>(return_ptr) },
None => ret_registers[0],
} as usize;

match type_info {
CoreTypeConcrete::Bitwise(_) => builtin_stats.bitwise = value,
CoreTypeConcrete::EcOp(_) => builtin_stats.ec_op = value,
CoreTypeConcrete::RangeCheck(_) => builtin_stats.range_check = value,
CoreTypeConcrete::Pedersen(_) => builtin_stats.pedersen = value,
CoreTypeConcrete::Poseidon(_) => builtin_stats.poseidon = value,
CoreTypeConcrete::SegmentArena(_) => builtin_stats.segment_arena = value,
CoreTypeConcrete::RangeCheck96(_) => builtin_stats.range_check_96 = value,
CoreTypeConcrete::Circuit(CircuitTypeConcrete::AddMod(_)) => {
builtin_stats.circuit_add = value
}
CoreTypeConcrete::Circuit(CircuitTypeConcrete::MulMod(_)) => {
builtin_stats.circuit_mul = value
if let CoreTypeConcrete::BuiltinCosts(_) = type_info {
// todo: should we use this value?
let _value = match &mut return_ptr {
Some(return_ptr) => unsafe { *read_value::<*mut u64>(return_ptr) },
None => ret_registers[0] as *mut u64,
};
} else {
let value = match &mut return_ptr {
Some(return_ptr) => unsafe { *read_value::<u64>(return_ptr) },
None => ret_registers[0],
} as usize;

match type_info {
CoreTypeConcrete::Bitwise(_) => builtin_stats.bitwise = value,
CoreTypeConcrete::EcOp(_) => builtin_stats.ec_op = value,
CoreTypeConcrete::RangeCheck(_) => builtin_stats.range_check = value,
CoreTypeConcrete::Pedersen(_) => builtin_stats.pedersen = value,
CoreTypeConcrete::Poseidon(_) => builtin_stats.poseidon = value,
CoreTypeConcrete::SegmentArena(_) => {
builtin_stats.segment_arena = value
}
CoreTypeConcrete::RangeCheck96(_) => {
builtin_stats.range_check_96 = value
}
CoreTypeConcrete::Circuit(CircuitTypeConcrete::AddMod(_)) => {
builtin_stats.circuit_add = value
}
CoreTypeConcrete::Circuit(CircuitTypeConcrete::MulMod(_)) => {
builtin_stats.circuit_mul = value
}
_ => unreachable!("{type_id:?}"),
}
_ => unreachable!("{type_id:?}"),
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/executor/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl AotNativeExecutor {
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id),
args,
available_gas,
Expand All @@ -103,6 +104,7 @@ impl AotNativeExecutor {
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id),
args,
available_gas,
Expand All @@ -125,6 +127,7 @@ impl AotNativeExecutor {
ContractExecutionResult::from_execution_result(super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id),
&[Value::Struct {
fields: vec![Value::Array(
Expand Down Expand Up @@ -152,6 +155,15 @@ impl AotNativeExecutor {
}
}

pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> {
unsafe {
self.library
.get::<*mut ()>(name.as_bytes())
.ok()
.map(|x| x.into_raw().into_raw())
}
}

fn extract_signature(&self, function_id: &FunctionId) -> &FunctionSignature {
&self.registry.get_function(function_id).unwrap().signature
}
Expand Down
Loading
Loading