Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boost import sheet from CSV / XLSX #646

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use tokio::sync::Mutex;
use umya_spreadsheet::new_file;
use uuid::Uuid;

const MAX_ROWS: u32 = 10000;

pub struct SheetRustFunctions;

// Function to detect the delimiter
Expand Down Expand Up @@ -299,38 +301,16 @@ impl SheetRustFunctions {
}
}

// Ensure the number of rows matches the number of records
{
let mut sheet_manager = sheet_manager.lock().await;
while {
let (sheet, _) = sheet_manager.sheets.get_mut(&sheet_id).ok_or("Sheet ID not found")?;
sheet.rows.len() < records.len()
} {
sheet_manager.add_row(&sheet_id, None).await?;
// Add rows with values in chunks
for chunk in records.chunks(MAX_ROWS as usize) {
let mut rows = Vec::new();
for record in chunk {
let row_cells = record.iter().map(|s| s.to_string()).collect::<Vec<String>>();
rows.push(row_cells);
}
}

// Set values for the new columns
let row_ids: Vec<String> = {
let sheet_manager = sheet_manager.lock().await;
let (sheet, _) = sheet_manager.sheets.get(&sheet_id).ok_or("Sheet ID not found")?;
sheet.display_rows.clone()
};

for (row_index, record) in records.iter().enumerate() {
let row_id = row_ids.get(row_index).ok_or("Row ID not found")?.clone();
for (col_index, value) in record.iter().enumerate() {
let column_definition = &column_definitions[col_index];
let mut sheet_manager = sheet_manager.lock().await;
sheet_manager
.set_cell_value(
&sheet_id,
row_id.clone(),
column_definition.id.clone(),
value.to_string(),
)
.await?;
}
let mut sheet_manager = sheet_manager.lock().await;
sheet_manager.add_values(&sheet_id, rows).await?;
}

Ok("Columns created successfully".to_string())
Expand Down Expand Up @@ -383,49 +363,36 @@ impl SheetRustFunctions {
}
}

let mut num_rows: u32 = 0;
for row_index in 1..u32::MAX {
let row_cells = worksheet.get_collection_by_row(&row_index);
let is_empty_row =
row_cells.is_empty() || row_cells.into_iter().all(|cell| cell.get_cell_value().is_empty());

if is_empty_row {
break;
}
// Add rows with values in chunks
for chunk_start in (1..u32::MAX).step_by(MAX_ROWS as usize) {
let mut rows = Vec::new();
let mut is_empty_row = false;
for row_index in chunk_start..(chunk_start + MAX_ROWS) {
let row_cells = worksheet.get_collection_by_row(&row_index);
is_empty_row =
row_cells.is_empty() || row_cells.into_iter().all(|cell| cell.get_cell_value().is_empty());

if is_empty_row {
break;
}

num_rows += 1;
}
let mut row_cells = Vec::new();
for col_index in 1..=num_columns {
if let Some(cell) = worksheet.get_cell((col_index.to_u32().unwrap_or_default(), row_index)) {
row_cells.push(cell.get_value().to_string());
}
}

{
let mut sheet_manager = sheet_manager.lock().await;
rows.push(row_cells);
}

for _ in 0..num_rows {
sheet_manager.add_row(&sheet_id, None).await?;
if !rows.is_empty() {
let mut sheet_manager = sheet_manager.lock().await;
sheet_manager.add_values(&sheet_id, rows).await?;
}
}

let row_ids: Vec<String> = {
let sheet_manager = sheet_manager.lock().await;
let (sheet, _) = sheet_manager.sheets.get(&sheet_id).ok_or("Sheet ID not found")?;
sheet.display_rows.clone()
};

for row_index in 1..=num_rows {
for col_index in 1..=num_columns {
if let Some(cell) = worksheet.get_cell((col_index.to_u32().unwrap_or_default(), row_index)) {
let cell_value = cell.get_value();
let row_id = row_ids.get(row_index as usize - 1).ok_or("Row ID not found")?.clone();

let mut sheet_manager = sheet_manager.lock().await;
sheet_manager
.set_cell_value(
&sheet_id,
row_id,
column_definitions[col_index as usize - 1].id.clone(),
cell_value.to_string(),
)
.await?;
}
if is_empty_row {
break;
}
}
}
Expand Down
22 changes: 21 additions & 1 deletion shinkai-bin/shinkai-node/src/managers/sheet_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl SheetManager {
if let Some((first_job_message, _)) = job_messages.first() {
let mut job_manager = job_manager.lock().await;
job_manager
// TODO: I'm not sure about this one
// TODO: I'm not sure about this one
.queue_job_message(first_job_message, user_profile, "")
.await
.map_err(|e| e.to_string())?;
Expand Down Expand Up @@ -434,6 +434,26 @@ impl SheetManager {
Ok(())
}

pub async fn add_values(&mut self, sheet_id: &str, values: Vec<Vec<String>>) -> Result<(), String> {
let (sheet, _) = self.sheets.get_mut(sheet_id).ok_or("Sheet ID not found")?;
let jobs = sheet.add_values(values).await.map_err(|e| e.to_string())?;

// Update the sheet in the database
let db_strong = self.db.upgrade().ok_or("Couldn't convert to strong db".to_string())?;
db_strong
.save_sheet(sheet.clone(), self.user_profile.clone())
.map_err(|e| e.to_string())?;

// Create and chain JobMessages, and add the first one to the job queue
if let Some(job_manager) = &self.job_manager {
Self::create_and_chain_job_messages(jobs, job_manager, &self.user_profile).await?;
} else {
return Err("JobManager not set".to_string());
}

Ok(())
}

async fn handle_updates(
receiver: Receiver<SheetUpdate>,
ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
Expand Down
36 changes: 32 additions & 4 deletions shinkai-libs/shinkai-sheet/src/sheet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub struct CellUpdateData {
#[derive(Debug, Clone)]
pub struct ProcessedInput {
pub content: String,
pub local_files: Vec<(String, String)>, // (FilePath, FileName)
pub local_files: Vec<(String, String)>, // (FilePath, FileName)
pub uploaded_files: Vec<(String, String)>, // (FilePath, FileName)
}

Expand Down Expand Up @@ -471,7 +471,8 @@ impl Sheet {
if let Some(cell) = self.get_cell(row.clone(), col_uuid.clone()) {
if let Some(value) = &cell.value {
// Assuming the value is a serialized list of file paths
let files: Vec<(String, String)> = serde_json::from_str(value).unwrap_or_default();
let files: Vec<(String, String)> =
serde_json::from_str(value).unwrap_or_default();
local_files.extend(files);
}
}
Expand All @@ -484,7 +485,7 @@ impl Sheet {
// TODO: eventually if we want to support multiple files, we need to change this
// let file_names: Vec<String> = serde_json::from_str(value).unwrap_or_default();
// for file_name in file_names {
uploaded_files.push((file_inbox_id.clone(), value.clone()));
uploaded_files.push((file_inbox_id.clone(), value.clone()));
// }
}
}
Expand Down Expand Up @@ -524,6 +525,11 @@ impl Sheet {
Ok(())
}

pub async fn add_values(&mut self, values: Vec<Vec<String>>) -> Result<Vec<WorkflowSheetJobData>, String> {
let jobs = self.dispatch(SheetAction::AddValues(values)).await;
Ok(jobs)
}

fn compute_input_hash(
&self,
input_cells: &[(RowIndex, ColumnIndex, ColumnDefinition)],
Expand Down Expand Up @@ -681,6 +687,7 @@ pub enum SheetAction {
TriggerUpdateColumnValues(UuidString),
RemoveRow(UuidString),
AddRow(UuidString), // Add other actions as needed
AddValues(Vec<Vec<String>>),
}

// Implement the reducer function
Expand Down Expand Up @@ -1150,9 +1157,30 @@ pub fn sheet_reducer(
jobs.append(&mut new_jobs);
}
}
SheetAction::AddValues(values) => {
for row in values {
let row_uuid = Uuid::new_v4().to_string();
let mut row_cells = HashMap::new();
for (col_index, value) in row.iter().enumerate() {
if let Some(col_uuid) = state.display_columns.get(col_index) {
row_cells.insert(
col_uuid.clone(),
Cell {
value: Some(value.clone()),
last_updated: Utc::now(),
status: CellStatus::Ready,
input_hash: None,
},
);
}
}
state.rows.insert(row_uuid.clone(), row_cells);
state.display_rows.push(row_uuid.clone());
}
}
}
println!("After state: \n");
state.print_as_ascii_table();
(state, jobs)
})
}
}
Loading