Skip to content

Commit

Permalink
add update endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
N-Clerkx committed Sep 12, 2024
1 parent a23a020 commit 71dfa5b
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 0 deletions.
5 changes: 5 additions & 0 deletions api/dmsapi/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class Keyword_GroupCreate(Keyword_GroupBase):
pass


class Keyword_GroupUpdate(SQLModel, table=False):
group_name_nl: str | None = Field(min_length=1, max_length=100, default=None)
group_name_en: str | None = Field(min_length=1, max_length=100, default=None)


class Keyword_GroupPublic(Keyword_GroupBase):
id: uuid.UUID

Expand Down
59 changes: 59 additions & 0 deletions api/dmsapi/extensions/keywords/keyword_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Keyword_Group,
Keyword_GroupCreate,
Keyword_GroupPublicWithKeywords,
Keyword_GroupUpdate,
KeywordBase,
KeywordUpdate,
OKResponse,
Expand Down Expand Up @@ -46,6 +47,34 @@ def create_facility(self, facility: FacilityCreate) -> Facility:
session.refresh(db_facility)
return db_facility

def update_facility(
self, facility_id: str, facility_update: FacilityCreate
) -> Facility:
"""Update a facility by ID.
Args:
facility_id: ID of the facility to update.
facility: facility to update.
Returns:
updated facility.
"""
try:
uuid = UUID(facility_id)
except ValueError:
raise InvalidQueryParameter(f"Facility ID {facility_id} invalid UUID")
with Session(self.db_engine) as session:
facility = session.get(Facility, uuid)
if facility is None:
raise NotFoundError(f"Facility with ID {facility_id} not found")

for key, value in facility_update.model_dump().items():
if value:
setattr(facility, key, value)
session.add(facility)
session.commit()
session.refresh(facility)
return facility

def get_facility(
self, facility_id: Annotated[str, Path(title="The ID of the facility to get")]
) -> Facility:
Expand Down Expand Up @@ -202,6 +231,36 @@ def create_keywordgroup(self, keywordgroup: Keyword_GroupCreate):
session.refresh(db_keywordgroup)
return db_keywordgroup

def update_keywordgroup(
self,
keywordgroup_id: Annotated[str, Path(title="The ID of the group to update")],
keywordgroup_update: Keyword_GroupUpdate,
):
"""Update an existing keyword group.
Args:
keywordgroup: updated keyword group.
keywordgroup_id: ID of the keyword group to update.
"""
try:
uuid = UUID(keywordgroup_id)
except ValueError:
raise InvalidQueryParameter(
f"Keywordgroup ID {keywordgroup_id} invalid UUID"
)
with Session(self.db_engine) as session:
keywordgroup = session.get(Keyword_Group, uuid)
if keywordgroup is None:
raise NotFoundError(f"Keywordgroup ID {keywordgroup_id} not found")

for key, value in keywordgroup_update.model_dump().items():
if value:
setattr(keywordgroup, key, value)
session.add(keywordgroup)
session.commit()
session.refresh(keywordgroup)
return keywordgroup

def get_keyword_group(
self, keywordgroup_id: Annotated[str, Path(title="The ID of the group to get")]
) -> Keyword_Group:
Expand Down
42 changes: 42 additions & 0 deletions api/dmsapi/extensions/keywords/keyword_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class KeywordExtension(ApiExtension):
Manage facilities:
POST /facility
PUT /facility/{facility_id}
GET /facilities
GET /facility/{facility_id}
PUT /facility/{facility_id}
Expand All @@ -32,6 +33,7 @@ class KeywordExtension(ApiExtension):
Manage keyword groups:
POST /keywordgroup
PUT /keywordgroup/{keywordgroup_id}
GET /keywordgroups
GET /keywordgroup/{keywordgroup_id}
DELETE /keywordgroup/{keywordgroup_id}
Expand Down Expand Up @@ -68,12 +70,14 @@ def register(self, app: FastAPI) -> None:
"""
self.router.prefix = app.state.router_prefix
self.add_create_facility()
self.add_update_facility()
self.add_get_facility()
self.add_get_facilities()
self.add_delete_facility()
self.add_link_keyword_group_to_facility()
self.add_unlink_keyword_group_from_facility()
self.add_create_keyword_group()
self.add_update_keyword_group()
self.add_get_keyword_group()
self.add_get_keyword_groups()
self.add_delete_keyword_group()
Expand All @@ -93,6 +97,35 @@ def add_create_facility(self):
methods=["POST"],
)

def add_update_facility(self):
self.router.add_api_route(
name="Update Facility",
path="/facility/{facility_id}",
endpoint=self.client.update_facility,
response_model=Facility,
responses={
200: {
"content": {
MimeTypes.json.value: {},
},
"model": Facility,
},
400: {
"content": {
MimeTypes.json.value: {},
},
"model": ErrorResponse,
},
404: {
"content": {
MimeTypes.json.value: {},
},
"model": ErrorResponse,
},
},
methods=["PUT"],
)

def add_get_facility(self):
self.router.add_api_route(
name="Get Facility",
Expand Down Expand Up @@ -187,6 +220,15 @@ def add_create_keyword_group(self):
methods=["POST"],
)

def add_update_keyword_group(self):
self.router.add_api_route(
name="Update Keyword Group",
path="/keywordgroup/{keywordgroup_id}",
endpoint=self.client.update_keywordgroup,
response_model=Keyword_GroupPublic,
methods=["PUT"],
)

def add_get_keyword_group(self):
self.router.add_api_route(
name="Get Keyword Group",
Expand Down
15 changes: 15 additions & 0 deletions api/tests/facility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ async def test_create_facility_invalid(app_client: AsyncClient):
assert response.json()["code"] == "RequestValidationError"


# test get facility
@pytest.mark.asyncio
async def test_update_facility(app_client: AsyncClient, facility: Facility):
response = await app_client.get(f"/facility/{facility.id}")
assert response.status_code == 200
facility_obj = Facility(**response.json())
assert facility_obj.name == "test_facility"
update_response = await app_client.put(
f"/facility/{facility.id}", json={"name": "updated_facility"}
)
assert update_response.status_code == 200
updated_facility_obj = Facility(**update_response.json())
assert updated_facility_obj.name == "updated_facility"


# test get facility
@pytest.mark.asyncio
async def test_get_facility(app_client: AsyncClient, facility: Facility):
Expand Down
19 changes: 19 additions & 0 deletions api/tests/keyword_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ async def test_create_keywordgroup_invalid(
assert response.json()["code"] == "RequestValidationError"


@pytest.mark.asyncio
async def test_update_keywordgroup(
app_client: AsyncClient, keyword_group: Keyword_Group
):
keyword_group_json = await app_client.get(f"/keywordgroup/{keyword_group.id}")
assert keyword_group is not None
keyword_group_obj = Keyword_Group(**keyword_group_json.json())
assert keyword_group_obj.group_name_nl == "test"
assert keyword_group_obj.group_name_en == "engelse_test"

response = await app_client.put(
f"/keywordgroup/{keyword_group.id}",
json={"group_name_nl": "updated_test", "group_name_en": "updated_engelse_test"},
)
assert response.status_code == 200
assert response.json()["group_name_nl"] == "updated_test"
assert response.json()["group_name_en"] == "updated_engelse_test"


@pytest.mark.asyncio
async def test_get_keywordgroup(app_client: AsyncClient, keyword_group: Keyword_Group):
keyword_group_json = await app_client.get(f"/keywordgroup/{keyword_group.id}")
Expand Down

0 comments on commit 71dfa5b

Please sign in to comment.