From 98226ea2a07c7da8f8037fc7641d45117df6b94b Mon Sep 17 00:00:00 2001 From: Ekaterina Broslavskaya Date: Thu, 17 Aug 2023 14:40:38 +0300 Subject: [PATCH] New feature release (#27) Add support for NTT, Update configs initialisation, fix module visibility --- .gitignore | 1 + Cargo.toml | 46 +++--- README.md | 86 ++++++++--- benches/ntt_bench.rs | 54 +++++++ configs/c1100_cfg.json | 12 -- configs/msm_bls377_dma_cfg.json | 7 - configs/msm_bls377_hbm_cfg.json | 6 - configs/msm_bls381_dma_cfg.json | 7 - configs/msm_bls381_hbm_cfg.json | 6 - configs/msm_bn254_dma_cfg.json | 7 - src/driver_client/dclient.rs | 149 ++++++++----------- src/driver_client/dclient_cfg.rs | 47 ++++++ src/driver_client/dclient_code.rs | 3 +- src/driver_client/mod.rs | 9 +- src/error.rs | 2 + src/ingo_hash/mod.rs | 9 +- src/ingo_hash/poseidon_api.rs | 32 ++-- src/ingo_msm/mod.rs | 8 +- src/ingo_msm/msm_api.rs | 77 ++-------- src/ingo_msm/msm_cfg.rs | 92 ++++++++++++ src/ingo_ntt/mod.rs | 5 + src/ingo_ntt/ntt_api.rs | 125 ++++++++++++++++ src/ingo_ntt/ntt_data.rs | 232 +++++++++++++++++++++++++++++ src/ingo_ntt/ntt_hw_code.rs | 89 ++++++++++++ src/lib.rs | 1 + src/utils.rs | 18 --- tests/integration_msm.rs | 233 ++++++++++++++---------------- tests/integration_msm_hbm.rs | 56 ++++--- tests/integration_ntt.rs | 146 +++++++++++++++++++ tests/integration_poseidon.rs | 16 +- 30 files changed, 1140 insertions(+), 441 deletions(-) create mode 100644 benches/ntt_bench.rs delete mode 100644 configs/c1100_cfg.json delete mode 100644 configs/msm_bls377_dma_cfg.json delete mode 100644 configs/msm_bls377_hbm_cfg.json delete mode 100644 configs/msm_bls381_dma_cfg.json delete mode 100644 configs/msm_bls381_hbm_cfg.json delete mode 100644 configs/msm_bn254_dma_cfg.json create mode 100644 src/driver_client/dclient_cfg.rs create mode 100644 src/ingo_msm/msm_cfg.rs create mode 100644 src/ingo_ntt/mod.rs create mode 100644 src/ingo_ntt/ntt_api.rs create mode 100644 src/ingo_ntt/ntt_data.rs create mode 100644 src/ingo_ntt/ntt_hw_code.rs create mode 100644 tests/integration_ntt.rs diff --git a/.gitignore b/.gitignore index 0ed481d..f62d851 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ Cargo.lock *DS_Store* *.csv .env +test_data \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 278c0aa..8a712c2 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,40 +1,38 @@ [package] -name = "ingo-blaze" -version = "0.3.0" -edition = "2021" -authors = [ "Ingonyama" ] +authors = ["Ingonyama"] description = "Library for ZK acceleration on Xilinx FPGAs." +edition = "2021" homepage = "https://www.ingonyama.com" +name = "ingo-blaze" repository = "https://github.com/ingonyama-zk/blaze" +version = "0.4.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -libc = "0.2.126" -ark-algebra-intro = "0.3.0" -ark-ec = "0.3.0" -ark-std = "0.3.0" -ark-ff = "0.3.0" +anyhow = "1.0.68" ark-bls12-377 = "0.3.0" -ark-bn254 = "0.3.0" ark-bls12-381 = "0.3.0" +ark-bn254 = "0.3.0" +ark-ec = "0.3.0" +ark-ff = "0.3.0" +ark-std = "0.3.0" +criterion = "0.4.0" +csv = "1.1" +dotenv = "0.15.0" env_logger = "0.10.0" +libc = "0.2.126" log = "0.4.0" -packed_struct = "0.10" -anyhow = "1.0.68" -csv = "1.1" -serde = { version = "1", features = ["derive"] } num = "0.4" -strum = "0.24" -strum_macros = "0.24" +num-bigint = "0.4" +num-traits = "0.2.15" +packed_struct = "0.10" rand = "0.8.5" -criterion = "0.4.0" -lazy_static = "1.4.0" -bindgen = "0.64.0" -dotenv = "0.15.0" rayon = "1.6.1" -serde_json = { version = "1.0" } -serde-hex = "0.1.0" +strum = "0.24" +strum_macros = "0.24" thiserror = "1.0" -num-bigint = "0.4" -num-traits = "0.2.15" + +[[bench]] +harness = false +name = "ntt_bench" diff --git a/README.md b/README.md index 0f260cc..6af69d8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # BLAZE +
blaze is a Rust library for ZK acceleration on Xilinx FPGAs.
![ingo_BlazeFire_5d](https://github.com/ingonyama-zk/blaze/assets/2446179/6460173b-02af-4023-b055-c8274a0cbc21) @@ -24,14 +25,15 @@ On the connection side, we can provide an API to retrieve any necessary data (in The [DriverClient](src/driver_client/) module is designed to establish a connection between the FPGA/AWS and a known type of card, such as the C1100 card. It does not possess any knowledge about primitives. -The [DriverClient](src/driver_client/) provides basic IO methods and can load a binary, as well as provide specific and debug information about current HW. For a specific card type, the [DriverConfig](src/driver_client/dclient.rs) remains the same and can be accessed using the `driver_client_c1100_cfg` function. +The [DriverClient](src/driver_client/) provides basic IO methods and can load a binary, as well as provide specific and debug information about current HW. For a specific card type, the [DriverConfig](src/driver_client/dclient.rs) remains the same and can be accessed using the `driver_client_cfg` function. It is important to note that the high-level management layer determines which client and primitive should be used. The [DriverClient](src/driver_client/) can be overused in this process. How to create a new connection: ```rust -let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); +let dclient = DriverClient::new(&id, +DriverConfig::driver_client_cfg(CardType::C1100)); ``` ### DriverPrimitive @@ -43,14 +45,14 @@ The configuration (e.g. for msm there are addresses space and curve description) To create a new primitive instance for MSM, for example, one would use the following code: ```rust -let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); -let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, - is_precompute: false, - curve: msm_api::Curve::BLS381, - }, - dclient, +let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); +let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, + is_precompute: true, + curve: Curve::BLS381, + }, + dclient, ); ``` @@ -60,14 +62,11 @@ For data encapsulation, methods specific to each primitive can be divided into p ### General Example of usage -We will refer to any type of primitive as `DriverPrimitiveClient` to show generality. And any abbreviation for a specific primitive will be replaced by `dpc` (e.g. `dpc_api` can be `msm_api` ) +We will refer to any type of primitive as `DriverPrimitiveClient` to show generality. ```rust - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = dpc_api:: DriverPrimitiveClient::new(dpc_api::dpc_type, dclient); - - let params = driver.get_loaded_binary_parameters(); - let params_parse = dpc_api::DPCImageParametrs::parse_image_params(params[1]); + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = DriverPrimitiveClient::new(dpc_type, dclient); let _ = driver.initialize(dpc_param); let _ = driver.set_data(dpc_input); @@ -114,6 +113,30 @@ MSMInput = { } ``` +## NTT (Number Theoretic Transform) Module + +This module implements the calculation of NTT of size `2^27`. To use it, the input byte vector of elements must be specified. Each element must be represented in little-endian. The result will be a similar byte vector. + +It is worth noting that the data transfer process is slightly different from other modules. The following is an example of how to use NTT. More details can be found here: [LINK TO BLOG] + +```rust +let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); +let driver = NTTClient::new(NTT::Ntt, dclient); +let buf_host = 0; +let buf_kernel = 0; +driver.set_data(NTTInput { + buf_host, + data: in_vec, +})?; +driver.driver_client.initialize_cms()?; +driver.driver_client.reset_sensor_data()?; + +driver.initialize(NttInit {})?; +driver.start_process(Some(buf_kernel))?; +driver.wait_result()?; +let res = driver.result(Some(buf_kernel))?.unwrap(); +``` + ## Poseidon Module ## Running tests and benchmark @@ -125,7 +148,6 @@ To run tests for the MSM primitive, use the following command: ``` RUST_LOG= cargo test -- -- - ``` Also, different tests can require additional parameters: @@ -137,8 +159,36 @@ Also, it's possible to set up a number of points in MSM in the `MSM_SIZE` variab If the values of `ID` and `MSM_SIZE` are not provided, they will be defaulted to `ID=0` and `MSM_SIZE=8192`. -### Poseidon tests +### NTT tests + +To run tests for the NTT primitive, use the following command: ``` +INFNAME= OUTFNAME= RUST_LOG= cargo test -- integration_ntt ``` + +Also, different tests can require additional parameters: +`ID` `INFNAME`, and `OUTFNAME`. + +Replace `` with the desired log level (e.g. info, debug). Set `INFNAME` with the path to the input vector in little-endian byte format. Since we are testing correctness, set the path to the file with which you want to compare the result for the `OUTFNAME` variable. It should also be a little-endian byte vector +file and `ID` with the number of the FPGA slot. + +If the value of `ID` is not provided, they will be defaulted to `ID=0`. + +### NTT benchmark + +Benchmarks for NTT are located in the benches directory, it's worth clarifying that there is no correctness check inside the benchmark - for that use the tests. + +To run bench for the NTT primitive, use the following command: + +``` + +INFNAME= RUST_LOG= cargo bench +``` + +Also, bench can require additional parameters: `ID` and `INFNAME`. Set `INFNAME` with the path to the input vector in little-endian byte format. + +If the value of `ID` is not provided, they will be defaulted to `ID=0`. + +### Poseidon tests diff --git a/benches/ntt_bench.rs b/benches/ntt_bench.rs new file mode 100644 index 0000000..0c34e1f --- /dev/null +++ b/benches/ntt_bench.rs @@ -0,0 +1,54 @@ +use std::{env, fs::File, io::Read}; + +use criterion::*; +use ingo_blaze::{driver_client::*, ingo_ntt::*}; +use log::info; + +fn bench_ntt_calc(c: &mut Criterion) { + env_logger::try_init().expect("Invalid logger initialisation"); + let id = env::var("ID").unwrap_or_else(|_| 0.to_string()); + let fname = env::var("FNAME").unwrap(); + let mut f = File::open(fname).expect("no file found"); + let mut in_vec: Vec = Default::default(); + let _ = f.read_to_end(&mut in_vec); + + let buf_host = 0; + let buf_kernel = 0; + + info!("Create Driver API instance"); + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = NTTClient::new(NTT::Ntt, dclient); + log::info!("Starting set NTT data"); + let _ = driver.set_data(NTTInput { + buf_host, + data: in_vec, + }); + log::info!("Successfully set NTT data"); + let _ = driver.driver_client.initialize_cms(); + let _ = driver.driver_client.reset_sensor_data(); + + let mut group = c.benchmark_group("NTT computation"); + + log::info!("Starting NTT"); + group.bench_function("NTT", |b| { + b.iter(|| { + let _ = driver.initialize(NttInit {}); + let _ = driver.start_process(Some(buf_kernel)); + let _ = driver.wait_result(); + let _ = driver.driver_client.reset(); + }) + }); + group.finish(); + log::info!("Finishing NTT"); + + log::info!("Try to get NTT result"); + let res = driver.result(Some(buf_kernel)).unwrap(); + log::info!("NTT result: {:?}", res.unwrap().len()); +} + +criterion_group! { + name = benches; + config = Criterion::default().sample_size(10); + targets = bench_ntt_calc +} +criterion_main!(benches); diff --git a/configs/c1100_cfg.json b/configs/c1100_cfg.json deleted file mode 100644 index b87f620..0000000 --- a/configs/c1100_cfg.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "ctrl_baseaddr": "0x00000000", - "ctrl_cms_baseaddr": "0x04000000", - "ctrl_qspi_baseaddr": "0x04040000", - "ctrl_hbicap_baseaddr": "0x04050000", - "ctrl_mgmt_ram_baseaddr": "0x04060000", - "ctrl_firewall_baseaddr": "0x04070000", - "dma_firewall_baseaddr": "0x04080000", - "ctrl_dfx_decoupler_baseaddr": "0x04090000", - "dma_baseaddr": "0x0000000000000000", - "dma_hbicap_baseaddr": "0x1000000000000000" -} \ No newline at end of file diff --git a/configs/msm_bls377_dma_cfg.json b/configs/msm_bls377_dma_cfg.json deleted file mode 100644 index 29fd50a..0000000 --- a/configs/msm_bls377_dma_cfg.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "result_point_size": 144, - "point_size": 96, - "scalar_size": 32, - "dma_scalars_addr": "0x0000020000000000", - "dma_points_addr": "0x0000010000000000" -} \ No newline at end of file diff --git a/configs/msm_bls377_hbm_cfg.json b/configs/msm_bls377_hbm_cfg.json deleted file mode 100644 index 687a2af..0000000 --- a/configs/msm_bls377_hbm_cfg.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "result_point_size": 144, - "point_size": 96, - "scalar_size": 32, - "dma_scalars_addr": "0x0000020000000000" -} \ No newline at end of file diff --git a/configs/msm_bls381_dma_cfg.json b/configs/msm_bls381_dma_cfg.json deleted file mode 100644 index 29fd50a..0000000 --- a/configs/msm_bls381_dma_cfg.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "result_point_size": 144, - "point_size": 96, - "scalar_size": 32, - "dma_scalars_addr": "0x0000020000000000", - "dma_points_addr": "0x0000010000000000" -} \ No newline at end of file diff --git a/configs/msm_bls381_hbm_cfg.json b/configs/msm_bls381_hbm_cfg.json deleted file mode 100644 index 687a2af..0000000 --- a/configs/msm_bls381_hbm_cfg.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "result_point_size": 144, - "point_size": 96, - "scalar_size": 32, - "dma_scalars_addr": "0x0000020000000000" -} \ No newline at end of file diff --git a/configs/msm_bn254_dma_cfg.json b/configs/msm_bn254_dma_cfg.json deleted file mode 100644 index 5d3fc9b..0000000 --- a/configs/msm_bn254_dma_cfg.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "result_point_size": 96, - "point_size": 64, - "scalar_size": 32, - "dma_scalars_addr": "0x0000010000000000", - "dma_points_addr": "0x0000000000000000" -} \ No newline at end of file diff --git a/src/driver_client/dclient.rs b/src/driver_client/dclient.rs index df8229c..f75beca 100755 --- a/src/driver_client/dclient.rs +++ b/src/driver_client/dclient.rs @@ -6,12 +6,11 @@ //! for custom modules. Each custom module is built on top of this foundation and //! includes its own specific fields and methods. //! +use super::{dclient_cfg::*, dclient_code::*}; use crate::{ - driver_client::dclient_code::*, error::*, - utils::{deserialize_hex, open_channel, AccessFlags}, + utils::{open_channel, AccessFlags}, }; -use serde::Deserialize; use std::{fmt::Debug, os::unix::fs::FileExt, thread::sleep, time::Duration}; /// A trait for defining functions related to parameters of specific core image. @@ -32,65 +31,25 @@ pub trait DriverPrimitive { /// The `loaded_binary_parameters` method returns /// a vector of 32-bit unsigned integers representing the loaded binary parameters. fn loaded_binary_parameters(&self) -> Vec; - /// The `reset` method resets the driver primitive to its initial state. - fn reset(&self) -> Result<()>; /// The `initialize` method initializes the driver primitive with the given parameter. fn initialize(&self, param: P) -> Result<()>; /// The `set_data` method sets the input data for the driver primitive. fn set_data(&self, input: I) -> Result<()>; + /// The `start_process` method starts the driver after setting all controls and data. + fn start_process(&self, param: Option) -> Result<()>; /// The `wait_result` method waits for the driver primitive to finish processing the input data. fn wait_result(&self) -> Result<()>; /// The `result` method returns the output data from the driver primitive, /// optionally with a specific parameter. If there is no output data available, it returns `None`. - fn result(&self, _param: Option) -> Result>; -} - -/// The [`DriverConfig`] is a struct that defines a set of 64-bit unsigned integer (`u64`) -/// representing addreses memory space for different components of a FPGA. -/// -/// The struct is divided into logical parts: AXI Lite space of addresses and AXI space of addresses -#[derive(Copy, Clone, Deserialize, Debug)] -pub struct DriverConfig { - // CTRL - #[serde(deserialize_with = "deserialize_hex")] - pub ctrl_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub ctrl_cms_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub ctrl_qspi_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub ctrl_hbicap_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub ctrl_mgmt_ram_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub ctrl_firewall_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub dma_firewall_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub ctrl_dfx_decoupler_baseaddr: u64, - - // DMA - #[serde(deserialize_with = "deserialize_hex")] - pub dma_baseaddr: u64, - #[serde(deserialize_with = "deserialize_hex")] - pub dma_hbicap_baseaddr: u64, -} - -impl DriverConfig { - /// Create a new driver config from the given configuration in json format. - pub fn driver_client_c1100_cfg() -> Self { - let file = std::fs::File::open("configs/c1100_cfg.json").expect(""); - let reader = std::io::BufReader::new(file); - serde_json::from_reader(reader).unwrap() - } + fn result(&self, param: Option) -> Result>; } /// The [`DriverClient`] is described bunch of addreses on FPGA which called [`DriverConfig`] also /// it includes file descriptor for read-from and write-to channels using DMA bus and CTRL bus. pub struct DriverClient { /// Addreses space of current FPGA. - pub cfg: DriverConfig, + pub(crate) cfg: DriverConfig, /// Write only channel from host memory into custom core using DMA bus. pub dma_h2c_write: std::fs::File, /// Read only channel from core using DMA bus. @@ -115,7 +74,7 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::shell_api::{DriverClient, DriverConfig}; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); /// ``` pub fn new(id: &str, cfg: DriverConfig) -> Self { DriverClient { @@ -125,6 +84,13 @@ impl DriverClient { ctrl: open_channel(&format!("/dev/xdma{}_user", id), AccessFlags::RdwrMode), } } + /// The `reset` method resets the driver primitive to its initial state. + pub fn reset(&self) -> Result<()> { + self.set_dfx_decoupling(1)?; + self.set_dfx_decoupling(0)?; + sleep(Duration::from_millis(100)); + Ok(()) + } // ==== DFX ==== /// Method for checking decouple status. @@ -136,16 +102,12 @@ impl DriverClient { } /// Setup decouple signal to isolate the user logic during reconfiguration, protecting the shell from spurious signals. - pub fn set_dfx_decoupling(&self, signal: u8) -> Result<()> { - self.ctrl - .write_all_at( - &[signal, 0, 0, 0], - self.cfg.ctrl_dfx_decoupler_baseaddr + DFX_DECOUPLER::DECOUPLE as u64, - ) - .map_err(|e| DriverClientError::WriteError { - offset: "DFX setup error".to_string(), - source: e, - })?; + pub fn set_dfx_decoupling(&self, signal: u32) -> Result<()> { + self.ctrl_write_u32( + self.cfg.ctrl_dfx_decoupler_baseaddr, + DFX_DECOUPLER::DECOUPLE, + signal, + )?; Ok(()) } @@ -156,23 +118,38 @@ impl DriverClient { CMS_ADDR::ADDR_CPU2HIF_CMS_INITIALIZE, 1, )?; + // Required for FPGA. In plan: replace with checking register loop + sleep(Duration::from_millis(10)); Ok(()) } /// This method setup 27 bit in CONTROL_REG for enabling hbm temperature monitoring. pub fn enable_hbm_temp_monitoring(&self) -> Result<()> { let ctrl_reg = self.ctrl_read_u32( - self.cfg.ctrl_cms_baseaddr + 0x028000, + self.cfg.ctrl_cms_baseaddr + CMS_ADDR::ADDR_HIF2CPU_CMS_REG_MAP as u64, CMS_ADDR::ADDR_HIF2CPU_CMS_CONTROL_REG, ); self.ctrl_write_u32( - self.cfg.ctrl_cms_baseaddr + 0x028000, + self.cfg.ctrl_cms_baseaddr + CMS_ADDR::ADDR_HIF2CPU_CMS_REG_MAP as u64, CMS_ADDR::ADDR_HIF2CPU_CMS_CONTROL_REG, ctrl_reg.unwrap() | 1 << 27, )?; Ok(()) } + pub fn reset_sensor_data(&self) -> Result<()> { + let ctrl_reg = self.ctrl_read_u32( + self.cfg.ctrl_cms_baseaddr + CMS_ADDR::ADDR_HIF2CPU_CMS_REG_MAP as u64, + CMS_ADDR::ADDR_HIF2CPU_CMS_CONTROL_REG, + ); + self.ctrl_write_u32( + self.cfg.ctrl_cms_baseaddr + CMS_ADDR::ADDR_HIF2CPU_CMS_REG_MAP as u64, + CMS_ADDR::ADDR_HIF2CPU_CMS_CONTROL_REG, + ctrl_reg.unwrap() | 1, + )?; + Ok(()) + } + // HBICAP /// Checking HBICAP status register. Return `true` if zero (previous operation done) and /// second (Indicates that the EOS is complete) bit setting to 1. @@ -205,9 +182,7 @@ impl DriverClient { HBICAP_ADDR::ADDR_HIF2CPU_HBICAP_ABORT_STATUS, )?; - self.set_firewall_block(self.cfg.ctrl_firewall_baseaddr, true)?; - self.set_firewall_block(self.cfg.dma_firewall_baseaddr, true)?; - + self.block_firewalls()?; Ok(()) } @@ -225,7 +200,7 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::{shell_api::DriverClient, shell_hw_code::*}; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); /// /// // read binary data from given filename /// let buf = utils::read_binary_file(&filename); @@ -247,12 +222,13 @@ impl DriverClient { HBICAP_ADDR::ADDR_CPU2HIF_HBICAP_TRANSFER_SIZE, (binary.len() / 4) as u32, )?; - self.dma_write_by_chunks(self.cfg.ctrl_hbicap_baseaddr, DMA_RW::OFFSET, binary, 4)?; - + self.dma_write(self.cfg.dma_hbicap_baseaddr, DMA_RW::OFFSET, binary)?; while !self.is_hbicap_ready() { - sleep(Duration::from_millis(1)); + continue; } self.set_dfx_decoupling(0)?; + self.unblock_firewalls()?; + self.ctrl_read_u32( self.cfg.ctrl_hbicap_baseaddr, HBICAP_ADDR::ADDR_HIF2CPU_HBICAP_ABORT_STATUS, @@ -275,11 +251,11 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::shell_api::DriverClient; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverConfig::driver_client_cfg(CardType::C1100); /// dclient.set_firewall_block(dclient.cfg.ctrl_firewall_baseaddr, true); // ctrl firewall is now blocked /// dclient.set_firewall_block(dclient.cfg.dma_firewall_baseaddr, false); // dma firewall is now unblocked /// ``` - pub fn set_firewall_block(&self, addr: u64, block: bool) -> Result<()> { + fn set_firewall_block(&self, addr: u64, block: bool) -> Result<()> { if block { self.ctrl_write_u32(addr, FIREWALL_ADDR::BLOCK, 0x100_0100)?; Ok(()) @@ -290,14 +266,15 @@ impl DriverClient { } } + pub fn block_firewalls(&self) -> Result<()> { + self.set_firewall_block(self.cfg.ctrl_firewall_baseaddr, true)?; + self.set_firewall_block(self.cfg.dma_firewall_baseaddr, true)?; + Ok(()) + } + pub fn unblock_firewalls(&self) -> Result<()> { self.set_firewall_block(self.cfg.ctrl_firewall_baseaddr, false)?; self.set_firewall_block(self.cfg.dma_firewall_baseaddr, false)?; - self.ctrl_write_u32( - self.cfg.ctrl_firewall_baseaddr, - FIREWALL_ADDR::DISABLE_BLOCK, - 0, - )?; Ok(()) } @@ -316,7 +293,7 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::{shell_api::DriverClient, shell_hw_code::*}; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); /// /// let ret = dclient.ctrl_read_u32( /// dclient.cfg.ctrl_firewall_baseaddr, @@ -390,7 +367,7 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::{shell_api::DriverClient, shell_hw_code::*}; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); /// /// dclient.ctrl_write_u32( /// dclient.cfg.ctrl_hbicap_baseaddr, @@ -458,7 +435,7 @@ impl DriverClient { /// /// * `base_address`: the base address in the DMA bus addresses space /// * `offset`: an enum which represent the specific offset for given `base_address`. - /// * `size`: an unsigned integer representing the size of the data to be read. + /// * `read_buffer`: existing mememory for reading data /// /// returns: Vec /// @@ -467,7 +444,7 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::{shell_api::DriverClient, shell_hw_code::*}; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); /// let size_of_input = 16; /// let readen = dclient.dma_read( /// dclient.cfg.dma_baseaddr, @@ -480,19 +457,17 @@ impl DriverClient { &self, base_address: u64, offset: T, - size: usize, - ) -> Result> { - let mut read_data = vec![0; size]; - + read_buffer: &mut Vec, + ) -> Result<()> { self.dma_c2h_read - .read_exact_at(&mut read_data, base_address + offset.into()) + .read_exact_at(read_buffer, base_address + offset.into()) .map_err(|e| DriverClientError::ReadError { offset: format!("{:?}", offset), source: e, })?; - crate::getter_log!(read_data, offset); - Ok(read_data) + crate::getter_log!(read_buffer, offset); + Ok(()) } /// The method for writing data from host memory into FPGA. @@ -512,7 +487,7 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::{shell_api::DriverClient, shell_hw_code::*}; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); /// let input = vec![1, 2, 3, 4, 5, 6, 7, 8]; /// let chunk_size = 4; /// @@ -558,7 +533,7 @@ impl DriverClient { /// ```rust /// use ingo_blaze::shell::{shell_api::DriverClient, shell_hw_code::*}; /// - /// let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + /// let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); /// let input = vec![1, 2, 3, 4, 5, 6, 7, 8]; /// let chunk_size = 4; /// diff --git a/src/driver_client/dclient_cfg.rs b/src/driver_client/dclient_cfg.rs new file mode 100644 index 0000000..905ef57 --- /dev/null +++ b/src/driver_client/dclient_cfg.rs @@ -0,0 +1,47 @@ +pub enum CardType { + C1100, +} + +/// The [`DriverConfig`] is a struct that defines a set of 64-bit unsigned integer (`u64`) +/// representing addreses memory space for different components of a FPGA. +/// The struct is divided into logical parts: AXI Lite space of addresses and AXI space of addresses +#[derive(Copy, Clone, Debug)] +pub struct DriverConfig { + // CTRL + pub(crate) ctrl_baseaddr: u64, + pub(crate) ctrl_cms_baseaddr: u64, + // pub(crate) ctrl_qspi_baseaddr: u64, + pub(crate) ctrl_hbicap_baseaddr: u64, + // pub(crate) ctrl_mgmt_ram_baseaddr: u64, + pub(crate) ctrl_firewall_baseaddr: u64, + pub(crate) dma_firewall_baseaddr: u64, + pub(crate) ctrl_dfx_decoupler_baseaddr: u64, + + // DMA + pub(crate) dma_baseaddr: u64, + pub(crate) dma_hbicap_baseaddr: u64, +} + +impl DriverConfig { + /// Create a new driver config. + pub fn driver_client_cfg(card_type: CardType) -> Self { + match card_type { + CardType::C1100 => c1100_cfg(), + } + } +} + +fn c1100_cfg() -> DriverConfig { + DriverConfig { + ctrl_baseaddr: 0x00000000, + ctrl_cms_baseaddr: 0x04000000, + // ctrl_qspi_baseaddr: 0x04040000, + ctrl_hbicap_baseaddr: 0x04050000, + // ctrl_mgmt_ram_baseaddr: 0x04060000, + ctrl_firewall_baseaddr: 0x04070000, + dma_firewall_baseaddr: 0x04080000, + ctrl_dfx_decoupler_baseaddr: 0x04090000, + dma_baseaddr: 0x0000000000000000, + dma_hbicap_baseaddr: 0x1000000000000000, + } +} diff --git a/src/driver_client/dclient_code.rs b/src/driver_client/dclient_code.rs index c148ea0..a3da0e4 100755 --- a/src/driver_client/dclient_code.rs +++ b/src/driver_client/dclient_code.rs @@ -40,7 +40,8 @@ impl From for u64 { #[derive(Debug, Copy, Clone)] pub enum CMS_ADDR { ADDR_CPU2HIF_CMS_INITIALIZE = 0x020000, - ADDR_HIF2CPU_CMS_CONTROL_REG = 0x0018, + ADDR_HIF2CPU_CMS_CONTROL_REG = 0x028000, + ADDR_HIF2CPU_CMS_REG_MAP = 0x0018, } impl From for u64 { fn from(addr: CMS_ADDR) -> Self { diff --git a/src/driver_client/mod.rs b/src/driver_client/mod.rs index 10f34d8..9316a3d 100755 --- a/src/driver_client/mod.rs +++ b/src/driver_client/mod.rs @@ -1,2 +1,7 @@ -pub mod dclient; -pub(crate) mod dclient_code; +mod dclient; +mod dclient_cfg; +mod dclient_code; + +pub use dclient::*; +pub use dclient_cfg::{CardType, DriverConfig}; +pub(crate) use dclient_code::*; diff --git a/src/error.rs b/src/error.rs index 6650beb..b78f2f3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -25,6 +25,8 @@ pub enum DriverClientError { CsvError(#[from] csv::Error), #[error("failed to load instruction set from: {:?}", path)] LoadFailed { path: String }, + #[error("failed open file")] + FileError(#[from] io::Error), #[error("unknown driver client error")] Unknown, } diff --git a/src/ingo_hash/mod.rs b/src/ingo_hash/mod.rs index 19fadeb..ca8813e 100755 --- a/src/ingo_hash/mod.rs +++ b/src/ingo_hash/mod.rs @@ -1,3 +1,6 @@ -pub mod hash_hw_code; -pub mod poseidon_api; -pub mod utils; +mod hash_hw_code; +mod poseidon_api; +mod utils; + +pub use poseidon_api::*; +pub use utils::*; diff --git a/src/ingo_hash/poseidon_api.rs b/src/ingo_hash/poseidon_api.rs index 0949c68..0369466 100755 --- a/src/ingo_hash/poseidon_api.rs +++ b/src/ingo_hash/poseidon_api.rs @@ -1,11 +1,9 @@ use packed_struct::prelude::PackedStruct; -use std::{option::Option, thread::sleep, time::Duration}; +use std::option::Option; use strum::IntoEnumIterator; -use crate::{ - driver_client::dclient::*, driver_client::dclient_code::*, error::*, - ingo_hash::hash_hw_code::*, ingo_hash::utils::*, utils::convert_to_32_byte_array, -}; +use super::{hash_hw_code::*, TreeMode}; +use crate::{driver_client::*, error::*, utils::convert_to_32_byte_array}; use csv; use num::{bigint::BigUint, Num}; @@ -95,15 +93,8 @@ impl DriverPrimitive>() } - fn reset(&self) -> Result<()> { - self.dclient.set_dfx_decoupling(1)?; - self.dclient.set_dfx_decoupling(0)?; - sleep(Duration::from_millis(100)); - Ok(()) - } - fn initialize(&self, param: PoseidonInitializeParameters) -> Result<()> { - self.reset()?; + self.dclient.reset()?; self.set_initialize_mode(true)?; self.load_instructions(¶m.instruction_path) @@ -119,6 +110,10 @@ impl DriverPrimitive) -> Result<()> { + todo!() + } + fn set_data(&self, input: &[u8]) -> Result<()> { self.dclient .dma_write(self.dclient.cfg.dma_baseaddr, DMA_RW::OFFSET, input)?; @@ -194,11 +189,10 @@ impl PoseidonClient { } pub fn get_raw_results(&self, num_of_results: u32) -> Result> { - self.dclient.dma_read( - self.dclient.cfg.dma_baseaddr, - DMA_RW::OFFSET, - (64 * num_of_results).try_into().unwrap(), - ) + let mut res = vec![0; (64 * num_of_results).try_into().unwrap()]; + self.dclient + .dma_read(self.dclient.cfg.dma_baseaddr, DMA_RW::OFFSET, &mut res)?; + Ok(res) } pub fn get_last_hash_sent_to_host(&self) -> Result { @@ -297,7 +291,7 @@ mod tests { info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); let driver: PoseidonClient = PoseidonClient::new(Hash::Poseidon, dclient); let params = driver.loaded_binary_parameters(); info!("Driver parameters: [{:?}, {:032b}]", params[0], params[1]); diff --git a/src/ingo_msm/mod.rs b/src/ingo_msm/mod.rs index b83d023..46136c1 100755 --- a/src/ingo_msm/mod.rs +++ b/src/ingo_msm/mod.rs @@ -1,2 +1,6 @@ -pub mod msm_api; -pub mod msm_hw_code; +mod msm_api; +mod msm_cfg; +mod msm_hw_code; + +pub use msm_api::*; +pub use msm_cfg::{Curve, PointMemoryType}; diff --git a/src/ingo_msm/msm_api.rs b/src/ingo_msm/msm_api.rs index 07457c4..5f39e06 100755 --- a/src/ingo_msm/msm_api.rs +++ b/src/ingo_msm/msm_api.rs @@ -1,54 +1,9 @@ -use crate::{ - driver_client::dclient::*, driver_client::dclient_code::*, error::*, ingo_msm::msm_hw_code::*, - utils::deserialize_option_hex, -}; +use super::{msm_cfg::*, msm_hw_code::*}; +use crate::{driver_client::*, error::*}; use packed_struct::prelude::*; -use serde::Deserialize; -use std::{os::unix::fs::FileExt, thread::sleep, time::Duration}; - +use std::os::unix::fs::FileExt; use strum::IntoEnumIterator; -use strum_macros::EnumString; - -#[derive(Debug, EnumString, PartialEq)] -pub enum Curve { - BLS377, - BLS381, - BN254, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)] -pub enum PointMemoryType { - HBM, - DMA, -} - -#[derive(Deserialize, Debug, Copy, Clone)] -struct MSMConfig { - // The size characteristic in points and scalars in a curve. - /// The size in bytes of result point. The point is expected to be in projective form. - result_point_size: usize, - /// The size of one point in bytes. Point is represented in affine form. - point_size: Option, - /// The size of scalar coordinate in bytes. - scalar_size: usize, - - // Ingo MSM Core additional addresses - #[serde(default, deserialize_with = "deserialize_option_hex")] - dma_scalars_addr: Option, - #[serde(default, deserialize_with = "deserialize_option_hex")] - dma_points_addr: Option, -} - -impl MSMConfig { - fn load_cfg(curve: Curve, mem: PointMemoryType) -> Self { - let name = format!("configs/msm_{:?}_{:?}_cfg.json", curve, mem).to_ascii_lowercase(); - log::info!("Config name: {}", name); - let file = std::fs::File::open(name).expect(""); - let reader = std::io::BufReader::new(file); - serde_json::from_reader(reader).unwrap() - } -} pub struct MSMClient { mem_type: PointMemoryType, @@ -94,7 +49,7 @@ impl DriverPrimitive for MSMClient { } else { PRECOMPUTE_FACTOR_BASE }, - msm_cfg: MSMConfig::load_cfg(init.curve, init.mem_type), + msm_cfg: MSMConfig::msm_cfg(init.curve, init.mem_type), driver_client: dclient, } } @@ -114,13 +69,6 @@ impl DriverPrimitive for MSMClient { .collect::>() } - fn reset(&self) -> Result<()> { - self.driver_client.set_dfx_decoupling(1)?; - self.driver_client.set_dfx_decoupling(0)?; - sleep(Duration::from_millis(100)); - Ok(()) - } - fn initialize(&self, params: MSMParams) -> Result<()> { log::info!("Start initialize driver"); @@ -159,14 +107,16 @@ impl DriverPrimitive for MSMClient { params.nof_elements, )?; + Ok(()) + } + + fn start_process(&self, _param: Option) -> Result<()> { log::info!("Pushing Task Signal"); self.driver_client.ctrl_write_u32( self.driver_client.cfg.ctrl_baseaddr, INGO_MSM_ADDR::ADDR_CPU2HIF_E_PUSH_MSM_TASK_TO_QUEUE, 1, - )?; - - Ok(()) + ) } /// This function sets data for compute MSM and has three different cases depending on the input parameters. @@ -362,12 +312,13 @@ impl MSMClient { Ok(()) } - pub fn get_data_from_hbm(&self, data: &[u8], addr: u64, offset: u64) -> Result> { + pub fn get_data_from_hbm(&self, data_len: usize, addr: u64, offset: u64) -> Result> { log::debug!("HBM adress: {:#X?}", addr); - log::debug!("Data length: {:#X?}", data.len()); - let res = self.driver_client.dma_read(addr, offset, data.len()); + log::debug!("Data length: {:#X?}", data_len); + let mut res = vec![0; data_len]; + self.driver_client.dma_read(addr, offset, &mut res)?; log::debug!("Successfully read data from hbm"); - res + Ok(res) } pub fn get_api(&self) { diff --git a/src/ingo_msm/msm_cfg.rs b/src/ingo_msm/msm_cfg.rs new file mode 100644 index 0000000..13c0d92 --- /dev/null +++ b/src/ingo_msm/msm_cfg.rs @@ -0,0 +1,92 @@ +use strum_macros::EnumString; + +#[derive(Debug, EnumString, PartialEq)] +pub enum Curve { + BLS377, + BLS381, + BN254, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)] +pub enum PointMemoryType { + HBM, + DMA, +} + +#[derive(Debug, Copy, Clone)] +pub(super) struct MSMConfig { + // The size characteristic in points and scalars in a curve. + /// The size in bytes of result point. The point is expected to be in projective form. + pub result_point_size: usize, + /// The size of one point in bytes. Point is represented in affine form. + pub point_size: Option, + /// The size of scalar coordinate in bytes. + pub scalar_size: usize, + + // Ingo MSM Core additional addresses + pub dma_scalars_addr: Option, + pub dma_points_addr: Option, +} + +impl MSMConfig { + pub(super) fn msm_cfg(curve: Curve, mem: PointMemoryType) -> Self { + match (curve, mem) { + (Curve::BLS377, PointMemoryType::HBM) => msm_bls377_hbm_cfg(), + (Curve::BLS377, PointMemoryType::DMA) => msm_bls377_dma_cfg(), + (Curve::BLS381, PointMemoryType::HBM) => msm_bls381_hbm_cfg(), + (Curve::BLS381, PointMemoryType::DMA) => msm_bls381_dma_cfg(), + (Curve::BN254, PointMemoryType::HBM) => todo!(), + (Curve::BN254, PointMemoryType::DMA) => msm_bn254_dma_cfg(), + } + } +} + +fn msm_bls377_hbm_cfg() -> MSMConfig { + MSMConfig { + result_point_size: 144, + point_size: Some(96), + scalar_size: 32, + dma_scalars_addr: Some(0x0000020000000000), + dma_points_addr: None, + } +} + +fn msm_bls377_dma_cfg() -> MSMConfig { + MSMConfig { + result_point_size: 144, + point_size: Some(96), + scalar_size: 32, + dma_scalars_addr: Some(0x0000020000000000), + dma_points_addr: Some(0x0000010000000000), + } +} + +fn msm_bls381_hbm_cfg() -> MSMConfig { + MSMConfig { + result_point_size: 144, + point_size: Some(96), + scalar_size: 32, + dma_scalars_addr: Some(0x0000020000000000), + dma_points_addr: None, + } +} + +fn msm_bls381_dma_cfg() -> MSMConfig { + MSMConfig { + result_point_size: 144, + point_size: Some(96), + scalar_size: 32, + dma_scalars_addr: Some(0x0000020000000000), + dma_points_addr: Some(0x0000010000000000), + } +} + +fn msm_bn254_dma_cfg() -> MSMConfig { + MSMConfig { + result_point_size: 96, + point_size: Some(64), + scalar_size: 32, + dma_scalars_addr: Some(0x0000010000000000), + dma_points_addr: Some(0x0000000000000000), + } +} diff --git a/src/ingo_ntt/mod.rs b/src/ingo_ntt/mod.rs new file mode 100644 index 0000000..9558b96 --- /dev/null +++ b/src/ingo_ntt/mod.rs @@ -0,0 +1,5 @@ +mod ntt_api; +mod ntt_data; +mod ntt_hw_code; + +pub use ntt_api::*; diff --git a/src/ingo_ntt/ntt_api.rs b/src/ingo_ntt/ntt_api.rs new file mode 100644 index 0000000..9a1b42c --- /dev/null +++ b/src/ingo_ntt/ntt_api.rs @@ -0,0 +1,125 @@ +use super::{ + ntt_data::{NTTBanks, NTTConfig, NOF_BANKS}, + ntt_hw_code::*, +}; +use crate::{driver_client::*, error::*}; +use std::{fmt::Debug, os::unix::fs::FileExt}; + +pub enum NTT { + Ntt, +} + +pub struct NTTClient { + ntt_cfg: NTTConfig, + pub driver_client: DriverClient, +} + +pub struct NttInit {} + +#[derive(Debug, Clone)] +pub struct NTTInput { + pub buf_host: usize, + pub data: Vec, +} + +impl DriverPrimitive> for NTTClient { + fn new(_ptype: NTT, dclient: DriverClient) -> Self { + NTTClient { + ntt_cfg: NTTConfig::ntt_cfg(), + driver_client: dclient, + } + } + + fn loaded_binary_parameters(&self) -> Vec { + todo!() + } + + fn initialize(&self, _: NttInit) -> Result<()> { + let enable_debug_program = 0x00; + let debug_program: Vec = vec![0xFF00000000, 0xFF00000000]; + + self.driver_client.ctrl_write_u32( + self.ntt_cfg.ntt_addrs.hbm_ss_baseaddr, + INGO_NTT_SUPER_PROGRAM_ADDR::XHBM_SS_CONTROL_ADDR_HIF_INPUT_ENABLE_DEBUG_PROGRAM_DATA, + enable_debug_program, + )?; + + for (i, input) in debug_program.iter().enumerate() { + let size = 8; + self.driver_client.ctrl_write( + self.ntt_cfg.ntt_addrs.hbm_ss_baseaddr + (size * i as u64), + INGO_NTT_SUPER_PROGRAM_ADDR::XHBM_SS_CONTROL_ADDR_HIF_INPUT_DEBUG_PROGRAM_BASE, + &input.to_le_bytes(), + )?; + } + Ok(()) + } + + fn start_process(&self, buf_kernel: Option) -> Result<()> { + self.driver_client.ctrl_write_u32( + self.ntt_cfg.ntt_addrs.hbm_ss_baseaddr, + INGO_NTT_SUPER_PROGRAM_ADDR::XHBM_SS_CONTROL_ADDR_HIF_INPUT_BUFFER_DATA, + buf_kernel.unwrap().try_into().unwrap(), + )?; + + self.driver_client.ctrl_write_u32( + self.ntt_cfg.ntt_addrs.hbm_ss_baseaddr, + INGO_NTT_SUPER_PROGRAM_ADDR::XHBM_SS_CONTROL_ADDR_AP_CTRL, + 1, + ) + } + + fn set_data(&self, input: NTTInput) -> Result<()> { + let data_banks = NTTBanks::preprocess(input.data); + + data_banks + .banks + .into_iter() + .enumerate() + .try_for_each(|(i, data_in)| { + let offset = self.ntt_cfg.ntt_bank_start_addr(i, input.buf_host); + self.driver_client.dma_write( + self.driver_client.cfg.dma_baseaddr, + offset, + data_in.as_slice(), + ) + }) + } + + fn wait_result(&self) -> Result<()> { + let mut result_valid = [0, 0, 0, 0]; + let mut done = false; + log::debug!("Waiting ready signal from offset: XHBM_SS_CONTROL_ADDR_AP_CTRL"); + while !done { + self.driver_client + .ctrl + .read_exact_at( + &mut result_valid, + self.driver_client.cfg.ctrl_baseaddr + + INGO_NTT_SUPER_PROGRAM_ADDR::XHBM_SS_CONTROL_ADDR_AP_CTRL as u64, + ) + .map_err(|e| DriverClientError::ReadError { + offset: "XHBM_SS_CONTROL_ADDR_AP_CTRL".to_string(), + source: e, + })?; + done = (result_valid[0] & 0x2) == 0x2; + } + Ok(()) + } + + fn result(&self, buf_num: Option) -> Result>> { + let mut res_banks: NTTBanks = Default::default(); + for i in 0..NOF_BANKS { + let offset = self.ntt_cfg.ntt_bank_start_addr(i, buf_num.unwrap()); + res_banks.banks[i] = vec![0; NTTConfig::NTT_BUFFER_SIZE]; + self.driver_client.dma_read( + self.driver_client.cfg.dma_baseaddr, + offset, + &mut res_banks.banks[i], + )?; + } + + let res = res_banks.postprocess(); + Ok(Some(res)) + } +} diff --git a/src/ingo_ntt/ntt_data.rs b/src/ingo_ntt/ntt_data.rs new file mode 100644 index 0000000..4a7a895 --- /dev/null +++ b/src/ingo_ntt/ntt_data.rs @@ -0,0 +1,232 @@ +pub(super) const NOF_BANKS: usize = 16; + +#[derive(Debug, Copy, Clone)] +pub(super) struct NTTAddrs { + pub hbm_ss_baseaddr: u64, + pub hbm_addrs: [u64; NOF_BANKS], +} + +fn ntt_addrs() -> NTTAddrs { + NTTAddrs { + hbm_ss_baseaddr: 0x0, + hbm_addrs: [ + 0x000000000, + 0x020000000, + 0x040000000, + 0x060000000, + 0x080000000, + 0x0A0000000, + 0x0C0000000, + 0x0E0000000, + 0x100000000, + 0x120000000, + 0x140000000, + 0x160000000, + 0x180000000, + 0x1A0000000, + 0x1C0000000, + 0x1E0000000, + ], + } +} + +#[derive(Debug, Copy, Clone)] +pub(super) struct NTTConfig { + pub ntt_addrs: NTTAddrs, +} + +impl NTTConfig { + // In essence, the HBM is subdivided into two rows and two columns. + // The two rows account for the HBM double buffer and + // the two columns account for the left and right NTTC sides. + pub const NTT_BUFFER_SIZE: usize = 268435456; // 2**28 - size of one buffer into HBM + + pub fn ntt_cfg() -> Self { + NTTConfig { + ntt_addrs: ntt_addrs(), + } + } + + pub(super) fn hbm_bank_start_addr(&self, bank_num: usize) -> u64 { + *self.ntt_addrs.hbm_addrs.get(bank_num).unwrap() + } + + pub(super) fn ntt_bank_start_addr(&self, bank_num: usize, buf_num: usize) -> u64 { + self.hbm_bank_start_addr(bank_num) + (Self::NTT_BUFFER_SIZE * buf_num) as u64 + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct NTTBanks { + pub banks: [Vec; NOF_BANKS], +} + +impl NTTBanks { + const NTT_SIZE: usize = 134217728; // Size of NTT = 2**27 + const NTT_WORD_SIZE: usize = 32; // Size of one element in NTT in bytes + const NTT_NOF_MMU_IN_CORE: usize = 8; // Number of MMUs into which one subNTT splits into + + // The NTT data (corresponding to a single buffer) consists of 512 Groups (NTT_NOF_GROUPS), + // each Group consisting of two Slices (NTT_NOF_SLICE), + // each Slice consisting of 16 Batches (NTT_NOF_BATCH), + // and each Batch consisting of 16 subNTTs (NTT_NOF_SUBNTT), + // each subNTTs consisting of 64 rows (NTT_NOF_ROW). + const NTT_NOF_GROUPS: usize = 512; + const NTT_NOF_SLICE: usize = 2; + const NTT_NOF_BATCH: usize = 16; + const NTT_NOF_SUBNTT: usize = 8; + const NTT_NOF_ROW: usize = 64; + + pub(super) fn preprocess(input: Vec) -> Self { + log::info!("Start preparing the input vector before NTT"); + let mut banks: Vec> = Vec::with_capacity(NOF_BANKS); + for _ in 0..NOF_BANKS { + banks.push(Default::default()); + } + let mut addr = 0; + for group in 0..Self::NTT_NOF_GROUPS { + for _ in 0..Self::NTT_NOF_SLICE { + for _ in 0..Self::NTT_NOF_BATCH { + for _ in 0..Self::NTT_NOF_SUBNTT { + for cores in [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15]] { + for row in 0..Self::NTT_NOF_ROW { + for bank_num in cores.into_iter() { + let buf: &[u8] = &input[(bank_num % 8 + row * 8) * 32 + addr + ..(bank_num % 8 + row * 8 + 1) * 32 + addr]; + banks[bank_num].extend_from_slice(buf); + } + } + addr += + Self::NTT_NOF_MMU_IN_CORE * Self::NTT_NOF_ROW * Self::NTT_WORD_SIZE; + } + } + } + } + log::trace!("Group {} is ready", group) + } + + NTTBanks { + banks: banks.try_into().unwrap(), + } + } + + pub(super) fn postprocess(&self) -> Vec { + log::info!("Start processing the result after NTT"); + let mut res = vec![0u8; Self::NTT_SIZE * Self::NTT_WORD_SIZE]; + log::debug!("Allocate vector of size: {}", res.len()); + + let mut group_start = [0, 0]; + let offset = [0, 512]; + let mut bank_offset = vec![0usize; 16]; + for group in 0..Self::NTT_NOF_GROUPS { + let mut block = 0; + for i in 0..2 { + group_start[i] = offset[i] + group; + } + for _ in 0..Self::NTT_NOF_SLICE { + for _ in 0..Self::NTT_NOF_BATCH { + for _ in 0..Self::NTT_NOF_SUBNTT { + for icore in 0..2 { + let isubntt = group_start[icore] + 1024 * block; + let mut i = 0; + for _ in 0..Self::NTT_NOF_ROW { + let cores = if group % 2 == 0 { + [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15]] + } else { + [[8, 9, 10, 11, 12, 13, 14, 15], [0, 1, 2, 3, 4, 5, 6, 7]] + }; + for bank_num in cores[icore].into_iter() { + let addr = 512 * isubntt + i; + res[addr * 32..(addr + 1) * 32].copy_from_slice( + &self.banks[bank_num] + [bank_offset[bank_num]..bank_offset[bank_num] + 32], + ); + bank_offset[bank_num] += 32; + i += 1; + } + } + } + block += 1; + } + } + } + log::trace!("Group {} is ready", group) + } + res + } +} + +#[cfg(test)] +mod tests { + use std::{env, fs::File, io::Read}; + + use super::{NTTBanks, NOF_BANKS}; + + #[test] + fn preprocess_correctness() { + let fdir = env::var("FDIR").unwrap(); + let exp = already_preprocess(fdir); + let fname = env::var("FNAME").unwrap(); + let mut f = File::open(&fname).expect("no file found"); + let mut in_vec: Vec = Default::default(); + let _ = f.read_to_end(&mut in_vec); + let got = NTTBanks::preprocess(in_vec); + + for (i, expb) in exp.iter().enumerate().take(16) { + if got.banks[i].eq(expb) { + println!("Bank {} is correct", i); + } + } + } + + fn already_preprocess(fdir: String) -> Vec> { + let mut banks: Vec> = Vec::with_capacity(NOF_BANKS); + for _ in 0..NOF_BANKS { + banks.push(Default::default()); + } + for (i, bank) in banks.iter_mut().enumerate().take(16) { + let fname = format!("{}/inbank{:02}.dat", fdir, i); + println!("Read {}", fname); + let mut f = File::open(&fname).expect("no file found"); + let _ = f.read_to_end(bank); + } + + banks + } + + #[test] + fn postprocess_correctness() { + let fdir = env::var("FDIR").unwrap(); + let fname = env::var("FNAME").unwrap(); + let in_banks: Vec> = already_postprocess(fdir); + let ntt_banks = NTTBanks { + banks: in_banks.try_into().unwrap(), + }; + let got = ntt_banks.postprocess(); + println!("Got result of size: {}", got.len()); + + let mut f = File::open(&fname).expect("no file found"); + let mut exp_out_vec: Vec = Default::default(); + let _ = f.read_to_end(&mut exp_out_vec); + println!("Result is read from: {}", fname); + + if exp_out_vec.eq(&got) { + println!("Result is correct"); + } + } + + fn already_postprocess(fdir: String) -> Vec> { + let mut banks: Vec> = Vec::with_capacity(NOF_BANKS); + for _ in 0..NOF_BANKS { + banks.push(Default::default()); + } + for (i, bank) in banks.iter_mut().enumerate().take(16) { + let fname = format!("{}/outbank{:02}.dat", fdir, i); + println!("Read {}", fname); + let mut f = File::open(&fname).expect("no file found"); + let _ = f.read_to_end(bank); + } + + banks + } +} diff --git a/src/ingo_ntt/ntt_hw_code.rs b/src/ingo_ntt/ntt_hw_code.rs new file mode 100644 index 0000000..25248d3 --- /dev/null +++ b/src/ingo_ntt/ntt_hw_code.rs @@ -0,0 +1,89 @@ +#![allow(non_camel_case_types)] +use strum_macros::EnumString; + +#[repr(u64)] +#[derive(Debug, Copy, Clone, PartialEq, EnumString)] +pub enum INGO_NTT_SUPER_PROGRAM_ADDR { + XHBM_SS_CONTROL_ADDR_AP_CTRL = 0x000, + XHBM_SS_CONTROL_ADDR_GIE = 0x004, + XHBM_SS_CONTROL_ADDR_IER = 0x008, + XHBM_SS_CONTROL_ADDR_ISR = 0x00c, + XHBM_SS_CONTROL_ADDR_HIF_INPUT_BUFFER_DATA = 0x010, + XHBM_SS_CONTROL_ADDR_HIF_INPUT_ENABLE_DEBUG_PROGRAM_DATA = 0x018, + XHBM_SS_CONTROL_ADDR_INFO_PROGRAM_DATA = 0x040, + XHBM_SS_CONTROL_ADDR_INFO_PROGRAM_CTRL = 0x044, + XHBM_SS_CONTROL_ADDR_INFO_WRITE00_DATA = 0x050, + XHBM_SS_CONTROL_ADDR_INFO_WRITE00_CTRL = 0x054, + XHBM_SS_CONTROL_ADDR_INFO_WRITE01_DATA = 0x060, + XHBM_SS_CONTROL_ADDR_INFO_WRITE01_CTRL = 0x064, + XHBM_SS_CONTROL_ADDR_INFO_WRITE02_DATA = 0x070, + XHBM_SS_CONTROL_ADDR_INFO_WRITE02_CTRL = 0x074, + XHBM_SS_CONTROL_ADDR_INFO_WRITE03_DATA = 0x080, + XHBM_SS_CONTROL_ADDR_INFO_WRITE03_CTRL = 0x084, + XHBM_SS_CONTROL_ADDR_INFO_WRITE04_DATA = 0x090, + XHBM_SS_CONTROL_ADDR_INFO_WRITE04_CTRL = 0x094, + XHBM_SS_CONTROL_ADDR_INFO_WRITE05_DATA = 0x0a0, + XHBM_SS_CONTROL_ADDR_INFO_WRITE05_CTRL = 0x0a4, + XHBM_SS_CONTROL_ADDR_INFO_WRITE06_DATA = 0x0b0, + XHBM_SS_CONTROL_ADDR_INFO_WRITE06_CTRL = 0x0b4, + XHBM_SS_CONTROL_ADDR_INFO_WRITE07_DATA = 0x0c0, + XHBM_SS_CONTROL_ADDR_INFO_WRITE07_CTRL = 0x0c4, + XHBM_SS_CONTROL_ADDR_INFO_WRITE08_DATA = 0x0d0, + XHBM_SS_CONTROL_ADDR_INFO_WRITE08_CTRL = 0x0d4, + XHBM_SS_CONTROL_ADDR_INFO_WRITE09_DATA = 0x0e0, + XHBM_SS_CONTROL_ADDR_INFO_WRITE09_CTRL = 0x0e4, + XHBM_SS_CONTROL_ADDR_INFO_WRITE10_DATA = 0x0f0, + XHBM_SS_CONTROL_ADDR_INFO_WRITE10_CTRL = 0x0f4, + XHBM_SS_CONTROL_ADDR_INFO_WRITE11_DATA = 0x100, + XHBM_SS_CONTROL_ADDR_INFO_WRITE11_CTRL = 0x104, + XHBM_SS_CONTROL_ADDR_INFO_WRITE12_DATA = 0x110, + XHBM_SS_CONTROL_ADDR_INFO_WRITE12_CTRL = 0x114, + XHBM_SS_CONTROL_ADDR_INFO_WRITE13_DATA = 0x120, + XHBM_SS_CONTROL_ADDR_INFO_WRITE13_CTRL = 0x124, + XHBM_SS_CONTROL_ADDR_INFO_WRITE14_DATA = 0x130, + XHBM_SS_CONTROL_ADDR_INFO_WRITE14_CTRL = 0x134, + XHBM_SS_CONTROL_ADDR_INFO_WRITE15_DATA = 0x140, + XHBM_SS_CONTROL_ADDR_INFO_WRITE15_CTRL = 0x144, + XHBM_SS_CONTROL_ADDR_INFO_READ00_DATA = 0x150, + XHBM_SS_CONTROL_ADDR_INFO_READ00_CTRL = 0x154, + XHBM_SS_CONTROL_ADDR_INFO_READ01_DATA = 0x160, + XHBM_SS_CONTROL_ADDR_INFO_READ01_CTRL = 0x164, + XHBM_SS_CONTROL_ADDR_INFO_READ02_DATA = 0x170, + XHBM_SS_CONTROL_ADDR_INFO_READ02_CTRL = 0x174, + XHBM_SS_CONTROL_ADDR_INFO_READ03_DATA = 0x180, + XHBM_SS_CONTROL_ADDR_INFO_READ03_CTRL = 0x184, + XHBM_SS_CONTROL_ADDR_INFO_READ04_DATA = 0x190, + XHBM_SS_CONTROL_ADDR_INFO_READ04_CTRL = 0x194, + XHBM_SS_CONTROL_ADDR_INFO_READ05_DATA = 0x1a0, + XHBM_SS_CONTROL_ADDR_INFO_READ05_CTRL = 0x1a4, + XHBM_SS_CONTROL_ADDR_INFO_READ06_DATA = 0x1b0, + XHBM_SS_CONTROL_ADDR_INFO_READ06_CTRL = 0x1b4, + XHBM_SS_CONTROL_ADDR_INFO_READ07_DATA = 0x1c0, + XHBM_SS_CONTROL_ADDR_INFO_READ07_CTRL = 0x1c4, + XHBM_SS_CONTROL_ADDR_INFO_READ08_DATA = 0x1d0, + XHBM_SS_CONTROL_ADDR_INFO_READ08_CTRL = 0x1d4, + XHBM_SS_CONTROL_ADDR_INFO_READ09_DATA = 0x1e0, + XHBM_SS_CONTROL_ADDR_INFO_READ09_CTRL = 0x1e4, + XHBM_SS_CONTROL_ADDR_INFO_READ10_DATA = 0x1f0, + XHBM_SS_CONTROL_ADDR_INFO_READ10_CTRL = 0x1f4, + XHBM_SS_CONTROL_ADDR_INFO_READ11_DATA = 0x200, + XHBM_SS_CONTROL_ADDR_INFO_READ11_CTRL = 0x204, + XHBM_SS_CONTROL_ADDR_INFO_READ12_DATA = 0x210, + XHBM_SS_CONTROL_ADDR_INFO_READ12_CTRL = 0x214, + XHBM_SS_CONTROL_ADDR_INFO_READ13_DATA = 0x220, + XHBM_SS_CONTROL_ADDR_INFO_READ13_CTRL = 0x224, + XHBM_SS_CONTROL_ADDR_INFO_READ14_DATA = 0x230, + XHBM_SS_CONTROL_ADDR_INFO_READ14_CTRL = 0x234, + XHBM_SS_CONTROL_ADDR_INFO_READ15_DATA = 0x240, + XHBM_SS_CONTROL_ADDR_INFO_READ15_CTRL = 0x244, + XHBM_SS_CONTROL_ADDR_DEBUG_OUTPUT_DATA = 0x250, + XHBM_SS_CONTROL_ADDR_DEBUG_OUTPUT_CTRL = 0x254, + XHBM_SS_CONTROL_ADDR_HIF_INPUT_DEBUG_PROGRAM_BASE = 0x020, + XHBM_SS_CONTROL_ADDR_HIF_INPUT_DEBUG_PROGRAM_HIGH = 0x03f, +} + +impl From for u64 { + fn from(addr: INGO_NTT_SUPER_PROGRAM_ADDR) -> Self { + addr as u64 + } +} diff --git a/src/lib.rs b/src/lib.rs index e3b3dc3..30ee351 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,4 +9,5 @@ pub mod driver_client; pub mod error; pub mod ingo_hash; pub mod ingo_msm; +pub mod ingo_ntt; pub mod utils; diff --git a/src/utils.rs b/src/utils.rs index 4048b1b..84528d8 100755 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,4 @@ use anyhow; -use serde::{de, Deserialize}; use std::{ fs::{File, OpenOptions}, io::{Error, Read}, @@ -130,23 +129,6 @@ pub fn convert_to_32_byte_array(init: &[u8]) -> [u8; 32] { arr } -pub fn deserialize_hex<'de, D: serde::Deserializer<'de>>(d: D) -> Result { - let hex: String = Deserialize::deserialize(d)?; - u64::from_str_radix(hex.trim_start_matches("0x"), 16).map_err(de::Error::custom) -} - -pub fn deserialize_option_hex<'de, D: serde::Deserializer<'de>>( - d: D, -) -> Result, D::Error> { - let hex: Option = Option::deserialize(d)?; - if let Some(hex) = hex { - return Ok(Some( - u64::from_str_radix(hex.trim_start_matches("0x"), 16).map_err(de::Error::custom)?, - )); - } - Ok(None) -} - // ==== general ==== pub fn retry( args: T, diff --git a/tests/integration_msm.rs b/tests/integration_msm.rs index 7d6df1c..3e08533 100644 --- a/tests/integration_msm.rs +++ b/tests/integration_msm.rs @@ -1,4 +1,4 @@ -use ingo_blaze::{driver_client::dclient::*, ingo_msm::*, utils::*}; +use ingo_blaze::{driver_client::*, ingo_msm::*, utils::*}; use num_traits::Pow; use std::{ env, @@ -24,12 +24,12 @@ fn load_msm_binary_test() -> Result<(), Box> { log::info!("MSM Size: {}", msm_size); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: false, - curve: msm_api::Curve::BLS381, + curve: Curve::BLS381, }, dclient, ); @@ -54,7 +54,7 @@ fn load_msm_binary_test() -> Result<(), Box> { params[0], params[1].reverse_bits() ); - let params_parce = msm_api::MSMImageParametrs::parse_image_params(params[1]); + let params_parce = MSMImageParametrs::parse_image_params(params[1]); params_parce.debug_information(); log::info!("Checking MSM core is ready: "); @@ -62,30 +62,28 @@ fn load_msm_binary_test() -> Result<(), Box> { driver.task_label()?; let (points, scalars, msm_result, results) = - msm::input_generator_bls12_381(msm_size as usize, msm_api::PRECOMPUTE_FACTOR_BASE); + msm::input_generator_bls12_381(msm_size as usize, PRECOMPUTE_FACTOR_BASE); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; + log::info!("Starting to calculate MSM: "); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points), scalars, params: msm_params, - }); + })?; driver.wait_result()?; let mres = driver.result(None).unwrap().unwrap(); let (is_on_curve, is_eq) = msm::result_check_bls12_381(mres.result, msm_result, results, msm_size as usize); - log::info!( - "Is point on the {:?} curve {}", - msm_api::Curve::BLS377, - is_on_curve - ); + log::info!("Is point on the {:?} curve {}", Curve::BLS377, is_on_curve); log::info!("Is Result Equal To Expected {}", is_eq); assert!(is_on_curve); assert!(is_eq); @@ -103,48 +101,45 @@ fn msm_bls12_377_test() -> Result<(), Box> { .unwrap(); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: false, - curve: msm_api::Curve::BLS377, + curve: Curve::BLS377, }, dclient, ); let params = driver.loaded_binary_parameters(); - let params_parce = msm_api::MSMImageParametrs::parse_image_params(params[1]); + let params_parce = MSMImageParametrs::parse_image_params(params[1]); params_parce.debug_information(); log::info!("Checking MSM core is ready: "); driver.is_msm_engine_ready()?; driver.task_label()?; let (points, scalars, msm_result, results) = - msm::input_generator_bls12_377(msm_size as usize, msm_api::PRECOMPUTE_FACTOR_BASE); + msm::input_generator_bls12_377(msm_size as usize, PRECOMPUTE_FACTOR_BASE); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; log::info!("Starting to calculate MSM: "); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points), scalars, params: msm_params, - }); + })?; driver.wait_result()?; let mres = driver.result(None).unwrap().unwrap(); let (is_on_curve, is_eq) = msm::result_check_bls12_377(mres.result, msm_result, results, msm_size as usize); - log::info!( - "Is point on the {:?} curve {}", - msm_api::Curve::BLS377, - is_on_curve - ); + log::info!("Is point on the {:?} curve {}", Curve::BLS377, is_on_curve); log::info!("Is Result Equal To Expected {}", is_eq); assert!(is_on_curve); assert!(is_eq); @@ -161,21 +156,21 @@ fn msm_bls12_381_test() -> Result<(), Box> { .unwrap(); let (points, scalars, msm_result, results) = - msm::input_generator_bls12_381(msm_size as usize, msm_api::PRECOMPUTE_FACTOR_BASE); + msm::input_generator_bls12_381(msm_size as usize, PRECOMPUTE_FACTOR_BASE); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: false, - curve: msm_api::Curve::BLS381, + curve: Curve::BLS381, }, dclient, ); let params = driver.loaded_binary_parameters(); - let params_parce = msm_api::MSMImageParametrs::parse_image_params(params[1]); + let params_parce = MSMImageParametrs::parse_image_params(params[1]); params_parce.debug_information(); log::info!("Checking MSM core is ready: "); driver.is_msm_engine_ready()?; @@ -183,18 +178,19 @@ fn msm_bls12_381_test() -> Result<(), Box> { driver.driver_client.firewalls_status(); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; log::info!("Starting to calculate MSM: "); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points), scalars, params: msm_params, - }); + })?; driver.driver_client.firewalls_status(); driver.task_label()?; log::info!("Waiting MSM result: "); @@ -202,11 +198,7 @@ fn msm_bls12_381_test() -> Result<(), Box> { let mres = driver.result(None).unwrap().unwrap(); let (is_on_curve, is_eq) = msm::result_check_bls12_381(mres.result, msm_result, results, msm_size as usize); - log::info!( - "Is point on the {:?} curve {}", - msm_api::Curve::BLS381, - is_on_curve - ); + log::info!("Is point on the {:?} curve {}", Curve::BLS381, is_on_curve); log::info!("Is Result Equal To Expected {}", is_eq); assert!(is_on_curve); assert!(is_eq); @@ -224,48 +216,45 @@ fn msm_bn254_test() -> Result<(), Box> { .unwrap(); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: false, - curve: msm_api::Curve::BN254, + curve: Curve::BN254, }, dclient, ); let params = driver.loaded_binary_parameters(); - let params_parce = msm_api::MSMImageParametrs::parse_image_params(params[1]); + let params_parce = MSMImageParametrs::parse_image_params(params[1]); params_parce.debug_information(); log::info!("Checking MSM core is ready: "); driver.is_msm_engine_ready()?; driver.task_label()?; let (points, scalars, msm_result, results) = - msm::input_generator_bn254(msm_size as usize, msm_api::PRECOMPUTE_FACTOR_BASE); + msm::input_generator_bn254(msm_size as usize, PRECOMPUTE_FACTOR_BASE); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; log::info!("Starting to calculate MSM: "); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points), scalars, params: msm_params, - }); + })?; driver.wait_result()?; let mres = driver.result(None).unwrap().unwrap(); let (is_on_curve, is_eq) = msm::result_check_bn254(mres.result, msm_result, results, msm_size as usize); - log::info!( - "Is point on the {:?} curve {}", - msm_api::Curve::BN254, - is_on_curve - ); + log::info!("Is point on the {:?} curve {}", Curve::BN254, is_on_curve); log::info!("Is Result Equal To Expected {}", is_eq); assert!(is_on_curve); assert!(is_eq); @@ -302,10 +291,8 @@ fn msm_bls12_377_precompute_test() -> Result<(), Box> { log::debug!("Timer generation start"); let start_gen = Instant::now(); - let (points, scalars, _, results) = msm::input_generator_bls12_377( - Pow::pow(base, max_exp) as usize, - msm_api::PRECOMPUTE_FACTOR, - ); + let (points, scalars, _, results) = + msm::input_generator_bls12_377(Pow::pow(base, max_exp) as usize, PRECOMPUTE_FACTOR); let duration_gen = start_gen.elapsed(); log::debug!("Time elapsed in input generation is: {:?}", duration_gen); @@ -313,24 +300,23 @@ fn msm_bls12_377_precompute_test() -> Result<(), Box> { for iter in low_exp..=max_exp { let msm_size = Pow::pow(base, iter) as usize; log::debug!("MSM size: {}", msm_size); - let mut points_to_run = vec![0; msm_size * 96 * msm_api::PRECOMPUTE_FACTOR as usize]; + let mut points_to_run = vec![0; msm_size * 96 * PRECOMPUTE_FACTOR as usize]; let mut scalars_to_run = vec![0; msm_size * 32]; - points_to_run - .copy_from_slice(&points[0..msm_size * 96 * msm_api::PRECOMPUTE_FACTOR as usize]); + points_to_run.copy_from_slice(&points[0..msm_size * 96 * PRECOMPUTE_FACTOR as usize]); scalars_to_run.copy_from_slice(&scalars[0..msm_size * 32]); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: true, - curve: msm_api::Curve::BLS377, + curve: Curve::BLS377, }, dclient, ); - driver.reset()?; + driver.driver_client.reset()?; log::info!("Checking MSM core is ready: "); driver.is_msm_engine_ready()?; @@ -338,11 +324,12 @@ fn msm_bls12_377_precompute_test() -> Result<(), Box> { driver.driver_client.firewalls_status(); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size as u32, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; driver.driver_client.firewalls_status(); @@ -350,11 +337,11 @@ fn msm_bls12_377_precompute_test() -> Result<(), Box> { log::debug!("Timer start"); let start_set_data = Instant::now(); let start_full = Instant::now(); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points_to_run), scalars: scalars_to_run, params: msm_params, - }); + })?; let dur_set = start_set_data.elapsed(); let start_get = Instant::now(); @@ -404,24 +391,24 @@ fn msm_bls12_377_precompute_max_test() -> Result<(), Box> log::debug!("Timer start to generate test data"); let start_gen = Instant::now(); let (points, scalars, msm_result, results) = - msm::input_generator_bls12_377(msm_size as usize, msm_api::PRECOMPUTE_FACTOR); + msm::input_generator_bls12_377(msm_size as usize, PRECOMPUTE_FACTOR); let duration_gen = start_gen.elapsed(); log::debug!("Time elapsed in generate test data is: {:?}", duration_gen); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: true, - curve: msm_api::Curve::BLS377, + curve: Curve::BLS377, }, dclient, ); - driver.reset()?; + driver.driver_client.reset()?; let params = driver.loaded_binary_parameters(); - let params_parce = msm_api::MSMImageParametrs::parse_image_params(params[1]); + let params_parce = MSMImageParametrs::parse_image_params(params[1]); params_parce.debug_information(); log::info!("Checking MSM core is ready: "); driver.is_msm_engine_ready()?; @@ -429,11 +416,12 @@ fn msm_bls12_377_precompute_max_test() -> Result<(), Box> driver.driver_client.firewalls_status(); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; driver.driver_client.firewalls_status(); driver.task_label()?; @@ -443,11 +431,11 @@ fn msm_bls12_377_precompute_max_test() -> Result<(), Box> log::debug!("Timer start"); let start_set_data = Instant::now(); let start_full = Instant::now(); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points), scalars, params: msm_params, - }); + })?; let dur_set = start_set_data.elapsed(); let start_wait = Instant::now(); @@ -488,10 +476,8 @@ fn msm_bls12_381_precompute_test() -> Result<(), Box> { log::debug!("Timer generation start"); let start_gen = Instant::now(); - let (points, scalars, _, results) = msm::input_generator_bls12_381( - Pow::pow(base, max_exp) as usize, - msm_api::PRECOMPUTE_FACTOR, - ); + let (points, scalars, _, results) = + msm::input_generator_bls12_381(Pow::pow(base, max_exp) as usize, PRECOMPUTE_FACTOR); let duration_gen = start_gen.elapsed(); log::debug!("Time elapsed in input generation is: {:?}", duration_gen); @@ -499,24 +485,23 @@ fn msm_bls12_381_precompute_test() -> Result<(), Box> { for iter in low_exp..=max_exp { let msm_size = Pow::pow(base, iter) as usize; log::debug!("MSM size: {}", msm_size); - let mut points_to_run = vec![0; msm_size * 96 * msm_api::PRECOMPUTE_FACTOR as usize]; + let mut points_to_run = vec![0; msm_size * 96 * PRECOMPUTE_FACTOR as usize]; let mut scalars_to_run = vec![0; msm_size * 32]; - points_to_run - .copy_from_slice(&points[0..msm_size * 96 * msm_api::PRECOMPUTE_FACTOR as usize]); + points_to_run.copy_from_slice(&points[0..msm_size * 96 * PRECOMPUTE_FACTOR as usize]); scalars_to_run.copy_from_slice(&scalars[0..msm_size * 32]); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: true, - curve: msm_api::Curve::BLS381, + curve: Curve::BLS381, }, dclient, ); - driver.reset()?; + driver.driver_client.reset()?; log::info!("Checking MSM core is ready: "); driver.is_msm_engine_ready()?; @@ -524,11 +509,12 @@ fn msm_bls12_381_precompute_test() -> Result<(), Box> { // driver.driver_client.firewalls_status(); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size as u32, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; driver.driver_client.firewalls_status(); @@ -536,11 +522,11 @@ fn msm_bls12_381_precompute_test() -> Result<(), Box> { log::debug!("Timer start"); let start_set_data = Instant::now(); let start_full = Instant::now(); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points_to_run), scalars: scalars_to_run, params: msm_params, - }); + })?; // driver.get_api(); let dur_set = start_set_data.elapsed(); let start_get = Instant::now(); @@ -590,24 +576,24 @@ fn msm_bls12_381_precompute_max_test() -> Result<(), Box> log::debug!("Timer start to generate test data"); let start_gen = Instant::now(); let (points, scalars, msm_result, results) = - msm::input_generator_bls12_381(msm_size as usize, msm_api::PRECOMPUTE_FACTOR); + msm::input_generator_bls12_381(msm_size as usize, PRECOMPUTE_FACTOR); let duration_gen = start_gen.elapsed(); log::debug!("Time elapsed in generate test data is: {:?}", duration_gen); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: true, - curve: msm_api::Curve::BLS381, + curve: Curve::BLS381, }, dclient, ); - driver.reset()?; + driver.driver_client.reset()?; let params = driver.loaded_binary_parameters(); - let params_parce = msm_api::MSMImageParametrs::parse_image_params(params[1]); + let params_parce = MSMImageParametrs::parse_image_params(params[1]); params_parce.debug_information(); log::info!("Checking MSM core is ready: "); driver.is_msm_engine_ready()?; @@ -615,11 +601,12 @@ fn msm_bls12_381_precompute_max_test() -> Result<(), Box> driver.driver_client.firewalls_status(); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size, hbm_point_addr: None, }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; driver.driver_client.firewalls_status(); driver.task_label()?; @@ -629,11 +616,11 @@ fn msm_bls12_381_precompute_max_test() -> Result<(), Box> log::debug!("Timer start"); let start_set_data = Instant::now(); let start_full = Instant::now(); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: Some(points), scalars, params: msm_params, - }); + })?; let dur_set = start_set_data.elapsed(); let start_wait = Instant::now(); diff --git a/tests/integration_msm_hbm.rs b/tests/integration_msm_hbm.rs index 2efd26c..844b4a6 100644 --- a/tests/integration_msm_hbm.rs +++ b/tests/integration_msm_hbm.rs @@ -1,5 +1,5 @@ use crate::msm::RunResults; -use ingo_blaze::{driver_client::dclient::*, ingo_msm::*}; +use ingo_blaze::{driver_client::*, ingo_msm::*}; use num_traits::Pow; use std::{ env, @@ -19,10 +19,8 @@ fn hbm_msm_bls12_381_precomp_test() -> Result<(), Box> { log::debug!("Timer generation start"); let start_gen = Instant::now(); - let (points, scalars, _, results) = msm::input_generator_bls12_381( - Pow::pow(base, max_exp) as usize, - msm_api::PRECOMPUTE_FACTOR, - ); + let (points, scalars, _, results) = + msm::input_generator_bls12_381(Pow::pow(base, max_exp) as usize, PRECOMPUTE_FACTOR); let duration_gen = start_gen.elapsed(); log::debug!("Time elapsed in input generation is: {:?}", duration_gen); @@ -37,16 +35,16 @@ fn hbm_msm_bls12_381_precomp_test() -> Result<(), Box> { scalars_to_run.copy_from_slice(&scalars[0..msm_size * 32]); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::DMA, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::DMA, is_precompute: true, - curve: msm_api::Curve::BLS381, + curve: Curve::BLS381, }, dclient, ); - driver.reset()?; + driver.driver_client.reset()?; let hbm_addr: u64 = 0x0; let offset: u64 = 0x0; @@ -63,23 +61,24 @@ fn hbm_msm_bls12_381_precomp_test() -> Result<(), Box> { driver.driver_client.firewalls_status(); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size as u32, hbm_point_addr: Some((hbm_addr, offset)), }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; driver.driver_client.firewalls_status(); log::info!("Starting to calculate MSM: "); log::debug!("Timer start"); let start_set_data = Instant::now(); let start_full = Instant::now(); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: None, scalars: scalars_to_run, params: msm_params, - }); + })?; driver.get_api(); let dur_set = start_set_data.elapsed(); let start_get = Instant::now(); @@ -128,10 +127,8 @@ fn hbm_msm_bls12_377_precomp_test() -> Result<(), Box> { log::debug!("Timer generation start"); let start_gen = Instant::now(); - let (points, scalars, _, results) = msm::input_generator_bls12_377( - Pow::pow(base, max_exp) as usize, - msm_api::PRECOMPUTE_FACTOR, - ); + let (points, scalars, _, results) = + msm::input_generator_bls12_377(Pow::pow(base, max_exp) as usize, PRECOMPUTE_FACTOR); let duration_gen = start_gen.elapsed(); log::debug!("Time elapsed in input generation is: {:?}", duration_gen); @@ -146,16 +143,16 @@ fn hbm_msm_bls12_377_precomp_test() -> Result<(), Box> { scalars_to_run.copy_from_slice(&scalars[0..msm_size * 32]); log::info!("Create Driver API instance"); - let dclient = DriverClient::new(&id, DriverConfig::driver_client_c1100_cfg()); - let driver = msm_api::MSMClient::new( - msm_api::MSMInit { - mem_type: msm_api::PointMemoryType::HBM, + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + let driver = MSMClient::new( + MSMInit { + mem_type: PointMemoryType::HBM, is_precompute: true, - curve: msm_api::Curve::BLS377, + curve: Curve::BLS377, }, dclient, ); - driver.reset()?; + driver.driver_client.reset()?; // log::debug!("writing data to HBM"); let hbm_addr: u64 = 0x0; @@ -173,23 +170,24 @@ fn hbm_msm_bls12_377_precomp_test() -> Result<(), Box> { driver.driver_client.firewalls_status(); log::info!("Starting to initialize task and set number of elements: "); - let msm_params = msm_api::MSMParams { + let msm_params = MSMParams { nof_elements: msm_size as u32, hbm_point_addr: Some((hbm_addr, offset)), }; - let _ = driver.initialize(msm_params); + driver.initialize(msm_params)?; + driver.start_process(None)?; driver.driver_client.firewalls_status(); log::info!("Starting to calculate MSM: "); log::debug!("Timer start"); let start_set_data = Instant::now(); let start_full = Instant::now(); - let _ = driver.set_data(msm_api::MSMInput { + driver.set_data(MSMInput { points: None, scalars: scalars_to_run, params: msm_params, - }); + })?; let dur_set = start_set_data.elapsed(); let start_get = Instant::now(); driver.driver_client.firewalls_status(); diff --git a/tests/integration_ntt.rs b/tests/integration_ntt.rs new file mode 100644 index 0000000..2d42e09 --- /dev/null +++ b/tests/integration_ntt.rs @@ -0,0 +1,146 @@ +use ingo_blaze::{driver_client::*, ingo_ntt::*, utils}; +use log::info; +use std::{env, error::Error, fs::File, io::Read}; + +#[test] +fn ntt_test_correctness() -> Result<(), Box> { + env_logger::try_init().expect("Invalid logger initialisation"); + let id = env::var("ID").unwrap_or_else(|_| 0.to_string()); + + let input_fname = env::var("INFNAME").unwrap(); + let mut in_f = File::open(input_fname).expect("no file found"); + let mut in_vec: Vec = Default::default(); + in_f.read_to_end(&mut in_vec)?; + + let output_fname = env::var("OUTFNAME").unwrap(); + let mut out_f = File::open(output_fname).expect("no file found"); + let mut out_vec: Vec = Default::default(); + out_f.read_to_end(&mut out_vec)?; + + let buf_host = 0; + let buf_kernel = 0; + + info!("Create Driver API instance"); + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + if std::env::var("BIN").is_ok() { + let bin_fname = std::env::var("BIN").unwrap(); + info!("Start reading binary"); + let bin = utils::read_binary_file(&bin_fname)?; + info!("Start setup FPGA"); + dclient.setup_before_load_binary()?; + info!("Start loading driver"); + dclient.load_binary(&bin)?; + } + + let driver = NTTClient::new(NTT::Ntt, dclient); + log::info!("Starting set NTT data"); + driver.set_data(NTTInput { + buf_host, + data: in_vec, + })?; + log::info!("Successfully set NTT data"); + driver.driver_client.initialize_cms()?; + driver.driver_client.reset_sensor_data()?; + + for i in 0..1 { + log::info!("Starting NTT: {:?}", i); + driver.initialize(NttInit {})?; + driver.start_process(Some(buf_kernel))?; + driver.wait_result()?; + driver.driver_client.reset()?; + log::info!("Finishing NTT: {:?}", i); + } + + log::info!("Try to get NTT result"); + let res = driver.result(Some(buf_kernel))?.unwrap(); + log::info!("Get NTT result of size: {:?}", res.len()); + assert_eq!(res, out_vec); + + Ok(()) +} + +#[test] +fn ntt_parallel_test_correctness() -> Result<(), Box> { + env_logger::try_init().expect("Invalid logger initialisation"); + const NOF_VECTORS: usize = 3; + let id = env::var("ID").unwrap_or_else(|_| 0.to_string()); + + let in_dir = env::var("INDIR").unwrap(); + let mut in_vecs: Vec> = vec![Default::default(); NOF_VECTORS]; + + for (i, in_vec) in in_vecs.iter_mut().enumerate().take(NOF_VECTORS) { + let input_fname = format! {"{}/in_bin_{:02?}.dat", in_dir, i}; + let mut in_f = File::open(&input_fname).expect("no file found"); + in_f.read_to_end(in_vec)?; + log::info!("Read input from file {:?}", input_fname); + } + + let ref_dir = env::var("REFDIR").unwrap(); + let mut ref_vecs: Vec> = vec![Default::default(); NOF_VECTORS]; + for (i, ref_vec) in ref_vecs.iter_mut().enumerate().take(NOF_VECTORS) { + let ref_fname = format! {"{}/ref_bin_{:02?}.dat", ref_dir, i}; + let mut ref_f = File::open(&ref_fname).expect("no file found"); + ref_f.read_to_end(ref_vec)?; + log::info!("Read reference from file {:?}", ref_fname); + } + + info!("Create Driver API instance"); + let dclient = DriverClient::new(&id, DriverConfig::driver_client_cfg(CardType::C1100)); + if std::env::var("BIN").is_ok() { + let bin_fname = std::env::var("BIN").unwrap(); + info!("Start reading binary"); + let bin = utils::read_binary_file(&bin_fname)?; + info!("Start setup FPGA"); + dclient.setup_before_load_binary()?; + info!("Start loading driver"); + dclient.load_binary(&bin)?; + } + + let driver = NTTClient::new(NTT::Ntt, dclient); + driver.initialize(NttInit {})?; + + let mut outputs: Vec> = Vec::new(); + for i in 0..(NOF_VECTORS + 2) { + let buf_host = i % 2; + let buf_kernel = 1 - buf_host; + info!("Cycle {}: host = {}, kernel = {}", i, buf_host, buf_kernel); + log::info!("Starting process {:?} on kernel {:?}", i, buf_kernel); + driver.start_process(Some(buf_kernel))?; + + log::info!("Try to get NTT result"); + let res = driver.result(Some(buf_host))?.unwrap(); + log::info!("Get NTT result of size: {:?}", res.len()); + if i >= 2 { + log::info!("Save result {}", i - 2); + outputs.push(res) + } + + let host_wr_idx = i; + let host_wr_idx_adj = if host_wr_idx > NOF_VECTORS - 1 { + NOF_VECTORS - 1 + } else { + host_wr_idx + }; + log::info!( + "Starting set NTT data with params: [{:?}, {:?}]", + host_wr_idx, + host_wr_idx_adj + ); + driver.set_data(NTTInput { + buf_host, + data: in_vecs[host_wr_idx_adj].clone(), + })?; + log::info!("Successfully set NTT data"); + driver.wait_result()?; + log::info!("Finishing Cycle: {:?}", i); + } + + log::info!("Starting to check correctness"); + for (i, out_vec) in outputs.into_iter().enumerate() { + log::info!("Checking output: {:?}", i); + assert_eq!(out_vec, ref_vecs[i].clone()); + log::info!("Result {:?} correct", i); + } + + Ok(()) +} diff --git a/tests/integration_poseidon.rs b/tests/integration_poseidon.rs index 785010c..d3f5723 100755 --- a/tests/integration_poseidon.rs +++ b/tests/integration_poseidon.rs @@ -6,9 +6,11 @@ use std::{ use dotenv::dotenv; use ingo_blaze::{ - driver_client::dclient::*, - ingo_hash::poseidon_api::{Hash, PoseidonClient, PoseidonInitializeParameters, PoseidonResult}, - ingo_hash::utils::{num_of_elements_in_base_layer, TreeMode}, + driver_client::*, + ingo_hash::{ + num_of_elements_in_base_layer, Hash, PoseidonClient, PoseidonInitializeParameters, + PoseidonResult, TreeMode, + }, }; use log::info; use num::{BigUint, Num}; @@ -28,10 +30,10 @@ const ONE: u32 = 1; fn test_sanity_check() { let instruction_path = get_instruction_path(); - let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); let poseidon: PoseidonClient = PoseidonClient::new(Hash::Poseidon, dclient); - poseidon.reset().expect_err("Failed while reset"); + poseidon.dclient.reset().expect_err("Failed while reset"); let params = poseidon.loaded_binary_parameters(); info!("Driver parameters: [{:?}, {:032b}]", params[0], params[1]); @@ -60,7 +62,7 @@ fn test_build_small_tree_par() { env_logger::try_init().expect("Invalid logger initialization"); - let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); let poseidon: PoseidonClient = PoseidonClient::new(Hash::Poseidon, dclient); let poseidon = Arc::new(Mutex::new(poseidon)); @@ -123,7 +125,7 @@ fn test_build_small_tree() { env_logger::try_init().expect("Invalid logger initialization"); - let dclient = DriverClient::new("0", DriverConfig::driver_client_c1100_cfg()); + let dclient = DriverClient::new("0", DriverConfig::driver_client_cfg(CardType::C1100)); let poseidon: PoseidonClient = PoseidonClient::new(Hash::Poseidon, dclient); let params = PoseidonInitializeParameters {