Skip to content

Commit

Permalink
Merge pull request nf-core#3276 from mashehu/fix-create-params
Browse files Browse the repository at this point in the history
handle new schema structure in `nf-core pipelines create-params-file`
  • Loading branch information
mashehu authored Nov 12, 2024
2 parents 27659c5 + e57a096 commit 24e2dc2
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- Update GitHub Actions ([#3237](https://github.com/nf-core/tools/pull/3237))
- add `--dir/-d` option to schema commands ([#3247](https://github.com/nf-core/tools/pull/3247))
- Update pre-commit hook astral-sh/ruff-pre-commit to v0.7.1 ([#3250](https://github.com/nf-core/tools/pull/3250))
- handle new schema structure in `nf-core pipelines create-params-file` ([#3276](https://github.com/nf-core/tools/pull/3276))
- Update Gitpod image to use Miniforge instead of Miniconda([#3274](https://github.com/nf-core/tools/pull/3274))
- Update pre-commit hook astral-sh/ruff-pre-commit to v0.7.3 ([#3275](https://github.com/nf-core/tools/pull/3275))

Expand Down
43 changes: 28 additions & 15 deletions nf_core/pipelines/params_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import json
import logging
import os
import textwrap
from typing import Literal, Optional
from pathlib import Path
from typing import Dict, List, Literal, Optional

import questionary

Expand All @@ -27,7 +27,7 @@
ModeLiteral = Literal["both", "start", "end", "none"]


def _print_wrapped(text, fill_char="-", mode="both", width=80, indent=0, drop_whitespace=True):
def _print_wrapped(text, fill_char="-", mode="both", width=80, indent=0, drop_whitespace=True) -> str:
"""Helper function to format text for the params-file template.
Args:
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self.wfs = nf_core.pipelines.list.Workflows()
self.wfs.get_remote_workflows()

def get_pipeline(self):
def get_pipeline(self) -> Optional[bool]:
"""
Prompt the user for a pipeline name and get the schema
"""
Expand All @@ -124,11 +124,14 @@ def get_pipeline(self):
).unsafe_ask()

# Get the schema
self.schema_obj = nf_core.pipelines.schema.PipelineSchema()
self.schema_obj = PipelineSchema()
if self.schema_obj is None:
return False
self.schema_obj.get_schema_path(self.pipeline, local_only=False, revision=self.pipeline_revision)
self.schema_obj.get_wf_params()
return True

def format_group(self, definition, show_hidden=False):
def format_group(self, definition, show_hidden=False) -> str:
"""Format a group of parameters of the schema as commented YAML.
Args:
Expand Down Expand Up @@ -167,7 +170,9 @@ def format_group(self, definition, show_hidden=False):

return out

def format_param(self, name, properties, required_properties=(), show_hidden=False):
def format_param(
self, name: str, properties: Dict, required_properties: List[str] = [], show_hidden: bool = False
) -> Optional[str]:
"""
Format a single parameter of the schema as commented YAML
Expand All @@ -188,6 +193,9 @@ def format_param(self, name, properties, required_properties=(), show_hidden=Fal
return None

description = properties.get("description", "")
if self.schema_obj is None:
log.error("No schema object found")
return ""
self.schema_obj.get_schema_defaults()
default = properties.get("default")
type = properties.get("type")
Expand All @@ -209,7 +217,7 @@ def format_param(self, name, properties, required_properties=(), show_hidden=Fal

return out

def generate_params_file(self, show_hidden=False):
def generate_params_file(self, show_hidden: bool = False) -> str:
"""Generate the contents of a parameter template file.
Assumes the pipeline has been fetched (if remote) and the schema loaded.
Expand All @@ -220,6 +228,10 @@ def generate_params_file(self, show_hidden=False):
Returns:
str: Formatted output for the pipeline schema
"""
if self.schema_obj is None:
log.error("No schema object found")
return ""

schema = self.schema_obj.schema
pipeline_name = self.schema_obj.pipeline_manifest.get("name", self.pipeline)
pipeline_version = self.schema_obj.pipeline_manifest.get("version", "0.0.0")
Expand All @@ -234,13 +246,13 @@ def generate_params_file(self, show_hidden=False):
out += "\n"

# Add all parameter groups
for definition in schema.get("definitions", {}).values():
for definition in schema.get("definitions", schema.get("$defs", {})).values():
out += self.format_group(definition, show_hidden=show_hidden)
out += "\n"

return out

def write_params_file(self, output_fn="nf-params.yaml", show_hidden=False, force=False):
def write_params_file(self, output_fn: Path = Path("nf-params.yaml"), show_hidden=False, force=False) -> bool:
"""Build a template file for the pipeline schema.
Args:
Expand All @@ -254,7 +266,9 @@ def write_params_file(self, output_fn="nf-params.yaml", show_hidden=False, force
"""

self.get_pipeline()

if self.schema_obj is None:
log.error("No schema object found")
return False
try:
self.schema_obj.load_schema()
self.schema_obj.validate_schema()
Expand All @@ -265,11 +279,10 @@ def write_params_file(self, output_fn="nf-params.yaml", show_hidden=False, force

schema_out = self.generate_params_file(show_hidden=show_hidden)

if os.path.exists(output_fn) and not force:
if output_fn.exists() and not force:
log.error(f"File '{output_fn}' exists! Please delete first, or use '--force'")
return False
with open(output_fn, "w") as fh:
fh.write(schema_out)
log.info(f"Parameter file written to '{output_fn}'")
output_fn.write_text(schema_out)
log.info(f"Parameter file written to '{output_fn}'")

return True
1 change: 1 addition & 0 deletions nf_core/pipelines/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get_schema_path(
# Supplied path exists - assume a local pipeline directory or schema
if path.exists():
log.debug(f"Path exists: {path}. Assuming local pipeline directory or schema")
local_only = True
if revision is not None:
log.warning(f"Local workflow supplied, ignoring revision '{revision}'")
if path.is_dir():
Expand Down
90 changes: 39 additions & 51 deletions tests/pipelines/test_params_file.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,67 @@
import json
import os
import shutil
import tempfile
from pathlib import Path

import nf_core.pipelines.create.create
import nf_core.pipelines.schema
from nf_core.pipelines.params_file import ParamsFileBuilder

from ..test_pipelines import TestPipelines

class TestParamsFileBuilder:

class TestParamsFileBuilder(TestPipelines):
"""Class for schema tests"""

@classmethod
def setup_class(cls):
def setUp(self):
"""Create a new PipelineSchema object"""
cls.schema_obj = nf_core.pipelines.schema.PipelineSchema()
cls.root_repo_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

# Create a test pipeline in temp directory
cls.tmp_dir = tempfile.mkdtemp()
cls.template_dir = Path(cls.tmp_dir, "wf")
create_obj = nf_core.pipelines.create.create.PipelineCreate(
"testpipeline", "a description", "Me", outdir=cls.template_dir, no_git=True
)
create_obj.init_pipeline()

cls.template_schema = Path(cls.template_dir, "nextflow_schema.json")
cls.params_template_builder = ParamsFileBuilder(cls.template_dir)
cls.invalid_template_schema = Path(cls.template_dir, "nextflow_schema_invalid.json")

# Remove the allOf section to make the schema invalid
with open(cls.template_schema) as fh:
o = json.load(fh)
del o["allOf"]

with open(cls.invalid_template_schema, "w") as fh:
json.dump(o, fh)

@classmethod
def teardown_class(cls):
if Path(cls.tmp_dir).exists():
shutil.rmtree(cls.tmp_dir)
super().setUp()

self.template_schema = Path(self.pipeline_dir, "nextflow_schema.json")
self.params_template_builder = ParamsFileBuilder(self.pipeline_dir)
self.outfile = Path(self.pipeline_dir, "params-file.yml")

def test_build_template(self):
outfile = Path(self.tmp_dir, "params-file.yml")
self.params_template_builder.write_params_file(str(outfile))
self.params_template_builder.write_params_file(self.outfile)

assert outfile.exists()
assert self.outfile.exists()

with open(outfile) as fh:
with open(self.outfile) as fh:
out = fh.read()

assert "nf-core/testpipeline" in out

def test_build_template_invalid_schema(self, caplog):
def test_build_template_invalid_schema(self):
"""Build a schema from a template"""
outfile = Path(self.tmp_dir, "params-file-invalid.yml")
builder = ParamsFileBuilder(self.invalid_template_schema)
res = builder.write_params_file(str(outfile))
schema = {}
with open(self.template_schema) as fh:
schema = json.load(fh)
del schema["allOf"]

with open(self.template_schema, "w") as fh:
json.dump(schema, fh)

builder = ParamsFileBuilder(self.template_schema)
res = builder.write_params_file(self.outfile)

assert res is False
assert "Pipeline schema file is invalid" in caplog.text
assert "Pipeline schema file is invalid" in self.caplog.text

def test_build_template_file_exists(self, caplog):
def test_build_template_file_exists(self):
"""Build a schema from a template"""

# Creates a new empty file
outfile = Path(self.tmp_dir) / "params-file.yml"
with open(outfile, "w"):
pass
self.outfile.touch()

res = self.params_template_builder.write_params_file(outfile)
res = self.params_template_builder.write_params_file(self.outfile)

assert res is False
assert f"File '{outfile}' exists!" in caplog.text
assert f"File '{self.outfile}' exists!" in self.caplog.text

self.outfile.unlink()

outfile.unlink()
def test_build_template_content(self):
"""Test that the content of the params file is correct"""
self.params_template_builder.write_params_file(self.outfile)

with open(self.outfile) as fh:
out = fh.read()

assert "nf-core/testpipeline" in out
assert "# input = null" in out
6 changes: 6 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import shutil
from unittest import TestCase

import pytest

from nf_core.utils import Pipeline

from .utils import create_tmp_pipeline
Expand All @@ -24,3 +26,7 @@ def _make_pipeline_copy(self):
new_pipeline = self.tmp_dir / "nf-core-testpipeline-copy"
shutil.copytree(self.pipeline_dir, new_pipeline)
return new_pipeline

@pytest.fixture(autouse=True)
def _use_caplog(self, caplog):
self.caplog = caplog

0 comments on commit 24e2dc2

Please sign in to comment.