diff --git a/rust/drivers/datafusion/src/lib.rs b/rust/drivers/datafusion/src/lib.rs index 5141d856cd..55adcc5719 100644 --- a/rust/drivers/datafusion/src/lib.rs +++ b/rust/drivers/datafusion/src/lib.rs @@ -17,14 +17,16 @@ #![allow(refining_impl_trait)] +use adbc_core::ffi::constants; +use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::TableType; use datafusion::prelude::*; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use datafusion_substrait::substrait::proto::Plan; use prost::Message; +use std::fmt::Debug; use std::sync::Arc; use std::vec::IntoIter; -use std::{collections::HashMap, fmt::Debug}; use tokio::runtime::Runtime; use arrow_array::builder::{ @@ -113,9 +115,7 @@ impl Driver for DataFusionDriver { type DatabaseType = DataFusionDatabase; fn new_database(&mut self) -> Result { - Ok(Self::DatabaseType { - options: HashMap::new(), - }) + Ok(Self::DatabaseType {}) } fn new_database_with_opts( @@ -127,9 +127,7 @@ impl Driver for DataFusionDriver { ), >, ) -> adbc_core::error::Result { - let mut database = Self::DatabaseType { - options: HashMap::new(), - }; + let mut database = Self::DatabaseType {}; for (key, value) in opts { database.set_option(key, value)?; } @@ -137,36 +135,48 @@ impl Driver for DataFusionDriver { } } -pub struct DataFusionDatabase { - options: HashMap, -} +pub struct DataFusionDatabase {} impl Optionable for DataFusionDatabase { type Option = OptionDatabase; fn set_option( &mut self, - _key: Self::Option, + key: Self::Option, _value: adbc_core::options::OptionValue, ) -> adbc_core::error::Result<()> { - self.options.insert(_key, _value); - Ok(()) + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_string(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_string(&self, key: Self::Option) -> adbc_core::error::Result { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_bytes(&self, _key: Self::Option) -> adbc_core::error::Result> { - todo!() + fn get_option_bytes(&self, key: Self::Option) -> adbc_core::error::Result> { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_int(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_int(&self, key: Self::Option) -> adbc_core::error::Result { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_double(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_double(&self, key: Self::Option) -> adbc_core::error::Result { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } } @@ -189,7 +199,7 @@ impl Database for DataFusionDatabase { fn new_connection_with_opts( &mut self, - _opts: impl IntoIterator< + opts: impl IntoIterator< Item = ( adbc_core::options::OptionConnection, adbc_core::options::OptionValue, @@ -203,10 +213,16 @@ impl Database for DataFusionDatabase { .build() .unwrap(); - Ok(DataFusionConnection { + let mut connection = DataFusionConnection { runtime: Arc::new(runtime), ctx: Arc::new(ctx), - }) + }; + + for (key, value) in opts { + connection.set_option(key, value)?; + } + + Ok(connection) } } @@ -220,26 +236,85 @@ impl Optionable for DataFusionConnection { fn set_option( &mut self, - _key: Self::Option, - _value: adbc_core::options::OptionValue, + key: Self::Option, + value: adbc_core::options::OptionValue, ) -> adbc_core::error::Result<()> { - todo!() + match key.as_ref() { + constants::ADBC_CONNECTION_OPTION_CURRENT_CATALOG => match value { + OptionValue::String(value) => { + self.runtime.block_on(async { + let query = format!("SET datafusion.catalog.default_catalog = {value}"); + self.ctx.sql(query.as_str()).await.unwrap(); + }); + Ok(()) + } + _ => Err(Error::with_message_and_status( + "CurrentCatalog value must be of type String", + Status::InvalidArguments, + )), + }, + constants::ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA => match value { + OptionValue::String(value) => { + self.runtime.block_on(async { + let query = format!("SET datafusion.catalog.default_schema = {value}"); + self.ctx.sql(query.as_str()).await.unwrap(); + }); + Ok(()) + } + _ => Err(Error::with_message_and_status( + "CurrentSchema value must be of type String", + Status::InvalidArguments, + )), + }, + _ => Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )), + } } - fn get_option_string(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_string(&self, key: Self::Option) -> adbc_core::error::Result { + match key.as_ref() { + constants::ADBC_CONNECTION_OPTION_CURRENT_CATALOG => Ok(self + .ctx + .state() + .config_options() + .catalog + .default_catalog + .clone()), + constants::ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA => Ok(self + .ctx + .state() + .config_options() + .catalog + .default_schema + .clone()), + _ => Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )), + } } - fn get_option_bytes(&self, _key: Self::Option) -> adbc_core::error::Result> { - todo!() + fn get_option_bytes(&self, key: Self::Option) -> adbc_core::error::Result> { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_int(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_int(&self, key: Self::Option) -> adbc_core::error::Result { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_double(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_double(&self, key: Self::Option) -> adbc_core::error::Result { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } } @@ -622,6 +697,8 @@ impl Connection for DataFusionConnection { ctx: self.ctx.clone(), sql_query: None, substrait_plan: None, + bound_record_batch: None, + ingest_target_table: None, }) } @@ -707,6 +784,8 @@ pub struct DataFusionStatement { ctx: Arc, sql_query: Option, substrait_plan: Option, + bound_record_batch: Option, + ingest_target_table: Option, } impl Optionable for DataFusionStatement { @@ -714,32 +793,72 @@ impl Optionable for DataFusionStatement { fn set_option( &mut self, - _key: Self::Option, - _value: adbc_core::options::OptionValue, + key: Self::Option, + value: adbc_core::options::OptionValue, ) -> adbc_core::error::Result<()> { - todo!() + match key.as_ref() { + constants::ADBC_INGEST_OPTION_TARGET_TABLE => match value { + OptionValue::String(value) => { + self.ingest_target_table = Some(value); + Ok(()) + } + _ => Err(Error::with_message_and_status( + "IngestOptionTargetTable value must be of type String", + Status::InvalidArguments, + )), + }, + _ => Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )), + } } - fn get_option_string(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_string(&self, key: Self::Option) -> adbc_core::error::Result { + match key.as_ref() { + constants::ADBC_INGEST_OPTION_TARGET_TABLE => { + let target_table = self.ingest_target_table.clone(); + match target_table { + Some(table) => Ok(table), + None => Err(Error::with_message_and_status( + format!("{key:?} has not been set"), + Status::NotFound, + )), + } + } + _ => Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )), + } } - fn get_option_bytes(&self, _key: Self::Option) -> adbc_core::error::Result> { - todo!() + fn get_option_bytes(&self, key: Self::Option) -> adbc_core::error::Result> { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_int(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_int(&self, key: Self::Option) -> adbc_core::error::Result { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } - fn get_option_double(&self, _key: Self::Option) -> adbc_core::error::Result { - todo!() + fn get_option_double(&self, key: Self::Option) -> adbc_core::error::Result { + Err(Error::with_message_and_status( + format!("Unrecognized option: {key:?}"), + Status::NotFound, + )) } } impl Statement for DataFusionStatement { - fn bind(&mut self, _batch: arrow_array::RecordBatch) -> adbc_core::error::Result<()> { - todo!() + fn bind(&mut self, batch: arrow_array::RecordBatch) -> adbc_core::error::Result<()> { + self.bound_record_batch.replace(batch); + Ok(()) } fn bind_stream( @@ -768,13 +887,29 @@ impl Statement for DataFusionStatement { } fn execute_update(&mut self) -> adbc_core::error::Result> { - self.runtime.block_on(async { - let _ = self - .ctx - .sql(&self.sql_query.clone().unwrap()) - .await - .unwrap(); - }); + if self.sql_query.is_some() { + self.runtime.block_on(async { + let _ = self + .ctx + .sql(&self.sql_query.clone().unwrap()) + .await + .unwrap(); + }); + } else if let Some(batch) = self.bound_record_batch.take() { + self.runtime.block_on(async { + let table = match self.ingest_target_table.clone() { + Some(table) => table, + None => todo!(), + }; + + self.ctx + .read_batch(batch) + .unwrap() + .write_table(table.as_str(), DataFrameWriteOptions::new()) + .await + .unwrap(); + }); + } Ok(Some(0)) } diff --git a/rust/drivers/datafusion/tests/test_datafusion.rs b/rust/drivers/datafusion/tests/test_datafusion.rs index 4e1db4fa83..38764459a5 100644 --- a/rust/drivers/datafusion/tests/test_datafusion.rs +++ b/rust/drivers/datafusion/tests/test_datafusion.rs @@ -16,11 +16,11 @@ // under the License. use adbc_core::driver_manager::{ManagedConnection, ManagedDriver}; -use adbc_core::{Connection, Database, Driver, Statement}; +use adbc_core::{Connection, Database, Driver, Optionable, Statement}; use arrow_array::RecordBatch; use datafusion::prelude::*; -use adbc_core::options::AdbcVersion; +use adbc_core::options::{AdbcVersion, OptionConnection, OptionStatement, OptionValue}; use arrow_select::concat::concat_batches; use datafusion_substrait::logical_plan::producer::to_substrait_plan; use datafusion_substrait::substrait::proto::Plan; @@ -85,6 +85,45 @@ fn execute_substrait(connection: &mut ManagedConnection, plan: Plan) -> RecordBa concat_batches(&schema, &batches).unwrap() } +#[test] +fn test_connection_options() { + let mut connection = get_connection(); + + let current_catalog = connection + .get_option_string(OptionConnection::CurrentCatalog) + .unwrap(); + + assert_eq!(current_catalog, "datafusion"); + + let _ = connection.set_option( + OptionConnection::CurrentCatalog, + OptionValue::String("datafusion2".to_string()), + ); + + let current_catalog = connection + .get_option_string(OptionConnection::CurrentCatalog) + .unwrap(); + + assert_eq!(current_catalog, "datafusion2"); + + let current_schema = connection + .get_option_string(OptionConnection::CurrentSchema) + .unwrap(); + + assert_eq!(current_schema, "public"); + + let _ = connection.set_option( + OptionConnection::CurrentSchema, + OptionValue::String("public2".to_string()), + ); + + let current_schema = connection + .get_option_string(OptionConnection::CurrentSchema) + .unwrap(); + + assert_eq!(current_schema, "public2"); +} + #[test] fn test_get_objects_database() { let mut connection = get_connection(); @@ -112,6 +151,32 @@ fn test_execute_sql() { assert_eq!(batch.num_columns(), 2); } +#[test] +fn test_ingest() { + let mut connection = get_connection(); + + execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS datafusion.public.example (c1 INT, c2 VARCHAR) AS VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')"); + + let batch = execute_sql_query(&mut connection, "SELECT * FROM datafusion.public.example"); + + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 2); + + let mut statement = connection.new_statement().unwrap(); + + let _ = statement.set_option( + OptionStatement::TargetTable, + OptionValue::String("example".to_string()), + ); + let _ = statement.bind(batch); + + let _ = statement.execute_update(); + + let batch = execute_sql_query(&mut connection, "SELECT * FROM datafusion.public.example"); + + assert_eq!(batch.num_rows(), 6); +} + #[test] fn test_execute_substrait() { let mut connection = get_connection();