diff --git a/api/dmsapi/database/models.py b/api/dmsapi/database/models.py index a36eba3..0b07afb 100644 --- a/api/dmsapi/database/models.py +++ b/api/dmsapi/database/models.py @@ -29,6 +29,11 @@ class Keyword_GroupCreate(Keyword_GroupBase): pass +class Keyword_GroupUpdate(SQLModel, table=False): + group_name_nl: str | None = Field(min_length=1, max_length=100, default=None) + group_name_en: str | None = Field(min_length=1, max_length=100, default=None) + + class Keyword_GroupPublic(Keyword_GroupBase): id: uuid.UUID diff --git a/api/dmsapi/extensions/keywords/keyword_client.py b/api/dmsapi/extensions/keywords/keyword_client.py index 71e5781..1721174 100644 --- a/api/dmsapi/extensions/keywords/keyword_client.py +++ b/api/dmsapi/extensions/keywords/keyword_client.py @@ -15,6 +15,7 @@ Keyword_Group, Keyword_GroupCreate, Keyword_GroupPublicWithKeywords, + Keyword_GroupUpdate, KeywordBase, KeywordUpdate, OKResponse, @@ -46,6 +47,34 @@ def create_facility(self, facility: FacilityCreate) -> Facility: session.refresh(db_facility) return db_facility + def update_facility( + self, facility_id: str, facility_update: FacilityCreate + ) -> Facility: + """Update a facility by ID. + + Args: + facility_id: ID of the facility to update. + facility: facility to update. + Returns: + updated facility. + """ + try: + uuid = UUID(facility_id) + except ValueError: + raise InvalidQueryParameter(f"Facility ID {facility_id} invalid UUID") + with Session(self.db_engine) as session: + facility = session.get(Facility, uuid) + if facility is None: + raise NotFoundError(f"Facility with ID {facility_id} not found") + + for key, value in facility_update.model_dump().items(): + if value: + setattr(facility, key, value) + session.add(facility) + session.commit() + session.refresh(facility) + return facility + def get_facility( self, facility_id: Annotated[str, Path(title="The ID of the facility to get")] ) -> Facility: @@ -202,6 +231,36 @@ def create_keywordgroup(self, keywordgroup: Keyword_GroupCreate): session.refresh(db_keywordgroup) return db_keywordgroup + def update_keywordgroup( + self, + keywordgroup_id: Annotated[str, Path(title="The ID of the group to update")], + keywordgroup_update: Keyword_GroupUpdate, + ): + """Update an existing keyword group. + + Args: + keywordgroup: updated keyword group. + keywordgroup_id: ID of the keyword group to update. + """ + try: + uuid = UUID(keywordgroup_id) + except ValueError: + raise InvalidQueryParameter( + f"Keywordgroup ID {keywordgroup_id} invalid UUID" + ) + with Session(self.db_engine) as session: + keywordgroup = session.get(Keyword_Group, uuid) + if keywordgroup is None: + raise NotFoundError(f"Keywordgroup ID {keywordgroup_id} not found") + + for key, value in keywordgroup_update.model_dump().items(): + if value: + setattr(keywordgroup, key, value) + session.add(keywordgroup) + session.commit() + session.refresh(keywordgroup) + return keywordgroup + def get_keyword_group( self, keywordgroup_id: Annotated[str, Path(title="The ID of the group to get")] ) -> Keyword_Group: diff --git a/api/dmsapi/extensions/keywords/keyword_extension.py b/api/dmsapi/extensions/keywords/keyword_extension.py index 0fe57eb..b0b5d15 100644 --- a/api/dmsapi/extensions/keywords/keyword_extension.py +++ b/api/dmsapi/extensions/keywords/keyword_extension.py @@ -21,6 +21,7 @@ class KeywordExtension(ApiExtension): Manage facilities: POST /facility + PUT /facility/{facility_id} GET /facilities GET /facility/{facility_id} PUT /facility/{facility_id} @@ -32,6 +33,7 @@ class KeywordExtension(ApiExtension): Manage keyword groups: POST /keywordgroup + PUT /keywordgroup/{keywordgroup_id} GET /keywordgroups GET /keywordgroup/{keywordgroup_id} DELETE /keywordgroup/{keywordgroup_id} @@ -68,12 +70,14 @@ def register(self, app: FastAPI) -> None: """ self.router.prefix = app.state.router_prefix self.add_create_facility() + self.add_update_facility() self.add_get_facility() self.add_get_facilities() self.add_delete_facility() self.add_link_keyword_group_to_facility() self.add_unlink_keyword_group_from_facility() self.add_create_keyword_group() + self.add_update_keyword_group() self.add_get_keyword_group() self.add_get_keyword_groups() self.add_delete_keyword_group() @@ -93,6 +97,35 @@ def add_create_facility(self): methods=["POST"], ) + def add_update_facility(self): + self.router.add_api_route( + name="Update Facility", + path="/facility/{facility_id}", + endpoint=self.client.update_facility, + response_model=Facility, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Facility, + }, + 400: { + "content": { + MimeTypes.json.value: {}, + }, + "model": ErrorResponse, + }, + 404: { + "content": { + MimeTypes.json.value: {}, + }, + "model": ErrorResponse, + }, + }, + methods=["PUT"], + ) + def add_get_facility(self): self.router.add_api_route( name="Get Facility", @@ -187,6 +220,15 @@ def add_create_keyword_group(self): methods=["POST"], ) + def add_update_keyword_group(self): + self.router.add_api_route( + name="Update Keyword Group", + path="/keywordgroup/{keywordgroup_id}", + endpoint=self.client.update_keywordgroup, + response_model=Keyword_GroupPublic, + methods=["PUT"], + ) + def add_get_keyword_group(self): self.router.add_api_route( name="Get Keyword Group", diff --git a/api/tests/facility_test.py b/api/tests/facility_test.py index 1758c4b..08e403d 100644 --- a/api/tests/facility_test.py +++ b/api/tests/facility_test.py @@ -23,6 +23,21 @@ async def test_create_facility_invalid(app_client: AsyncClient): assert response.json()["code"] == "RequestValidationError" +# test get facility +@pytest.mark.asyncio +async def test_update_facility(app_client: AsyncClient, facility: Facility): + response = await app_client.get(f"/facility/{facility.id}") + assert response.status_code == 200 + facility_obj = Facility(**response.json()) + assert facility_obj.name == "test_facility" + update_response = await app_client.put( + f"/facility/{facility.id}", json={"name": "updated_facility"} + ) + assert update_response.status_code == 200 + updated_facility_obj = Facility(**update_response.json()) + assert updated_facility_obj.name == "updated_facility" + + # test get facility @pytest.mark.asyncio async def test_get_facility(app_client: AsyncClient, facility: Facility): diff --git a/api/tests/keyword_group_test.py b/api/tests/keyword_group_test.py index ff2f73c..7543aaf 100644 --- a/api/tests/keyword_group_test.py +++ b/api/tests/keyword_group_test.py @@ -26,6 +26,25 @@ async def test_create_keywordgroup_invalid( assert response.json()["code"] == "RequestValidationError" +@pytest.mark.asyncio +async def test_update_keywordgroup( + app_client: AsyncClient, keyword_group: Keyword_Group +): + keyword_group_json = await app_client.get(f"/keywordgroup/{keyword_group.id}") + assert keyword_group is not None + keyword_group_obj = Keyword_Group(**keyword_group_json.json()) + assert keyword_group_obj.group_name_nl == "test" + assert keyword_group_obj.group_name_en == "engelse_test" + + response = await app_client.put( + f"/keywordgroup/{keyword_group.id}", + json={"group_name_nl": "updated_test", "group_name_en": "updated_engelse_test"}, + ) + assert response.status_code == 200 + assert response.json()["group_name_nl"] == "updated_test" + assert response.json()["group_name_en"] == "updated_engelse_test" + + @pytest.mark.asyncio async def test_get_keywordgroup(app_client: AsyncClient, keyword_group: Keyword_Group): keyword_group_json = await app_client.get(f"/keywordgroup/{keyword_group.id}")