From 27676fcd7875c3347b140f4810ae68e4dcc7d641 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 12:36:51 -0500 Subject: [PATCH] fix: take `RunOptions` by reference --- src/training/trainer.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/training/trainer.rs b/src/training/trainer.rs index 6ef4c92..f7c7cb3 100644 --- a/src/training/trainer.rs +++ b/src/training/trainer.rs @@ -96,7 +96,7 @@ impl Trainer { &'s self, inputs: impl Into>, labels: impl Into> - ) -> Result> { + ) -> Result> { match inputs.into() { SessionInputs::ValueSlice(input_values) => match labels.into() { SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), @@ -112,11 +112,11 @@ impl Trainer { } } - fn step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( &'s self, input_values: impl Iterator>, - run_options: Option> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); @@ -145,7 +145,7 @@ impl Trainer { &'s self, inputs: impl Into>, labels: impl Into> - ) -> Result> { + ) -> Result> { match inputs.into() { SessionInputs::ValueSlice(input_values) => match labels.into() { SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), @@ -161,11 +161,11 @@ impl Trainer { } } - fn eval_step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( &'s self, input_values: impl Iterator>, - run_options: Option> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();