diff --git a/internal/runtime/hostfunctions.go b/internal/runtime/hostfunctions.go index 26c885bf2..6b810cc8d 100644 --- a/internal/runtime/hostfunctions.go +++ b/internal/runtime/hostfunctions.go @@ -1271,6 +1271,7 @@ func RegisterHostFunctions(runtime wazero.Runtime, env *RuntimeEnvironment) (waz WithResultNames("key_ptr"). Export("db_next_key") + // Compile and return the module return builder.Compile(context.Background()) } diff --git a/internal/runtime/wazeroruntime.go b/internal/runtime/wazeroruntime.go index fd8ec1aa3..5bfb6e089 100644 --- a/internal/runtime/wazeroruntime.go +++ b/internal/runtime/wazeroruntime.go @@ -60,6 +60,9 @@ const ( // Size of a Region struct in bytes (3x4 bytes) regionSize = 12 + + // Maximum memory pages (1024 pages = 64 MiB) + maxMemoryPages = 32768 ) // Region describes data allocated in Wasm's linear memory @@ -69,6 +72,27 @@ type Region struct { Length uint32 } +// toBytes converts a Region to its byte representation (little endian) +func (r *Region) toBytes() []byte { + bytes := make([]byte, 12) + binary.LittleEndian.PutUint32(bytes[0:4], r.Offset) + binary.LittleEndian.PutUint32(bytes[4:8], r.Capacity) + binary.LittleEndian.PutUint32(bytes[8:12], r.Length) + return bytes +} + +// fromBytes reads a Region from its byte representation (little endian) +func regionFromBytes(data []byte) (*Region, error) { + if len(data) != 12 { + return nil, fmt.Errorf("invalid region size: expected 12 bytes, got %d", len(data)) + } + return &Region{ + Offset: binary.LittleEndian.Uint32(data[0:4]), + Capacity: binary.LittleEndian.Uint32(data[4:8]), + Length: binary.LittleEndian.Uint32(data[8:12]), + }, nil +} + // validateRegion performs plausibility checks on a Region func validateRegion(region *Region) error { if region.Offset == 0 { @@ -83,6 +107,100 @@ func validateRegion(region *Region) error { return nil } +// readRegion reads a Region struct from memory at the given pointer +func (m *memoryManager) readRegion(ptr uint32) (*Region, error) { + if ptr == 0 { + return nil, fmt.Errorf("null pointer") + } + + // Read 12 bytes for the Region struct + data, ok := m.memory.Read(ptr, 12) + if !ok { + return nil, fmt.Errorf("failed to read Region at ptr=%d", ptr) + } + + region, err := regionFromBytes(data) + if err != nil { + return nil, err + } + + if err := validateRegion(region); err != nil { + return nil, err + } + + return region, nil +} + +// writeRegion writes a Region struct to memory at the given pointer +func (m *memoryManager) writeRegion(ptr uint32, region *Region) error { + if ptr == 0 { + return fmt.Errorf("null pointer") + } + + if err := validateRegion(region); err != nil { + return err + } + + if !m.memory.Write(ptr, region.toBytes()) { + return fmt.Errorf("failed to write Region at ptr=%d", ptr) + } + + return nil +} + +// allocateRegion allocates memory for a Region struct and the data it will contain +func (m *memoryManager) allocateRegion(size uint32) (*Region, uint32, error) { + // Get the allocate function + allocate := m.module.ExportedFunction("allocate") + if allocate == nil { + return nil, 0, fmt.Errorf("allocate function not found in WASM module") + } + + // Check if requested size is within bounds + memSize := uint64(m.memory.Size() * wasmPageSize) + if uint64(size) > memSize { + return nil, 0, fmt.Errorf("requested allocation size %d exceeds memory size %d", size, memSize) + } + + // Allocate memory for the data + dataResult, err := allocate.Call(context.Background(), uint64(size)) + if err != nil { + return nil, 0, fmt.Errorf("failed to allocate data memory: %w", err) + } + dataPtr := uint32(dataResult[0]) + + // Verify data pointer is within bounds + if uint64(dataPtr)+uint64(size) > memSize { + return nil, 0, fmt.Errorf("allocated memory region [%d, %d] exceeds memory size %d", dataPtr, dataPtr+size, memSize) + } + + // Create the Region struct + region := &Region{ + Offset: dataPtr, + Capacity: size, + Length: 0, + } + + // Allocate memory for the Region struct + regionResult, err := allocate.Call(context.Background(), uint64(regionSize)) + if err != nil { + return nil, 0, fmt.Errorf("failed to allocate region memory: %w", err) + } + regionPtr := uint32(regionResult[0]) + + // Verify region pointer is within bounds + if uint64(regionPtr)+uint64(regionSize) > memSize { + return nil, 0, fmt.Errorf("allocated region struct [%d, %d] exceeds memory size %d", regionPtr, regionPtr+regionSize, memSize) + } + + // Write the Region struct to memory + if !m.memory.Write(regionPtr, region.toBytes()) { + return nil, 0, fmt.Errorf("failed to write region to memory at ptr=%d", regionPtr) + } + + return region, regionPtr, nil +} + // memoryManager handles WASM memory allocation and deallocation type memoryManager struct { memory api.Memory @@ -96,176 +214,127 @@ func newMemoryManager(memory api.Memory, module api.Module) *memoryManager { } } -// writeToMemory writes data to WASM memory and returns the pointer and size -func (m *memoryManager) writeToMemory(data []byte) (uint32, uint32, error) { +// allocateAndWrite allocates memory and writes data directly, returning the pointer +func (m *memoryManager) allocateAndWrite(data []byte) (uint32, error) { if data == nil { - return 0, 0, nil + return 0, nil } // Get the allocate function allocate := m.module.ExportedFunction("allocate") if allocate == nil { - return 0, 0, fmt.Errorf("allocate function not found in WASM module") + return 0, fmt.Errorf("allocate function not found in WASM module") } - // Allocate memory for the Region struct (12 bytes) and the data - size := uint32(len(data)) - results, err := allocate.Call(context.Background(), uint64(size+regionSize)) - if err != nil { - return 0, 0, fmt.Errorf("failed to allocate memory: %w", err) - } - ptr := uint32(results[0]) + // Calculate memory size in bytes + memSize := m.memory.Size() + memSizeBytes := memSize * uint32(wasmPageSize) - // Create and write the Region struct - region := &Region{ - Offset: ptr + regionSize, // Data starts after the Region struct - Capacity: size, - Length: size, + // Check if we have enough memory + if uint32(len(data)) > memSizeBytes { + return 0, fmt.Errorf("requested allocation size %d exceeds memory size %d", len(data), memSizeBytes) } - // Validate the region before writing - if err := validateRegion(region); err != nil { - deallocate := m.module.ExportedFunction("deallocate") - if deallocate != nil { - if _, err := deallocate.Call(context.Background(), uint64(ptr)); err != nil { - return 0, 0, fmt.Errorf("deallocation failed: %w", err) - } - } - return 0, 0, fmt.Errorf("invalid region: %w", err) + // Allocate memory for the data + dataResult, err := allocate.Call(context.Background(), uint64(len(data))) + if err != nil { + return 0, fmt.Errorf("failed to allocate data memory: %w", err) } + dataPtr := uint32(dataResult[0]) - // Write the Region struct - if err := m.writeRegion(ptr, region); err != nil { - deallocate := m.module.ExportedFunction("deallocate") - if deallocate != nil { - if _, err := deallocate.Call(context.Background(), uint64(ptr)); err != nil { - return 0, 0, fmt.Errorf("deallocation failed: %w", err) - } - } - return 0, 0, fmt.Errorf("failed to write region: %w", err) + // Verify data pointer is within bounds + if dataPtr == 0 || dataPtr+uint32(len(data)) > memSizeBytes { + return 0, fmt.Errorf("allocated memory region [%d, %d] exceeds memory size %d", dataPtr, dataPtr+uint32(len(data)), memSizeBytes) } // Write the actual data - if !m.memory.Write(region.Offset, data) { + if !m.memory.Write(dataPtr, data) { deallocate := m.module.ExportedFunction("deallocate") if deallocate != nil { - if _, err := deallocate.Call(context.Background(), uint64(ptr)); err != nil { - return 0, 0, fmt.Errorf("deallocation failed: %w", err) + if _, err := deallocate.Call(context.Background(), uint64(dataPtr)); err != nil { + fmt.Printf("[DEBUG][Memory] Deallocation failed for ptr=%d: %v\n", dataPtr, err) } } - return 0, 0, fmt.Errorf("failed to write data to memory at ptr=%d size=%d", region.Offset, size) + return 0, fmt.Errorf("failed to write data to memory at ptr=%d size=%d", dataPtr, len(data)) } - return ptr, size, nil + return dataPtr, nil } -// readFromMemory reads data from WASM memory -func (m *memoryManager) readFromMemory(ptr, size uint32) ([]byte, error) { +// readData reads data from memory at the given pointer, handling both raw data pointers and Region struct pointers +func (m *memoryManager) readData(ptr uint32) ([]byte, error) { if ptr == 0 { return nil, nil } - // Read the Region struct first - region, err := m.readRegion(ptr) - if err != nil { - return nil, fmt.Errorf("failed to read region: %w", err) - } + // Calculate memory size in bytes + memSize := m.memory.Size() + memSizeBytes := memSize * uint32(wasmPageSize) - // Validate the region - if err := validateRegion(region); err != nil { - return nil, fmt.Errorf("invalid region: %w", err) + // Check if pointer is within bounds + if ptr >= memSizeBytes { + return nil, fmt.Errorf("pointer %d is out of memory bounds (size: %d)", ptr, memSizeBytes) } - // Verify the size matches what we expect - if region.Length != size { - return nil, fmt.Errorf("size mismatch: expected %d bytes but region specifies %d bytes", size, region.Length) - } - - // Read the actual data using the region's length - data, ok := m.memory.Read(region.Offset, region.Length) + // First try to read as a Region struct + regionData, ok := m.memory.Read(ptr, regionSize) if !ok { - return nil, fmt.Errorf("failed to read memory at ptr=%d size=%d", region.Offset, region.Length) + return nil, fmt.Errorf("failed to read memory at ptr=%d", ptr) } - // Make a copy to ensure we own the data - result := make([]byte, len(data)) - copy(result, data) - - return result, nil -} + // Try to parse as a Region struct + region, err := regionFromBytes(regionData) + if err == nil && region.Offset != 0 && region.Length <= region.Capacity { + // Verify region bounds + if region.Offset+region.Length > memSizeBytes { + return nil, fmt.Errorf("region [%d, %d] exceeds memory bounds (size: %d)", region.Offset, region.Offset+region.Length, memSizeBytes) + } -// readRegion reads a Region struct from memory and validates it -// readRegion reads a Region struct from memory and validates it -func (m *memoryManager) readRegion(ptr uint32) (*Region, error) { - if ptr == 0 { - return nil, fmt.Errorf("null region pointer") + // Looks like a valid Region struct, read the actual data + data, ok := m.memory.Read(region.Offset, region.Length) + if !ok { + return nil, fmt.Errorf("failed to read data at ptr=%d length=%d", region.Offset, region.Length) + } + // Make a copy to ensure we own the data + result := make([]byte, len(data)) + copy(result, data) + return result, nil } - // Read the Region struct (12 bytes total) - data, ok := m.memory.Read(ptr, regionSize) + // Not a valid Region struct, try to read as raw data + // First read 4 bytes for the length prefix + lengthData, ok := m.memory.Read(ptr, 4) if !ok { - return nil, fmt.Errorf("failed to read Region struct at ptr=%d", ptr) + return nil, fmt.Errorf("failed to read length prefix at ptr=%d", ptr) } + length := binary.LittleEndian.Uint32(lengthData) - // Parse the Region struct (little-endian) - region := &Region{ - Offset: binary.LittleEndian.Uint32(data[0:4]), - Capacity: binary.LittleEndian.Uint32(data[4:8]), - Length: binary.LittleEndian.Uint32(data[8:12]), - } - - // Validate the region - if err := validateRegion(region); err != nil { - return nil, fmt.Errorf("invalid region: %w", err) + // Verify length is within bounds + if ptr+4+length > memSizeBytes { + return nil, fmt.Errorf("raw data region [%d, %d] exceeds memory bounds (size: %d)", ptr+4, ptr+4+length, memSizeBytes) } - return region, nil -} - -// writeRegion writes a Region struct to memory -func (m *memoryManager) writeRegion(ptr uint32, region *Region) error { - if ptr == 0 { - return fmt.Errorf("null region pointer") - } - - // Validate the region before writing - if err := validateRegion(region); err != nil { - return fmt.Errorf("invalid region: %w", err) - } - - // Ensure we're not writing out of bounds - memSize := uint64(m.memory.Size()) * wasmPageSize - if uint64(region.Offset)+uint64(region.Capacity) > memSize { - return fmt.Errorf("region exceeds memory bounds: offset=%d, capacity=%d, memSize=%d", region.Offset, region.Capacity, memSize) - } - - // Create the Region struct bytes (little-endian) - data := make([]byte, regionSize) - binary.LittleEndian.PutUint32(data[0:4], region.Offset) - binary.LittleEndian.PutUint32(data[4:8], region.Capacity) - binary.LittleEndian.PutUint32(data[8:12], region.Length) - - // Write the Region struct - if !m.memory.Write(ptr, data) { - return fmt.Errorf("failed to write Region struct at ptr=%d", ptr) + // Then read the actual data + data, ok := m.memory.Read(ptr+4, length) + if !ok { + return nil, fmt.Errorf("failed to read raw data at ptr=%d length=%d", ptr+4, length) } - return nil + // Make a copy to ensure we own the data + result := make([]byte, len(data)) + copy(result, data) + return result, nil } func NewWazeroRuntime() (*WazeroRuntime, error) { // Create a new wazero runtime with memory configuration runtimeConfig := wazero.NewRuntimeConfig(). - WithMemoryLimitPages(4096). // Set max memory to 256 MiB (4096 * 64KB) - WithMemoryCapacityFromMax(false) // Eagerly allocate memory + WithMemoryLimitPages(maxMemoryPages). // Set max memory to 64 MiB (1024 * 64KB) + WithMemoryCapacityFromMax(true). // Eagerly allocate memory to ensure availability + WithCloseOnContextDone(true) // Ensure resources are cleaned up r := wazero.NewRuntimeWithConfig(context.Background(), runtimeConfig) - // Create mock implementations - kvStore := &MockKVStore{} - api := NewMockGoAPI() - querier := &MockQuerier{} - return &WazeroRuntime{ runtime: r, codeCache: make(map[string][]byte), @@ -274,9 +343,6 @@ func NewWazeroRuntime() (*WazeroRuntime, error) { pinnedModules: make(map[string]struct{}), moduleHits: make(map[string]uint32), moduleSizes: make(map[string]uint64), - kvStore: kvStore, - api: api, - querier: querier, }, nil } @@ -983,10 +1049,7 @@ func (w *WazeroRuntime) callContractFn( ctx := context.Background() - // 4) Register and instantiate the host module "env" - if printDebug { - fmt.Println("[DEBUG] Registering host functions ...") - } + // 4) Create runtime environment runtimeEnv := &RuntimeEnvironment{ DB: store, API: *api, @@ -996,72 +1059,97 @@ func (w *WazeroRuntime) callContractFn( gasUsed: 0, iterators: make(map[uint64]map[uint64]types.Iterator), } - hm, err := RegisterHostFunctions(w.runtime, runtimeEnv) + + // 5) Register and instantiate the host module "env" + if printDebug { + fmt.Println("[DEBUG] Registering host functions ...") + } + hostModule, err := RegisterHostFunctions(w.runtime, runtimeEnv) if err != nil { - errStr := fmt.Sprintf("[callContractFn] Error: failed to register host functions: %v", err) + errStr := fmt.Sprintf("[callContractFn] Error registering host functions: %v", err) fmt.Println(errStr) - return nil, types.GasReport{}, errors.New(errStr) + return nil, types.GasReport{}, fmt.Errorf("failed to register host functions: %w", err) } defer func() { if printDebug { fmt.Println("[DEBUG] Closing host module ...") } - hm.Close(ctx) + hostModule.Close(ctx) }() - // Instantiate the env module - if printDebug { - fmt.Println("[DEBUG] Instantiating 'env' module ...") - } - envConfig := wazero.NewModuleConfig(). - WithName("env"). - WithStartFunctions() - envModule, err := w.runtime.InstantiateModule(ctx, hm, envConfig) + // 6) Instantiate the host module first + envModule, err := w.runtime.InstantiateModule(ctx, hostModule, + wazero.NewModuleConfig().WithName("env")) if err != nil { - errStr := fmt.Sprintf("[callContractFn] Error: failed to instantiate env module: %v", err) + errStr := fmt.Sprintf("[callContractFn] Error instantiating env module: %v", err) fmt.Println(errStr) - return nil, types.GasReport{}, errors.New(errStr) + return nil, types.GasReport{}, fmt.Errorf("failed to instantiate env module: %w", err) } defer func() { if printDebug { - fmt.Println("[DEBUG] Closing 'env' module ...") + fmt.Println("[DEBUG] Closing env module ...") } envModule.Close(ctx) }() - // 5) Instantiate the contract module + // 7) Instantiate the contract module with proper memory configuration if printDebug { fmt.Println("[DEBUG] Instantiating contract module ...") } - modConfig := wazero.NewModuleConfig(). - WithName("contract"). - WithStartFunctions() - module, err := w.runtime.InstantiateModule(ctx, compiled, modConfig) + contractModule, err := w.runtime.InstantiateModule(ctx, compiled, + wazero.NewModuleConfig(). + WithName("contract"). + WithStartFunctions()) if err != nil { - errStr := fmt.Sprintf("[callContractFn] Error: failed to instantiate contract: %v", err) + errStr := fmt.Sprintf("[callContractFn] Error instantiating module: %v", err) fmt.Println(errStr) - return nil, types.GasReport{}, errors.New(errStr) + return nil, types.GasReport{}, fmt.Errorf("failed to instantiate module: %w", err) } defer func() { if printDebug { fmt.Println("[DEBUG] Closing contract module ...") } - module.Close(ctx) + contractModule.Close(ctx) }() - // 6) Create memory manager - memory := module.Memory() + // 8) Create memory manager and validate memory + memory := contractModule.Memory() if memory == nil { const errStr = "[callContractFn] Error: no memory section in module" fmt.Println(errStr) return nil, types.GasReport{}, errors.New(errStr) } - mm := newMemoryManager(memory, module) + + // Validate memory size + memSize := memory.Size() + if memSize == 0 { + const errStr = "[callContractFn] Error: memory size is 0" + fmt.Println(errStr) + return nil, types.GasReport{}, errors.New(errStr) + } + + // Calculate memory size in bytes + memSizeBytes := memSize * uint32(wasmPageSize) + if printDebug { + fmt.Printf("[DEBUG][Memory] Module memory size: %d pages (%d bytes)\n", memSize, memSizeBytes) + } + + // Initialize memory with zero values + zeroMem := make([]byte, memSizeBytes) + if !memory.Write(0, zeroMem) { + const errStr = "[callContractFn] Error: failed to initialize memory" + fmt.Println(errStr) + return nil, types.GasReport{}, errors.New(errStr) + } + + mm := newMemoryManager(memory, contractModule) + + // Write data directly to memory using Region structs if printDebug { fmt.Printf("[DEBUG] Writing environment to memory (size=%d) ...\n", len(adaptedEnv)) } - envPtr, _, err := mm.writeToMemory(adaptedEnv) + envPtr, err := mm.allocateAndWrite(adaptedEnv) if err != nil { errStr := fmt.Sprintf("[callContractFn] Error: failed to write env: %v", err) fmt.Println(errStr) @@ -1071,7 +1159,7 @@ func (w *WazeroRuntime) callContractFn( if printDebug { fmt.Printf("[DEBUG] Writing msg to memory (size=%d) ...\n", len(msg)) } - msgPtr, _, err := mm.writeToMemory(msg) + msgPtr, err := mm.allocateAndWrite(msg) if err != nil { errStr := fmt.Sprintf("[callContractFn] Error: failed to write msg: %v", err) fmt.Println(errStr) @@ -1089,16 +1177,22 @@ func (w *WazeroRuntime) callContractFn( if printDebug { fmt.Printf("[DEBUG] Writing info to memory (size=%d) ...\n", len(info)) } - infoPtr, _, err := mm.writeToMemory(info) + infoPtr, err := mm.allocateAndWrite(info) if err != nil { errStr := fmt.Sprintf("[callContractFn] Error: failed to write info: %v", err) fmt.Println(errStr) return nil, types.GasReport{}, errors.New(errStr) } callParams = []uint64{uint64(envPtr), uint64(infoPtr), uint64(msgPtr)} + if printDebug { + fmt.Printf("[DEBUG][Memory] Call parameters: env_ptr=%d, info_ptr=%d, msg_ptr=%d\n", envPtr, infoPtr, msgPtr) + } case "query", "sudo", "reply": callParams = []uint64{uint64(envPtr), uint64(msgPtr)} + if printDebug { + fmt.Printf("[DEBUG][Memory] Call parameters: env_ptr=%d, msg_ptr=%d\n", envPtr, msgPtr) + } default: errStr := fmt.Sprintf("[callContractFn] Error: unknown function name: %s", name) @@ -1107,7 +1201,7 @@ func (w *WazeroRuntime) callContractFn( } // Call the contract function - fn := module.ExportedFunction(name) + fn := contractModule.ExportedFunction(name) if fn == nil { errStr := fmt.Sprintf("[callContractFn] Error: function %q not found in contract", name) fmt.Println(errStr) @@ -1134,29 +1228,9 @@ func (w *WazeroRuntime) callContractFn( fmt.Printf("[DEBUG] results from contract call: %#v\n", results) } - // Read result from memory + // Read result directly from memory resultPtr := uint32(results[0]) - resultRegion, err := mm.readRegion(resultPtr) - if err != nil { - errStr := fmt.Sprintf("[callContractFn] Error: failed to read result region: %v", err) - fmt.Println(errStr) - return nil, types.GasReport{}, errors.New(errStr) - } - - if printDebug { - fmt.Printf("[DEBUG] result region: Offset=%d, Capacity=%d, Length=%d\n", - resultRegion.Offset, resultRegion.Capacity, resultRegion.Length) - } - - // Validate the result region - if err := validateRegion(resultRegion); err != nil { - errStr := fmt.Sprintf("[callContractFn] Error: invalid result region: %v", err) - fmt.Println(errStr) - return nil, types.GasReport{}, errors.New(errStr) - } - - // Read the actual result data using the region's length - resultData, err := mm.readFromMemory(resultRegion.Offset, resultRegion.Length) + resultData, err := mm.readData(resultPtr) if err != nil { errStr := fmt.Sprintf("[callContractFn] Error: failed to read result data: %v", err) fmt.Println(errStr) @@ -1189,38 +1263,3 @@ func (w *WazeroRuntime) callContractFn( return resultData, gr, nil } - -// SimulateStoreCode validates the code but does not store it -func (w *WazeroRuntime) SimulateStoreCode(code []byte) ([]byte, error, bool) { - if code == nil { - return nil, errors.New("Null/Nil argument: wasm"), false - } - - if len(code) == 0 { - return nil, errors.New("Wasm bytecode could not be deserialized"), false - } - - // Attempt to compile the module just to validate. - compiled, err := w.runtime.CompileModule(context.Background(), code) - if err != nil { - return nil, errors.New("Wasm bytecode could not be deserialized"), false - } - defer compiled.Close(context.Background()) - - // Check memory requirements - memoryCount := 0 - for _, exp := range compiled.ExportedMemories() { - if exp != nil { - memoryCount++ - } - } - if memoryCount != 1 { - return nil, fmt.Errorf("Error during static Wasm validation: Wasm contract must contain exactly one memory"), false - } - - // Compute checksum but do not store in any cache - checksum := sha256.Sum256(code) - - // Return checksum, no error, and persisted=false - return checksum[:], nil, false -}