Skip to content

Commit

Permalink
fix(Pipelines): handle case of notebook pipeline (#721)
Browse files Browse the repository at this point in the history
* fix(Pipelines): handle case of notebook pipeline

* tests(Pipelines): add test for notebook pipeline
  • Loading branch information
cheikhgwane authored Jun 25, 2024
1 parent f934ea5 commit 9575292
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
6 changes: 5 additions & 1 deletion hexa/pipelines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@ def run(
trigger_mode=trigger_mode,
execution_date=timezone.now(),
state=PipelineRunState.QUEUED,
config=self.merge_pipeline_config(config, pipeline_version.config),
config=(
self.merge_pipeline_config(config, pipeline_version.config)
if pipeline_version
else self.config
),
access_token=str(uuid.uuid4()),
send_mail_notifications=send_mail_notifications,
timeout=timeout,
Expand Down
24 changes: 23 additions & 1 deletion hexa/pipelines/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from hexa.core.test import TestCase
from hexa.files.tests.mocks.mockgcp import mock_gcp_storage
from hexa.pipelines.models import Pipeline, PipelineRunTrigger
from hexa.pipelines.models import Pipeline, PipelineRunTrigger, PipelineType
from hexa.user_management.models import Feature, FeatureFlag, User
from hexa.workspaces.models import (
Workspace,
Expand Down Expand Up @@ -97,6 +97,28 @@ def test_run_pipeline_not_enabled(self):
self.assertEqual(r.status_code, 400)
self.assertEqual(r.json(), {"error": "Pipeline has no webhook enabled"})

def test_run_pipeline_notebook_webhook(self):
pipeline = Pipeline.objects.create(
code="new_pipeline",
name="notebook.ipynb",
workspace=self.WORKSPACE,
type=PipelineType.NOTEBOOK,
notebook_path="notebook.ipynb",
webhook_enabled=True,
)
pipeline.generate_webhook_token()

response = self.client.post(
reverse(
"pipelines:run",
args=[pipeline.webhook_token],
),
content_type="application/json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(str(pipeline.last_run.id), response.json()["run_id"])
self.assertEqual(pipeline.last_run.trigger_mode, PipelineRunTrigger.WEBHOOK)

def test_run_pipeline_valid(self):
self.assertEqual(self.PIPELINE.last_run, None)
response = self.client.post(
Expand Down
26 changes: 14 additions & 12 deletions hexa/pipelines/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from hexa.pipelines.models import Environment

from .credentials import PipelinesCredentials
from .models import Pipeline, PipelineRunTrigger, PipelineVersion
from .models import Pipeline, PipelineRunTrigger, PipelineType, PipelineVersion
from .queue import environment_sync_queue

logger = getLogger(__name__)
Expand Down Expand Up @@ -126,17 +126,19 @@ def run_pipeline(
return JsonResponse({"error": "Pipeline has no webhook enabled"}, status=400)

# Get the pipeline version
try:
pipeline_version = pipeline.last_version
if version_id is not None:
pipeline_version = PipelineVersion.objects.get(
pipeline=pipeline, id=version_id
)

if pipeline_version is None:
return JsonResponse({"error": "Pipeline has no version"}, status=400)
except PipelineVersion.DoesNotExist:
return JsonResponse({"error": "Pipeline version not found"}, status=404)
pipeline_version = None
if pipeline.type == PipelineType.ZIPFILE:
try:
pipeline_version = pipeline.last_version
if version_id is not None:
pipeline_version = PipelineVersion.objects.get(
pipeline=pipeline, id=version_id
)

if pipeline_version is None:
return JsonResponse({"error": "Pipeline has no version"}, status=400)
except PipelineVersion.DoesNotExist:
return JsonResponse({"error": "Pipeline version not found"}, status=404)

# Get the data from the request
content_type = request.META.get("CONTENT_TYPE")
Expand Down

0 comments on commit 9575292

Please sign in to comment.