Skip to content

Commit

Permalink
package input files and folders (backend)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Mar 14, 2024
1 parent 07c3fe3 commit 5a784c5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
4 changes: 4 additions & 0 deletions jupyter_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class CreateJob(BaseModel):
name: str
output_filename_template: Optional[str] = OUTPUT_FILENAME_TEMPLATE
compute_type: Optional[str] = None
package_input_folder: Optional[bool] = None

@root_validator
def compute_input_filename(cls, values) -> Dict:
Expand Down Expand Up @@ -145,6 +146,7 @@ class DescribeJob(BaseModel):
status: Status = Status.CREATED
status_message: Optional[str] = None
downloaded: bool = False
package_input_folder: Optional[bool] = None

class Config:
orm_mode = True
Expand Down Expand Up @@ -209,6 +211,7 @@ class CreateJobDefinition(BaseModel):
compute_type: Optional[str] = None
schedule: Optional[str] = None
timezone: Optional[str] = None
package_input_folder: Optional[bool] = None

@root_validator
def compute_input_filename(cls, values) -> Dict:
Expand All @@ -234,6 +237,7 @@ class DescribeJobDefinition(BaseModel):
create_time: int
update_time: int
active: bool
package_input_folder: Optional[bool] = None

class Config:
orm_mode = True
Expand Down
1 change: 1 addition & 0 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class CommonColumns:
output_filename_template = Column(String(256))
update_time = Column(Integer, default=get_utc_timestamp, onupdate=get_utc_timestamp)
create_time = Column(Integer, default=get_utc_timestamp)
package_input_folder = Column(Boolean)


class Job(CommonColumns, Base):
Expand Down
33 changes: 31 additions & 2 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,25 @@ def copy_input_file(self, input_uri: str, copy_to_path: str):
with fsspec.open(copy_to_path, "wb") as output_file:
output_file.write(input_file.read())

def copy_input_folder(self, input_uri: str, nb_copy_to_path: str):
"""Copies the input file along with the input directory to the staging directory"""
input_dir_path = os.path.dirname(os.path.join(self.root_dir, input_uri))
staging_dir = os.path.dirname(nb_copy_to_path)

# Copy the input file
self.copy_input_file(input_uri, nb_copy_to_path)

# Copy the rest of the input folder excluding the input file
for item in os.listdir(input_dir_path):
source = os.path.join(input_dir_path, item)
destination = os.path.join(staging_dir, item)
if os.path.isdir(source):
shutil.copytree(source, destination)
elif os.path.isfile(source) and item != os.path.basename(input_uri):
with fsspec.open(source) as src_file:
with fsspec.open(destination, "wb") as output_file:
output_file.write(src_file.read())

def create_job(self, model: CreateJob) -> str:
if not model.job_definition_id and not self.file_exists(model.input_uri):
raise InputUriError(model.input_uri)
Expand Down Expand Up @@ -401,7 +420,10 @@ def create_job(self, model: CreateJob) -> str:
session.commit()

staging_paths = self.get_staging_paths(DescribeJob.from_orm(job))
self.copy_input_file(model.input_uri, staging_paths["input"])
if model.package_input_folder:
self.copy_input_folder(model.input_uri, staging_paths["input"])
else:
self.copy_input_file(model.input_uri, staging_paths["input"])

# The MP context forces new processes to not be forked on Linux.
# This is necessary because `asyncio.get_event_loop()` is bugged in
Expand Down Expand Up @@ -541,7 +563,10 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:
job_definition_id = job_definition.job_definition_id

staging_paths = self.get_staging_paths(DescribeJobDefinition.from_orm(job_definition))
self.copy_input_file(model.input_uri, staging_paths["input"])
if model.package_input_folder:
self.copy_input_folder(model.input_uri, staging_paths["input"])
else:
self.copy_input_file(model.input_uri, staging_paths["input"])

if self.task_runner and job_definition.schedule:
self.task_runner.add_job_definition(job_definition_id)
Expand Down Expand Up @@ -690,6 +715,10 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->

staging_paths["input"] = os.path.join(self.staging_path, id, model.input_filename)

if model.package_input_folder:
notebook_dir = os.path.dirname(staging_paths["input"])
staging_paths["input_dir"] = notebook_dir

return staging_paths


Expand Down

0 comments on commit 5a784c5

Please sign in to comment.