diff --git a/app/api/__init__.py b/app/api/__init__.py index 85526d6..bf7ecef 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -3,9 +3,8 @@ :license: MIT, see LICENSE for more details. """ -from spectree import SecurityScheme - from lin import SpecTree +from spectree import SecurityScheme api = SpecTree( backend_name="flask", diff --git a/app/api/cms/admin.py b/app/api/cms/admin.py index e3845ff..24df14e 100644 --- a/app/api/cms/admin.py +++ b/app/api/cms/admin.py @@ -6,19 +6,9 @@ """ import math -from flask import Blueprint, request -from sqlalchemy import func - -from app.api.cms.validator import ( - DispatchAuth, - DispatchAuths, - NewGroup, - RemoveAuths, - ResetPasswordForm, - UpdateGroup, - UpdateUserInfoForm, -) +from flask import Blueprint, g from lin import ( + DocResponse, Forbidden, GroupLevelEnum, Logger, @@ -30,7 +20,21 @@ manager, permission_meta, ) -from app.util.page import get_page_from_query, paginate +from sqlalchemy import func + +from app.api import AuthorizationBearerSecurity, api +from app.api.cms.schema import ResetPasswordSchema +from app.api.cms.schema.admin import ( + AdminGroupListSchema, + AdminGroupPermissionSchema, + AdminUserPageSchema, + CreateGroupSchema, + GroupBaseSchema, + GroupIdWithPermissionIdListSchema, + GroupQuerySearchSchema, + UpdateUserInfoSchema, +) +from app.util.page import get_page_from_query admin_api = Blueprint("admin", __name__) @@ -38,16 +42,30 @@ @admin_api.route("/permission") @permission_meta(name="查询所有可分配的权限", module="管理员", mount=False) @admin_required +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], +) def permissions(): + """ + 查询所有可分配的权限 + """ return manager.get_ep_infos() @admin_api.route("/users") @permission_meta(name="查询所有用户", module="管理员", mount=False) @admin_required -def get_admin_users(): - start, count = paginate() - group_id = request.args.get("group_id") +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(r=AdminUserPageSchema), + before=GroupQuerySearchSchema.offset_handler, +) +def get_admin_users(query: GroupQuerySearchSchema): + """ + 查询所有用户 + """ # 根据筛选条件和分页,获取 用户id 与 用户组id 的一对多 数据 # 过滤root 分组 query_root_group_id = db.session.query(manager.group_model.id).filter( @@ -57,14 +75,14 @@ def get_admin_users(): manager.user_group_model.user_id.label("user_id"), func.group_concat(manager.user_group_model.group_id).label("group_ids"), ).filter(~manager.user_group_model.group_id.in_(query_root_group_id)) - if group_id: + if g.group_id: _user_groups_list = _user_groups_list.filter( - manager.user_group_model.group_id == group_id + manager.user_group_model.group_id == g.group_id ) user_groups_list = ( _user_groups_list.group_by(manager.user_group_model.user_id) - .offset(start) - .limit(count) + .offset(g.offset) + .limit(g.count) .all() ) @@ -73,8 +91,8 @@ def get_admin_users(): _total = db.session.query( func.count(func.distinct(manager.user_group_model.user_id)) ).filter(~manager.user_group_model.group_id.in_(query_root_group_id)) - if group_id: - _total = _total.filter(manager.user_group_model.group_id == group_id) + if g.group_id: + _total = _total.filter(manager.user_group_model.group_id == g.group_id) total = _total.scalar() # 获取本次需要返回的用户的数据 @@ -87,7 +105,7 @@ def get_admin_users(): user_dict[user.id] = user # 使用用户组来过滤,则不需要补全用户组信息 - if not group_id: + if not g.group_id: # 拿到本次请求返回用户的所有 用户组 id List group_ids = [ int(i) @@ -109,29 +127,32 @@ def get_admin_users(): user_dict[user_id].groups = groups items.append(user_dict[user_id]) - total_page = math.ceil(total / count) + total_page = math.ceil(total / g.count) page = get_page_from_query() - return { - "count": count, - "items": users, - "page": page, - "total": total, - "total_page": total_page, - } + return AdminUserPageSchema( + count=g.count, total=total, total_page=total_page, page=page, items=users + ) @admin_api.route("/user//password", methods=["PUT"]) @permission_meta(name="修改用户密码", module="管理员", mount=False) @admin_required -def change_user_password(uid): - form = ResetPasswordForm().validate_for_api() +@api.validate( + tags=["管理员"], + resp=DocResponse(NotFound("用户不存在"), Success("密码修改成功")), + security=[AuthorizationBearerSecurity], +) +def change_user_password(uid: int, json: ResetPasswordSchema): + """ + 修改用户密码 + """ user = manager.find_user(id=uid) - if user is None: + if not user: raise NotFound("用户不存在") with db.auto_commit(): - user.reset_password(form.new_password.data) + user.reset_password(g.new_password) return Success("密码修改成功") @@ -140,7 +161,15 @@ def change_user_password(uid): @permission_meta(name="删除用户", module="管理员", mount=False) @Logger(template="管理员删除了一个用户") @admin_required +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(NotFound("用户不存在"), Success("删除成功"), Forbidden("用户不能删除")), +) def delete_user(uid): + """ + 删除用户 + """ user = manager.user_model.get(id=uid) if user is None: raise NotFound("用户不存在") @@ -159,19 +188,29 @@ def delete_user(uid): @admin_api.route("/user/", methods=["PUT"]) @permission_meta(name="管理员更新用户信息", module="管理员", mount=False) @admin_required -def update_user(uid): - form = UpdateUserInfoForm().validate_for_api() - +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse( + NotFound("用户不存在"), + Success("更新成功"), + ParameterError("邮箱已被注册,请重新输入邮箱"), + ), +) +def update_user(uid: int, json: UpdateUserInfoSchema): + """ + 更新用户信息 + """ user = manager.user_model.get(id=uid) if user is None: raise NotFound("用户不存在") - if user.email != form.email.data: - exists = manager.user_model.get(email=form.email.data) + if user.email != g.email: + exists = manager.user_model.get(email=g.email) if exists: raise ParameterError("邮箱已被注册,请重新输入邮箱") with db.auto_commit(): - user.email = form.email.data - group_ids = form.group_ids.data + user.email = g.email + group_ids = g.group_ids # 清空原来的所有关联关系 manager.user_group_model.query.filter_by(user_id=user.id).delete( synchronize_session=False @@ -179,7 +218,7 @@ def update_user(uid): # 根据传入分组ids 新增关联记录 user_group_list = list() # 如果没传分组数据,则将其设定为 guest 分组 - if len(group_ids) == 0: + if not group_ids: group_ids = [manager.group_model.get(level=GroupLevelEnum.GUEST.value).id] for group_id in group_ids: user_group = manager.user_group_model() @@ -190,52 +229,18 @@ def update_user(uid): return Success("操作成功") -@admin_api.route("/group") -@permission_meta(name="查询所有分组及其权限", module="管理员", mount=False) -@admin_required -def get_admin_groups(): - start, count = paginate() - groups = ( - manager.group_model.query.filter( - manager.group_model.level != GroupLevelEnum.ROOT.value - ) - .offset(start) - .limit(count) - .all() - ) - if groups is None: - raise NotFound("不存在任何分组") - - for group in groups: - permissions = manager.permission_model.select_by_group_id(group.id) - setattr(group, "permissions", permissions) - group._fields.append("permissions") - - # root分组隐藏不显示 - total = ( - db.session.query(func.count(manager.group_model.id)) - .filter( - manager.group_model.level != GroupLevelEnum.ROOT.value, - manager.group_model.is_deleted == False, - ) - .scalar() - ) - total_page = math.ceil(total / count) - page = get_page_from_query() - - return { - "count": count, - "items": groups, - "page": page, - "total": total, - "total_page": total_page, - } - - @admin_api.route("/group/all") @permission_meta(name="查询所有分组", module="管理员", mount=False) @admin_required +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(NotFound("不存在任何分组"), r=AdminGroupListSchema), +) def get_all_group(): + """ + 获取所有分组 + """ groups = manager.group_model.query.filter( manager.group_model.is_deleted == False, manager.group_model.level != GroupLevelEnum.ROOT.value, @@ -248,7 +253,15 @@ def get_all_group(): @admin_api.route("/group/") @permission_meta(name="查询一个分组及其权限", module="管理员", mount=False) @admin_required +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(NotFound("分组不存在"), r=AdminGroupPermissionSchema), +) def get_group(gid): + """ + 获取一个分组及其权限 + """ group = manager.group_model.get(id=gid, one=True, soft=False) if group is None: raise NotFound("分组不存在") @@ -262,19 +275,29 @@ def get_group(gid): @permission_meta(name="新建分组", module="管理员", mount=False) @Logger(template="管理员新建了一个分组") # 记录日志 @admin_required -def create_group(): - form = NewGroup().validate_for_api() - exists = manager.group_model.get(name=form.name.data) +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse( + ParameterError("用户组已存在, 不可创建同名分组"), + Success("新建用户组成功"), + ), +) +def create_group(json: CreateGroupSchema): + """ + 新建分组 + """ + exists = manager.group_model.get(name=g.name) if exists: raise Forbidden("分组已存在,不可创建同名分组") with db.auto_commit(): group = manager.group_model.create( - name=form.name.data, - info=form.info.data, + name=g.name, + info=g.info, ) db.session.flush() group_permission_list = list() - for permission_id in form.permission_ids.data: + for permission_id in g.permission_ids: gp = manager.group_permission_model() gp.group_id = group.id gp.permission_id = permission_id @@ -286,20 +309,38 @@ def create_group(): @admin_api.route("/group/", methods=["PUT"]) @permission_meta(name="更新一个分组", module="管理员", mount=False) @admin_required -def update_group(gid): - form = UpdateGroup().validate_for_api() +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse( + ParameterError("分组不存在,更新失败"), + Success("更新分组成功"), + ), +) +def update_group(gid, json: GroupBaseSchema): + """ + 更新一个分组基本信息 + """ exists = manager.group_model.get(id=gid) if not exists: raise NotFound("分组不存在,更新失败") - exists.update(name=form.name.data, info=form.info.data, commit=True) - return Success("更新分组成功") + exists.update(name=g.name, info=g.info, commit=True) + return Success("更新成功") @admin_api.route("/group/", methods=["DELETE"]) @permission_meta(name="删除一个分组", module="管理员", mount=False) @Logger(template="管理员删除一个分组") # 记录日志 @admin_required +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(NotFound("分组不存在,删除失败"), Forbidden("分组不可删除"), Success("删除分组成功")), +) def delete_group(gid): + """ + 删除一个分组 + """ exist = manager.group_model.get(id=gid) if not exist: raise NotFound("分组不存在,删除失败") @@ -319,35 +360,28 @@ def delete_group(gid): return Success("删除分组成功") -@admin_api.route("/permission/dispatch", methods=["POST"]) -@permission_meta(name="分配单个权限", module="管理员", mount=False) -@admin_required -def dispatch_auth(): - form = DispatchAuth().validate_for_api() - one = manager.group_permission_model.get( - group_id=form.group_id.data, permission_id=form.permission_id.data - ) - if one: - raise Forbidden("已有权限,不可重复添加") - manager.group_permission_model.create( - group_id=form.group_id.data, permission_id=form.permission_id.data, commit=True - ) - return Success("添加权限成功") - - @admin_api.route("/permission/dispatch/batch", methods=["POST"]) @permission_meta(name="分配多个权限", module="管理员", mount=False) @admin_required -def dispatch_auths(): - form = DispatchAuths().validate_for_api() +@api.validate( + tags=["管理员"], + security=[AuthorizationBearerSecurity], + resp=DocResponse( + Success("添加权限成功"), + ), +) +def dispatch_auths(json: GroupIdWithPermissionIdListSchema): + """ + 分配多个权限 + """ with db.auto_commit(): - for permission_id in form.permission_ids.data: + for permission_id in g.permission_ids: one = manager.group_permission_model.get( - group_id=form.group_id.data, permission_id=permission_id + group_id=g.group_id, permission_id=permission_id ) if not one: manager.group_permission_model.create( - group_id=form.group_id.data, + group_id=g.group_id, permission_id=permission_id, ) return Success("添加权限成功") @@ -356,13 +390,20 @@ def dispatch_auths(): @admin_api.route("/permission/remove", methods=["POST"]) @permission_meta(name="删除多个权限", module="管理员", mount=False) @admin_required -def remove_auths(): - form = RemoveAuths().validate_for_api() +@api.validate( + tags=["管理员"], + resp=DocResponse(Success("删除权限成功")), + security=[AuthorizationBearerSecurity], +) +def remove_auths(json: GroupIdWithPermissionIdListSchema): + """ + 删除多个权限 + """ with db.auto_commit(): db.session.query(manager.group_permission_model).filter( - manager.group_permission_model.permission_id.in_(form.permission_ids.data), - manager.group_permission_model.group_id == form.group_id.data, + manager.group_permission_model.permission_id.in_(g.permission_ids), + manager.group_permission_model.group_id == g.group_id, ).delete(synchronize_session=False) return Success("删除权限成功") diff --git a/app/api/cms/file.py b/app/api/cms/file.py index a0f732e..821a982 100644 --- a/app/api/cms/file.py +++ b/app/api/cms/file.py @@ -3,8 +3,9 @@ :license: MIT, see LICENSE for more details. """ from flask import Blueprint, request - from lin import login_required + +from app.api import AuthorizationBearerSecurity, api from app.extension.file.local_uploader import LocalUploader file_api = Blueprint("file", __name__) @@ -12,7 +13,14 @@ @file_api.route("", methods=["POST"]) @login_required +@api.validate( + tags=["文件"], + security=[AuthorizationBearerSecurity], +) def post_file(): + """ + 上传文件 + """ files = request.files uploader = LocalUploader(files) ret = uploader.upload() diff --git a/app/api/cms/log.py b/app/api/cms/log.py index 8967717..9899670 100644 --- a/app/api/cms/log.py +++ b/app/api/cms/log.py @@ -1,11 +1,15 @@ import math from flask import Blueprint, g +from lin import DocResponse, Log, db, group_required, permission_meta from sqlalchemy import text from app.api import AuthorizationBearerSecurity, api -from app.api.cms.schema import LogPageSchema, LogQuerySearchSchema, UsernameListSchema -from lin import DocResponse, Log, db, group_required, permission_meta +from app.api.cms.schema.log import ( + LogPageSchema, + LogQuerySearchSchema, + UsernameListSchema, +) log_api = Blueprint("log", __name__) diff --git a/app/api/cms/model/user.py b/app/api/cms/model/user.py index b423f47..f462195 100644 --- a/app/api/cms/model/user.py +++ b/app/api/cms/model/user.py @@ -1,7 +1,6 @@ -from sqlalchemy import func - from lin import User as LinUser from lin import db, manager +from sqlalchemy import func class User(LinUser): diff --git a/app/api/cms/schema/__init__.py b/app/api/cms/schema/__init__.py index ba681c2..459c476 100644 --- a/app/api/cms/schema/__init__.py +++ b/app/api/cms/schema/__init__.py @@ -1,66 +1,33 @@ -import re -from datetime import datetime from typing import List, Optional -from flask import g -from pydantic import Field, validator +from lin import BaseModel, ParameterError +from pydantic import EmailStr, Field, validator -from lin import BaseModel -from app.schema import BasePageSchema, datetime_regex +class EmailSchema(BaseModel): + email: Optional[str] = Field(description="用户邮箱") -class UsernameListSchema(BaseModel): - items: List[str] + @validator("email") + def check_email(cls, v, values, **kwargs): + return EmailStr.validate(v) if v else "" -class LogQuerySearchSchema(BaseModel): - keyword: Optional[str] = None - name: Optional[str] = None - start: Optional[str] = Field(None, description="YY-MM-DD HH:MM:SS") - end: Optional[str] = Field(None, description="YY-MM-DD HH:MM:SS") - count: int = Field(5, gt=0, lt=16, description="0 < count < 16") - page: int = 0 +class ResetPasswordSchema(BaseModel): + new_password: str = Field(description="新密码", min_length=6, max_length=22) + confirm_password: str = Field(description="确认密码", min_length=6, max_length=22) - @validator("start", "end") - def datetime_match(cls, v, values, **kwargs): - if re.match(datetime_regex, v): - return v - raise ValueError("时间格式有误") + @validator("confirm_password") + def passwords_match(cls, v, values, **kwargs): + if v != values["new_password"]: + raise ParameterError("两次输入的密码不一致,请输入相同的密码") + return v - @staticmethod - def offset_handler(req, resp, req_validation_error, instance): - g.offset = req.context.query.count * req.context.query.page +class GroupIdListSchema(BaseModel): + group_ids: List[int] = Field(description="用户组ID列表") -class LogSchema(BaseModel): - message: str - user_id: int - username: str - status_code: int - method: str - path: str - permission: str - time: datetime = Field(alias="create_time") - - -class LogPageSchema(BasePageSchema): - items: List[LogSchema] - - -class LoginSchema(BaseModel): - username: str = Field(description="用户名") - password: str = Field(description="密码") - captcha: Optional[str] = Field(description="验证码") - - -class AccessTokenSchema(BaseModel): - __root__: str = Field(description="access_token") - - -class RefreshTokenSchema(BaseModel): - __root__: str = Field(description="refresh_token") - - -class LoginTokenSchema(BaseModel): - access_token: AccessTokenSchema - refresh_token: RefreshTokenSchema + @validator("group_ids", each_item=True) + def check_group_id(cls, v, values, **kwargs): + if v <= 0: + raise ParameterError("用户组ID必须大于0") + return v diff --git a/app/api/cms/schema/admin.py b/app/api/cms/schema/admin.py new file mode 100644 index 0000000..3951495 --- /dev/null +++ b/app/api/cms/schema/admin.py @@ -0,0 +1,77 @@ +from typing import List, Optional + +from lin import BaseModel, ParameterError +from pydantic import Field, validator + +from app.schema import BasePageSchema, QueryPageSchema + +from . import EmailSchema, GroupIdListSchema + + +class AdminGroupSchema(BaseModel): + id: int = Field(description="用户组ID") + info: str = Field(description="用户组信息") + name: str = Field(description="用户组名称") + + +class AdminGroupListSchema(BaseModel): + __root__: List[AdminGroupSchema] + + +class AdminUserSchema(EmailSchema): + id: int = Field(description="用户ID") + username: str = Field(description="用户名") + groups: List[AdminGroupSchema] = Field(description="用户组列表") + + +class AdminUserPageSchema(BasePageSchema): + items: List[AdminUserSchema] + + +class GroupQuerySearchSchema(QueryPageSchema): + group_id: Optional[int] = Field(gt=0, description="用户组ID") + + +class UpdateUserInfoSchema(GroupIdListSchema, EmailSchema): + pass + + +class PermissionSchema(BaseModel): + id: int = Field(description="权限ID") + name: str = Field(description="权限名称") + module: str = Field(description="权限所属模块") + mount: bool = Field(description="是否为挂载权限") + + +class AdminGroupPermissionSchema(AdminGroupSchema): + permissions: List[PermissionSchema] + + +class AdminGroupPermissionPageSchema(BasePageSchema): + items: List[AdminGroupPermissionSchema] + + +class GroupBaseSchema(BaseModel): + name: str = Field(description="用户组名称") + info: Optional[str] = Field(description="用户组信息") + + +class CreateGroupSchema(GroupBaseSchema): + permission_ids: List[int] = Field(description="权限ID列表") + + @validator("permission_ids", each_item=True) + def check_permission_id(cls, v, values, **kwargs): + if v <= 0: + raise ParameterError("权限ID必须大于0") + return v + + +class GroupIdWithPermissionIdListSchema(BaseModel): + group_id: int = Field(description="用户组ID") + permission_ids: List[int] = Field(description="权限ID列表") + + @validator("permission_ids", each_item=True) + def check_permission_id(cls, v, values, **kwargs): + if v <= 0: + raise ParameterError("权限ID必须大于0") + return v diff --git a/app/api/cms/schema/log.py b/app/api/cms/schema/log.py new file mode 100644 index 0000000..bd2a092 --- /dev/null +++ b/app/api/cms/schema/log.py @@ -0,0 +1,40 @@ +import re +from datetime import datetime +from typing import List, Optional + +from lin import BaseModel +from pydantic import Field, validator + +from app.schema import BasePageSchema, QueryPageSchema, datetime_regex + + +class UsernameListSchema(BaseModel): + items: List[str] + + +class LogQuerySearchSchema(QueryPageSchema): + keyword: Optional[str] = None + name: Optional[str] = None + start: Optional[str] = Field(None, description="YY-MM-DD HH:MM:SS") + end: Optional[str] = Field(None, description="YY-MM-DD HH:MM:SS") + + @validator("start", "end") + def datetime_match(cls, v): + if re.match(datetime_regex, v): + return v + raise ValueError("时间格式有误") + + +class LogSchema(BaseModel): + message: str + user_id: int + username: str + status_code: int + method: str + path: str + permission: str + time: datetime = Field(alias="create_time") + + +class LogPageSchema(BasePageSchema): + items: List[LogSchema] diff --git a/app/api/cms/schema/user.py b/app/api/cms/schema/user.py new file mode 100644 index 0000000..da6b514 --- /dev/null +++ b/app/api/cms/schema/user.py @@ -0,0 +1,68 @@ +import re +from typing import List, Optional + +from lin import BaseModel, ParameterError +from pydantic import AnyHttpUrl, Field, validator + +from . import EmailSchema, GroupIdListSchema, ResetPasswordSchema + + +class LoginSchema(BaseModel): + username: str = Field(description="用户名") + password: str = Field(description="密码") + captcha: Optional[str] = Field(description="验证码") + + +class LoginTokenSchema(BaseModel): + access_token: str = Field(description="access_token") + refresh_token: str = Field(description="refresh_token") + + +class CaptchaSchema(BaseModel): + image: str = Field("", description="验证码图片base64编码") + tag: str = Field("", description="验证码标记码") + + +class PermissionNameSchema(BaseModel): + name: str = Field(description="权限名称") + + +class PermissionModuleSchema(BaseModel): + module: List[PermissionNameSchema] = Field(description="权限模块") + + +class UserBaseInfoSchema(EmailSchema): + nickname: Optional[str] = Field(description="用户昵称", min_length=2, max_length=10) + avatar: Optional[AnyHttpUrl] = Field(description="头像url") + + +class UserSchema(UserBaseInfoSchema): + id: int = Field(description="用户id") + username: str = Field(description="用户名") + + +class UserPermissionSchema(UserSchema): + admin: bool = Field(description="是否是管理员") + permissions: List[PermissionModuleSchema] = Field(description="用户权限") + + +class ChangePasswordSchema(ResetPasswordSchema): + old_password: str = Field(description="旧密码") + + +class UserRegisterSchema(GroupIdListSchema, EmailSchema): + username: str = Field(description="用户名", min_length=2, max_length=10) + password: str = Field(description="密码", min_length=6, max_length=22) + confirm_password: str = Field(description="确认密码", min_length=6, max_length=22) + + @validator("confirm_password") + def passwords_match(cls, v, values, **kwargs): + if v != values["password"]: + raise ParameterError("两次输入的密码不一致,请输入相同的密码") + return v + + @validator("username") + def check_username(cls, v, values, **kwargs): + if not re.match(r"^[a-zA-Z0-9_]{2,10}$", v): + raise ParameterError("用户名只能由字母、数字、下划线组成,且长度为2-10位") + return v diff --git a/app/api/cms/user.py b/app/api/cms/user.py index fe16ca5..c2f8374 100644 --- a/app/api/cms/user.py +++ b/app/api/cms/user.py @@ -4,7 +4,7 @@ :copyright: © 2020 by the Lin team. :license: MIT, see LICENSE for more details. """ -from flask import Blueprint, current_app, request +from flask import Blueprint, current_app, g, request from flask_jwt_extended import ( create_access_token, create_refresh_token, @@ -13,16 +13,6 @@ verify_jwt_in_request, ) from itsdangerous import JSONWebSignatureSerializer as JWSSerializer - -from app.api import api -from app.api.cms.exception import RefreshFailed -from app.api.cms.schema import LoginSchema, LoginTokenSchema -from app.api.cms.validator import ( - ChangePasswordForm, - LoginForm, - RegisterForm, - UpdateInfoForm, -) from lin import ( DocResponse, Duplicated, @@ -39,6 +29,19 @@ manager, permission_meta, ) + +from app.api import AuthorizationBearerSecurity, api +from app.api.cms.exception import RefreshFailed +from app.api.cms.schema.user import ( + CaptchaSchema, + ChangePasswordSchema, + LoginSchema, + LoginTokenSchema, + UserBaseInfoSchema, + UserPermissionSchema, + UserRegisterSchema, + UserSchema, +) from app.util.captcha import CaptchaTool from app.util.common import split_group @@ -49,14 +52,41 @@ @permission_meta(name="注册", module="用户", mount=False) @Logger(template="管理员新建了一个用户") # 记录日志 @admin_required -def register(): - form = RegisterForm().validate_for_api() - if manager.user_model.count_by_username(form.username.data) > 0: +@api.validate( + tags=["用户"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(Success("用户创建成功"), Duplicated("字段重复,请重新输入")), +) +def register(json: UserRegisterSchema): + """ + 注册新用户 + """ + if manager.user_model.count_by_username(g.username) > 0: raise Duplicated("用户名重复,请重新输入") # type: ignore - if form.email.data and form.email.data.strip() != "": - if manager.user_model.count_by_email(form.email.data) > 0: + if g.email and g.email.strip() != "": + if manager.user_model.count_by_email(g.email) > 0: raise Duplicated("注册邮箱重复,请重新输入") # type: ignore - _register_user(form) + # create a user + with db.auto_commit(): + user = manager.user_model() + user.username = g.username + if g.email and g.email.strip() != "": + user.email = g.email + db.session.add(user) + db.session.flush() + user.password = g.password + group_ids = g.group_ids + # 如果没传分组数据,则将其设定为 guest 分组 + if len(group_ids) == 0: + from lin import GroupLevelEnum + + group_ids = [GroupLevelEnum.GUEST.value] + for group_id in group_ids: + user_group = manager.user_group_model() + user_group.user_id = user.id + user_group.group_id = group_id + db.session.add(user_group) + return Success("用户创建成功") # type: ignore @@ -64,18 +94,17 @@ def register(): @api.validate(resp=DocResponse(Failed("验证码校验失败"), r=LoginTokenSchema), tags=["用户"]) def login(json: LoginSchema): """ - 用户登录, 返回 access_token 和 refresh_token + 用户登录 """ - form = LoginForm().validate_for_api() # 校对验证码 if current_app.config.get("LOGIN_CAPTCHA"): tag = request.headers.get("tag") secret_key = current_app.config.get("SECRET_KEY") serializer = JWSSerializer(secret_key) - if form.captcha.data != serializer.loads(tag): + if g.captcha != serializer.loads(tag): raise Failed("验证码校验失败") # type: ignore - user = manager.user_model.verify(form.username.data, form.password.data) + user = manager.user_model.verify(g.username, g.password) # 用户未登录,此处不能用装饰器记录日志 Log.create_log( message=f"{user.username}登录成功获取了令牌", @@ -88,41 +117,52 @@ def login(json: LoginSchema): commit=True, ) access_token, refresh_token = get_tokens(user) - return {"access_token": access_token, "refresh_token": refresh_token} + return LoginTokenSchema(access_token=access_token, refresh_token=refresh_token) @user_api.route("", methods=["PUT"]) @permission_meta(name="用户更新信息", module="用户", mount=False) @login_required -def update(): - form = UpdateInfoForm().validate_for_api() +@api.validate( + tags=["用户"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(Success("用户信息更新成功"), ParameterError("邮箱已被注册,请重新输入邮箱")), +) +def update(json: UserBaseInfoSchema): + """ + 更新用户信息 + """ user = get_current_user() - email = form.email.data - nickname = form.nickname.data - avatar = form.avatar.data - if email and user.email != email: - exists = manager.user_model.get(email=form.email.data) + if g.email and user.email != g.email: + exists = manager.user_model.get(email=g.email) if exists: raise ParameterError("邮箱已被注册,请重新输入邮箱") with db.auto_commit(): - if email: - user.email = form.email.data - if nickname: - user.nickname = form.nickname.data - if avatar: - user._avatar = form.avatar.data - return Success("操作成功") + if g.email: + user.email = g.email + if g.nickname: + user.nickname = g.nickname + if g.avatar: + user._avatar = g.avatar + return Success("用户信息更新成功") @user_api.route("/change_password", methods=["PUT"]) @permission_meta(name="修改密码", module="用户", mount=False) @Logger(template="{user.username}修改了自己的密码") # 记录日志 @login_required -def change_password(): - form = ChangePasswordForm().validate_for_api() +@api.validate( + tags=["用户"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(Success("密码修改成功"), Failed("密码修改失败")), +) +def change_password(json: ChangePasswordSchema): + """ + 修改密码 + """ user = get_current_user() - ok = user.change_password(form.old_password.data, form.new_password.data) + ok = user.change_password(g.old_password, g.new_password) if ok: db.session.commit() return Success("密码修改成功") @@ -133,24 +173,39 @@ def change_password(): @user_api.route("/information") @permission_meta(name="查询自己信息", module="用户", mount=False) @login_required +@api.validate( + tags=["用户"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(r=UserSchema), +) def get_information(): + """ + 获取用户信息 + """ current_user = get_current_user() return current_user @user_api.route("/refresh") @permission_meta(name="刷新令牌", module="用户", mount=False) +@api.validate( + resp=DocResponse(RefreshFailed, NotFound("refresh_token未被识别"), r=LoginTokenSchema), + tags=["用户"], +) def refresh(): + """ + 刷新令牌 + """ try: verify_jwt_in_request(refresh=True) except Exception: - return RefreshFailed() + return RefreshFailed identity = get_jwt_identity() if identity: access_token = create_access_token(identity=identity) refresh_token = create_refresh_token(identity=identity) - return {"access_token": access_token, "refresh_token": refresh_token} + return LoginTokenSchema(access_token=access_token, refresh_token=refresh_token) return NotFound("refresh_token未被识别") @@ -158,7 +213,15 @@ def refresh(): @user_api.route("/permissions") @permission_meta(name="查询自己拥有的权限", module="用户", mount=False) @login_required +@api.validate( + tags=["用户"], + security=[AuthorizationBearerSecurity], + resp=DocResponse(r=UserPermissionSchema), +) def get_allowed_apis(): + """ + 获取用户拥有的权限 + """ user = get_current_user() groups = manager.group_model.select_by_user_id(user.id) group_ids = [group.id for group in groups] @@ -175,33 +238,17 @@ def get_allowed_apis(): return user -def _register_user(form: RegisterForm): - with db.auto_commit(): - user = manager.user_model() - user.username = form.username.data - if form.email.data and form.email.data.strip() != "": - user.email = form.email.data - db.session.add(user) - db.session.flush() - user.password = form.password.data - group_ids = form.group_ids.data - # 如果没传分组数据,则将其设定为 id 2 的 guest 分组 - if len(group_ids) == 0: - group_ids = [2] - for group_id in group_ids: - user_group = manager.user_group_model() - user_group.user_id = user.id - user_group.group_id = group_id - db.session.add(user_group) - - @user_api.route("/captcha", methods=["GET", "POST"]) +@api.validate( + resp=DocResponse(r=CaptchaSchema), + tags=["用户"], +) def get_captcha(): """ 获取图形验证码 """ if not current_app.config.get("LOGIN_CAPTCHA"): - return {"tag": "", "image": ""} + return CaptchaSchema() # type: ignore image, code = CaptchaTool().get_verify_code() secret_key = current_app.config.get("SECRET_KEY") serializer = JWSSerializer(secret_key) diff --git a/app/api/cms/validator/__init__.py b/app/api/cms/validator/__init__.py index 69d7091..8b6466b 100644 --- a/app/api/cms/validator/__init__.py +++ b/app/api/cms/validator/__init__.py @@ -5,11 +5,10 @@ import re import time +from lin import Form, ParameterError, manager from wtforms import DateTimeField, FieldList, IntegerField, PasswordField, StringField from wtforms.validators import DataRequired, EqualTo, NumberRange, Regexp, length -from lin import Form, ParameterError, manager - # 注册校验 diff --git a/app/api/v1/book.py b/app/api/v1/book.py index 7938b0c..0878f20 100644 --- a/app/api/v1/book.py +++ b/app/api/v1/book.py @@ -6,6 +6,7 @@ """ from flask import Blueprint, g +from lin import DocResponse, Success, group_required, login_required, permission_meta from app.api import AuthorizationBearerSecurity, api from app.api.v1.exception import BookNotFound @@ -16,13 +17,6 @@ BookQuerySearchSchema, BookSchemaList, ) -from lin import ( - DocResponse, - Success, - group_required, - login_required, - permission_meta, -) book_api = Blueprint("book", __name__) diff --git a/app/api/v1/model/book.py b/app/api/v1/model/book.py index ef1f2fb..e4abe14 100644 --- a/app/api/v1/model/book.py +++ b/app/api/v1/model/book.py @@ -3,9 +3,8 @@ :license: MIT, see LICENSE for more details. """ -from sqlalchemy import Column, Integer, String - from lin import InfoCrud as Base +from sqlalchemy import Column, Integer, String class Book(Base): diff --git a/app/cli/db/fake.py b/app/cli/db/fake.py index 8b29b9d..f63581b 100644 --- a/app/cli/db/fake.py +++ b/app/cli/db/fake.py @@ -3,9 +3,10 @@ :license: MIT, see LICENSE for more details. """ -from app.api.v1.model.book import Book from lin import db +from app.api.v1.model.book import Book + def fake(): with db.auto_commit(): diff --git a/app/extension/file/file.py b/app/extension/file/file.py index 1765260..b4808f5 100644 --- a/app/extension/file/file.py +++ b/app/extension/file/file.py @@ -1,6 +1,5 @@ -from sqlalchemy import Column, Index, Integer, String, func, text - from lin import InfoCrud, db +from sqlalchemy import Column, Index, Integer, String, func, text class File(InfoCrud): diff --git a/app/extension/file/local_uploader.py b/app/extension/file/local_uploader.py index 3fd15dc..9a361b4 100644 --- a/app/extension/file/local_uploader.py +++ b/app/extension/file/local_uploader.py @@ -1,9 +1,8 @@ import os from flask import current_app -from werkzeug.utils import secure_filename - from lin import Uploader +from werkzeug.utils import secure_filename from .file import File diff --git a/app/plugin/oss/app/controller.py b/app/plugin/oss/app/controller.py index 0f47a6f..7cbcccc 100644 --- a/app/plugin/oss/app/controller.py +++ b/app/plugin/oss/app/controller.py @@ -2,7 +2,6 @@ import oss2 from flask import jsonify, request - from lin import ( Failed, ParameterError, diff --git a/app/plugin/oss/app/model.py b/app/plugin/oss/app/model.py index c97371b..60a146b 100644 --- a/app/plugin/oss/app/model.py +++ b/app/plugin/oss/app/model.py @@ -1,6 +1,5 @@ -from sqlalchemy import Column, Integer, String - from lin import BaseCrud +from sqlalchemy import Column, Integer, String class OSS(BaseCrud): diff --git a/app/plugin/poem/app/__init__.py b/app/plugin/poem/app/__init__.py index a09d7d5..78d6d77 100644 --- a/app/plugin/poem/app/__init__.py +++ b/app/plugin/poem/app/__init__.py @@ -7,9 +7,10 @@ def initial_data(): - from app import create_app from lin import db + from app import create_app + app = create_app() with app.app_context(): data = Poem.query.limit(1).all() diff --git a/app/plugin/poem/app/controller.py b/app/plugin/poem/app/controller.py index e9825cd..6186ebc 100644 --- a/app/plugin/poem/app/controller.py +++ b/app/plugin/poem/app/controller.py @@ -1,6 +1,6 @@ from flask import jsonify - from lin import Redprint + from app.plugin.poem.app.form import PoemListForm, PoemSearchForm from .model import Poem diff --git a/app/plugin/poem/app/form.py b/app/plugin/poem/app/form.py index ea85bbf..d7262ae 100644 --- a/app/plugin/poem/app/form.py +++ b/app/plugin/poem/app/form.py @@ -1,8 +1,7 @@ +from lin import Form from wtforms import IntegerField, StringField from wtforms.validators import DataRequired, Optional -from lin import Form - class PoemListForm(Form): count = IntegerField() diff --git a/app/plugin/poem/app/model.py b/app/plugin/poem/app/model.py index 80e5065..79cd5de 100644 --- a/app/plugin/poem/app/model.py +++ b/app/plugin/poem/app/model.py @@ -1,7 +1,6 @@ -from sqlalchemy import Column, Integer, String, Text, text - from lin import InfoCrud as Base from lin import NotFound, db, lin_config +from sqlalchemy import Column, Integer, String, Text, text class Poem(Base): diff --git a/app/plugin/qiniu/app/controller.py b/app/plugin/qiniu/app/controller.py index 47a492b..5b24530 100644 --- a/app/plugin/qiniu/app/controller.py +++ b/app/plugin/qiniu/app/controller.py @@ -4,9 +4,8 @@ """ from flask import request -from qiniu import Auth - from lin import FileExtensionError, Redprint, lin_config +from qiniu import Auth from .model import Qiniu diff --git a/app/plugin/qiniu/app/model.py b/app/plugin/qiniu/app/model.py index 817f8a3..dcfa49e 100644 --- a/app/plugin/qiniu/app/model.py +++ b/app/plugin/qiniu/app/model.py @@ -1,6 +1,5 @@ -from sqlalchemy import Column, Integer, String - from lin import BaseCrud +from sqlalchemy import Column, Integer, String class Qiniu(BaseCrud): diff --git a/app/schema/__init__.py b/app/schema/__init__.py index 80a9381..aeaf34c 100644 --- a/app/schema/__init__.py +++ b/app/schema/__init__.py @@ -1,6 +1,8 @@ from typing import Any, List +from flask import g from lin import BaseModel +from pydantic import Field datetime_regex = "^((([1-9][0-9][0-9][0-9]-(0[13578]|1[02])-(0[1-9]|[12][0-9]|3[01]))|(20[0-3][0-9]-(0[2469]|11)-(0[1-9]|[12][0-9]|30))) (20|21|22|23|[0-1][0-9]):[0-5][0-9]:[0-5][0-9])$" @@ -11,3 +13,12 @@ class BasePageSchema(BaseModel): total: int total_page: int items: List[Any] + + +class QueryPageSchema(BaseModel): + count: int = Field(5, gt=0, lt=16, description="0 < count < 16") + page: int = 0 + + @staticmethod + def offset_handler(req, resp, req_validation_error, instance): + g.offset = req.context.query.count * req.context.query.page diff --git a/pyproject.toml b/pyproject.toml index 5c7a5fd..6261372 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,9 +4,14 @@ version = "0.4" description = "🎀A simple and practical CMS implememted by flask" authors = ["sunlin92 "] +[[tool.poetry.source]] +name = "tsinghua" +url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" +default = true + [tool.poetry.dependencies] python = "^3.8" -Lin-CMS = "^0.4.2" +Lin-CMS = "^0.4.5" pillow = "^9.0.0" flask-cors = "^3.0.10" gunicorn = "^20.1.0" @@ -15,6 +20,7 @@ flask-socketio = "^5.1.1" blinker = "^1.4" python-dotenv = "^0.19.2" gevent-websocket = "^0.10.1" +pydantic = {extras = ["email"], version = "^1.9.0"} [tool.poetry.dev-dependencies] flask-sqlacodegen = "^1.1.8" diff --git a/requirements-dev.txt b/requirements-dev.txt index f344637..9db4d7c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ python==3.8 -Lin-CMS==0.4.2 +Lin-CMS==0.4.5 pillow==9.0.0 flask-cors==3.0.10 gunicorn==20.1.0 diff --git a/requirements-prod.txt b/requirements-prod.txt index 3f3d32d..caf3e97 100644 --- a/requirements-prod.txt +++ b/requirements-prod.txt @@ -1,4 +1,4 @@ -Lin-CMS==0.4.2 +Lin-CMS==0.4.5 pillow==9.0.0 flask-cors==3.0.10 gunicorn==20.1.0