diff --git a/examples/auth-flask-login/app.py b/examples/auth-flask-login/app.py index b29803841..587f53329 100644 --- a/examples/auth-flask-login/app.py +++ b/examples/auth-flask-login/app.py @@ -108,9 +108,11 @@ class MyAdminIndexView(admin.AdminIndexView): @expose('/') def index(self): - if not login.current_user.is_authenticated: - return redirect(url_for('.login_view')) - return super(MyAdminIndexView, self).index() + return ( + super(MyAdminIndexView, self).index() + if login.current_user.is_authenticated + else redirect(url_for('.login_view')) + ) @expose('/login/', methods=('GET', 'POST')) def login_view(self): @@ -201,8 +203,14 @@ def build_sample_db(): user.first_name = first_names[i] user.last_name = last_names[i] user.login = user.first_name.lower() - user.email = user.login + "@example.com" - user.password = generate_password_hash(''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(10))) + user.email = f"{user.login}@example.com" + user.password = generate_password_hash( + ''.join( + random.choice(string.ascii_lowercase + string.digits) + for _ in range(10) + ) + ) + db.session.add(user) db.session.commit() diff --git a/examples/auth/app.py b/examples/auth/app.py index f9cf7bc75..ebe0e2eac 100644 --- a/examples/auth/app.py +++ b/examples/auth/app.py @@ -138,8 +138,12 @@ def build_sample_db(): ] for i in range(len(first_names)): - tmp_email = first_names[i].lower() + "." + last_names[i].lower() + "@example.com" - tmp_pass = ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(10)) + tmp_email = f"{first_names[i].lower()}.{last_names[i].lower()}@example.com" + tmp_pass = ''.join( + random.choice(string.ascii_lowercase + string.digits) + for _ in range(10) + ) + user_datastore.create_user( first_name=first_names[i], last_name=last_names[i], diff --git a/examples/auth/config.py b/examples/auth/config.py index c25c011ff..3f199a33f 100644 --- a/examples/auth/config.py +++ b/examples/auth/config.py @@ -3,7 +3,7 @@ # Create in-memory database DATABASE_FILE = 'sample_db.sqlite' -SQLALCHEMY_DATABASE_URI = 'sqlite:///' + DATABASE_FILE +SQLALCHEMY_DATABASE_URI = f'sqlite:///{DATABASE_FILE}' SQLALCHEMY_ECHO = True # Flask-Security config diff --git a/examples/babel/app.py b/examples/babel/app.py index 8275be64e..c0950bb58 100644 --- a/examples/babel/app.py +++ b/examples/babel/app.py @@ -23,9 +23,7 @@ @babel.localeselector def get_locale(): - override = request.args.get('lang') - - if override: + if override := request.args.get('lang'): session['lang'] = override return session.get('lang', 'en') @@ -58,7 +56,7 @@ def __unicode__(self): # Flask views @app.route('/') def index(): - tmp = u""" + return u"""

Click me to get to Admin! (English)

Click me to get to Admin! (Czech)

Click me to get to Admin! (German)

@@ -71,7 +69,6 @@ def index():

Click me to get to Admin! (Chinese - Simplified)

Click me to get to Admin! (Chinese - Traditional)

""" - return tmp if __name__ == '__main__': # Create admin diff --git a/examples/bootstrap4/app.py b/examples/bootstrap4/app.py index bd766c17d..700810bb2 100644 --- a/examples/bootstrap4/app.py +++ b/examples/bootstrap4/app.py @@ -92,8 +92,8 @@ def build_sample_db(): for i in range(len(first_names)): user = User() - user.name = first_names[i] + " " + last_names[i] - user.email = first_names[i].lower() + "@example.com" + user.name = f"{first_names[i]} {last_names[i]}" + user.email = f"{first_names[i].lower()}@example.com" db.session.add(user) sample_text = [ diff --git a/examples/custom-layout/app.py b/examples/custom-layout/app.py index 9660c9723..e8bcf8419 100644 --- a/examples/custom-layout/app.py +++ b/examples/custom-layout/app.py @@ -86,8 +86,8 @@ def build_sample_db(): for i in range(len(first_names)): user = User() - user.name = first_names[i] + " " + last_names[i] - user.email = first_names[i].lower() + "@example.com" + user.name = f"{first_names[i]} {last_names[i]}" + user.email = f"{first_names[i].lower()}@example.com" db.session.add(user) sample_text = [ diff --git a/examples/forms-files-images/app.py b/examples/forms-files-images/app.py index e39136c0f..3a2e6467b 100644 --- a/examples/forms-files-images/app.py +++ b/examples/forms-files-images/app.py @@ -111,7 +111,7 @@ class CKTextAreaWidget(widgets.TextArea): def __call__(self, field, **kwargs): # add WYSIWYG class to existing classes existing_classes = kwargs.pop('class', '') or kwargs.pop('class_', '') - kwargs['class'] = '{} {}'.format(existing_classes, "ckeditor") + kwargs['class'] = f'{existing_classes} ckeditor' return super(CKTextAreaWidget, self).__call__(field, **kwargs) @@ -144,12 +144,15 @@ class FileView(sqla.ModelView): class ImageView(sqla.ModelView): - def _list_thumbnail(view, context, model, name): - if not model.path: - return '' - - return Markup('' % url_for('static', - filename=form.thumbgen_filename(model.path))) + def _list_thumbnail(self, context, model, name): + return ( + Markup( + '' + % url_for('static', filename=form.thumbgen_filename(model.path)) + ) + if model.path + else '' + ) column_formatters = { 'path': _list_thumbnail @@ -256,9 +259,9 @@ def build_sample_db(): user = User() user.first_name = first_names[i] user.last_name = last_names[i] - user.email = user.first_name.lower() + "@example.com" - tmp = ''.join(random.choice(string.digits) for i in range(10)) - user.phone = "(" + tmp[0:3] + ") " + tmp[3:6] + " " + tmp[6::] + user.email = f"{user.first_name.lower()}@example.com" + tmp = ''.join(random.choice(string.digits) for _ in range(10)) + user.phone = "(" + tmp[:3] + ") " + tmp[3:6] + " " + tmp[6::] user.city = locations[i][0] user.country = locations[i][1] db.session.add(user) @@ -267,13 +270,13 @@ def build_sample_db(): for name in images: image = Image() image.name = name - image.path = name.lower() + ".jpg" + image.path = f"{name.lower()}.jpg" db.session.add(image) for i in [1, 2, 3]: file = File() - file.name = "Example " + str(i) - file.path = "example_" + str(i) + ".pdf" + file.name = f"Example {str(i)}" + file.path = f"example_{str(i)}.pdf" db.session.add(file) sample_text = "

This is a test

" + \ diff --git a/examples/peewee/app.py b/examples/peewee/app.py index de0113575..cdeb042a9 100644 --- a/examples/peewee/app.py +++ b/examples/peewee/app.py @@ -32,7 +32,7 @@ class UserInfo(BaseModel): user = peewee.ForeignKeyField(User) def __unicode__(self): - return '%s - %s' % (self.key, self.value) + return f'{self.key} - {self.value}' class Post(BaseModel): diff --git a/examples/pymongo/app.py b/examples/pymongo/app.py index dc66b28b0..519108e2a 100644 --- a/examples/pymongo/app.py +++ b/examples/pymongo/app.py @@ -77,7 +77,7 @@ def get_list(self, *args, **kwargs): users = db.user.find(query, fields=('name',)) # Contribute user names to the models - users_map = dict((x['_id'], x['name']) for x in users) + users_map = {x['_id']: x['name'] for x in users} for item in data: item['user_name'] = users_map.get(item['user_id']) diff --git a/examples/sqla-association_proxy/app.py b/examples/sqla-association_proxy/app.py index 8d8e32ee9..f2c814771 100644 --- a/examples/sqla-association_proxy/app.py +++ b/examples/sqla-association_proxy/app.py @@ -67,7 +67,7 @@ def __init__(self, keyword=None): self.keyword = keyword def __repr__(self): - return 'Keyword(%s)' % repr(self.keyword) + return f'Keyword({repr(self.keyword)})' class UserAdmin(sqla.ModelView): diff --git a/examples/sqla-custom-inline-forms/app.py b/examples/sqla-custom-inline-forms/app.py index 659d80d74..19563a4be 100644 --- a/examples/sqla-custom-inline-forms/app.py +++ b/examples/sqla-custom-inline-forms/app.py @@ -89,9 +89,7 @@ def postprocess_form(self, form_class): return form_class def on_model_change(self, form, model): - file_data = request.files.get(form.upload.name) - - if file_data: + if file_data := request.files.get(form.upload.name): model.path = secure_filename(file_data.filename) file_data.save(op.join(base_path, model.path)) diff --git a/examples/sqla/admin/__init__.py b/examples/sqla/admin/__init__.py index 99c979b39..cabdbd502 100644 --- a/examples/sqla/admin/__init__.py +++ b/examples/sqla/admin/__init__.py @@ -13,9 +13,7 @@ @babel.localeselector def get_locale(): - override = request.args.get('lang') - - if override: + if override := request.args.get('lang'): session['lang'] = override return session.get('lang', 'en') diff --git a/examples/sqla/admin/config.py b/examples/sqla/admin/config.py index c685d83a0..86c9b7541 100644 --- a/examples/sqla/admin/config.py +++ b/examples/sqla/admin/config.py @@ -7,6 +7,6 @@ # Create in-memory database DATABASE_FILE = 'sample_db.sqlite' -SQLALCHEMY_DATABASE_URI = 'sqlite:///' + DATABASE_FILE +SQLALCHEMY_DATABASE_URI = f'sqlite:///{DATABASE_FILE}' SQLALCHEMY_ECHO = True SQLALCHEMY_TRACK_MODIFICATIONS = False diff --git a/examples/sqla/admin/data.py b/examples/sqla/admin/data.py index 0cb461356..44f146c38 100644 --- a/examples/sqla/admin/data.py +++ b/examples/sqla/admin/data.py @@ -41,7 +41,7 @@ def build_sample_db(): user.type = random.choice(AVAILABLE_USER_TYPES)[0] user.first_name = first_names[i] user.last_name = last_names[i] - user.email = first_names[i].lower() + "@example.com" + user.email = f"{first_names[i].lower()}@example.com" user.website = "https://www.example.com" user.ip_address = "127.0.0.1" @@ -107,7 +107,7 @@ def build_sample_db(): entry = random.choice(sample_text) # select text at random post = Post() post.user = user - post.title = "{}'s opinion on {}".format(user.first_name, entry['title']) + post.title = f"{user.first_name}'s opinion on {entry['title']}" post.text = entry['content'] post.background_color = random.choice(["#cccccc", "red", "lightblue", "#0f0"]) tmp = int(1000 * random.random()) # random number between 0 and 1000: @@ -120,12 +120,12 @@ def build_sample_db(): db.session.add(trunk) for i in range(5): branch = Tree() - branch.name = "Branch " + str(i + 1) + branch.name = f"Branch {str(i + 1)}" branch.parent = trunk db.session.add(branch) for j in range(5): leaf = Tree() - leaf.name = "Leaf " + str(j + 1) + leaf.name = f"Leaf {str(j + 1)}" leaf.parent = branch db.session.add(leaf) diff --git a/examples/sqla/admin/main.py b/examples/sqla/admin/main.py index 3fc7878a4..60a9c47df 100644 --- a/examples/sqla/admin/main.py +++ b/examples/sqla/admin/main.py @@ -15,7 +15,7 @@ # Flask views @app.route('/') def index(): - tmp = u""" + return u"""

Click me to get to Admin! (English)

Click me to get to Admin! (Czech)

Click me to get to Admin! (German)

@@ -28,7 +28,6 @@ def index():

Click me to get to Admin! (Chinese - Simplified)

Click me to get to Admin! (Chinese - Traditional)

""" - return tmp @app.route('/favicon.ico') @@ -50,7 +49,11 @@ def operation(self): # Customized User model admin def phone_number_formatter(view, context, model, name): - return Markup("{}".format(model.phone_number)) if model.phone_number else None + return ( + Markup(f"{model.phone_number}") + if model.phone_number + else None + ) def is_numberic_validator(form, field): diff --git a/examples/sqla/admin/models.py b/examples/sqla/admin/models.py index 25d1f1a27..0dac1bf25 100644 --- a/examples/sqla/admin/models.py +++ b/examples/sqla/admin/models.py @@ -54,18 +54,21 @@ class User(db.Model): def phone_number(self): if self.dialling_code and self.local_phone_number: number = str(self.local_phone_number) - return "+{} ({}) {} {} {}".format(self.dialling_code, number[0], number[1:3], number[3:6], number[6::]) + return f"+{self.dialling_code} ({number[0]}) {number[1:3]} {number[3:6]} {number[6::]}" + return @phone_number.expression - def phone_number(cls): - return sql.operators.ColumnOperators.concat(cast(cls.dialling_code, db.String), cls.local_phone_number) + def phone_number(self): + return sql.operators.ColumnOperators.concat( + cast(self.dialling_code, db.String), self.local_phone_number + ) def __str__(self): - return "{}, {}".format(self.last_name, self.first_name) + return f"{self.last_name}, {self.first_name}" def __repr__(self): - return "{}: {}".format(self.id, self.__str__()) + return f"{self.id}: {self.__str__()}" # Create M2M table @@ -90,7 +93,7 @@ class Post(db.Model): tags = db.relationship('Tag', secondary=post_tags_table) def __str__(self): - return "{}".format(self.title) + return f"{self.title}" class Tag(db.Model): @@ -98,7 +101,7 @@ class Tag(db.Model): name = db.Column(db.Unicode(64), unique=True) def __str__(self): - return "{}".format(self.name) + return f"{self.name}" class Tree(db.Model): @@ -110,4 +113,4 @@ class Tree(db.Model): parent = db.relationship('Tree', remote_side=[id], backref='children') def __str__(self): - return "{}".format(self.name) + return f"{self.name}" diff --git a/flask_admin/_backwards.py b/flask_admin/_backwards.py index 0f1a2a49b..efbc73a4d 100644 --- a/flask_admin/_backwards.py +++ b/flask_admin/_backwards.py @@ -29,8 +29,11 @@ def get_property(obj, name, old_name, default=None): Default value """ if hasattr(obj, old_name): - warnings.warn('Property %s is obsolete, please use %s instead' % - (old_name, name), stacklevel=2) + warnings.warn( + f'Property {old_name} is obsolete, please use {name} instead', + stacklevel=2, + ) + return getattr(obj, old_name) return getattr(obj, name, default) @@ -40,7 +43,7 @@ class ObsoleteAttr(object): def __init__(self, new_name, old_name, default): self.new_name = new_name self.old_name = old_name - self.cache = '_cache_' + new_name + self.cache = f'_cache_{new_name}' self.default = default def __get__(self, obj, objtype=None): @@ -53,8 +56,11 @@ def __get__(self, obj, objtype=None): # Check if there's old attribute if hasattr(obj, self.old_name): - warnings.warn('Property %s is obsolete, please use %s instead' % - (self.old_name, self.new_name), stacklevel=2) + warnings.warn( + f'Property {self.old_name} is obsolete, please use {self.new_name} instead', + stacklevel=2, + ) + return getattr(obj, self.old_name) # Return default otherwise diff --git a/flask_admin/_compat.py b/flask_admin/_compat.py index 610d29301..84f71370e 100644 --- a/flask_admin/_compat.py +++ b/flask_admin/_compat.py @@ -27,10 +27,7 @@ filter_list = lambda f, l: list(filter(f, l)) def as_unicode(s): - if isinstance(s, bytes): - return s.decode('utf-8') - - return str(s) + return s.decode('utf-8') if isinstance(s, bytes) else str(s) def csv_encode(s): ''' Returns unicode string expected by Python 3's csv module ''' @@ -50,10 +47,7 @@ def csv_encode(s): filter_list = filter def as_unicode(s): - if isinstance(s, str): - return s.decode('utf-8') - - return unicode(s) + return s.decode('utf-8') if isinstance(s, str) else unicode(s) def csv_encode(s): ''' Returns byte string expected by Python 2's csv module ''' diff --git a/flask_admin/actions.py b/flask_admin/actions.py index f217fa0c4..4bed96132 100644 --- a/flask_admin/actions.py +++ b/flask_admin/actions.py @@ -89,8 +89,7 @@ def get_actions_list(self): if self.is_action_allowed(name): actions.append((name, text_type(text))) - confirmation = self._actions_data[name][2] - if confirmation: + if confirmation := self._actions_data[name][2]: actions_confirmation[name] = text_type(confirmation) return actions, actions_confirmation @@ -122,7 +121,7 @@ def handle_action(self, return_view=None): flash_errors(form, message='Failed to perform action. %(error)s') if return_view: - url = self.get_url('.' + return_view) + url = self.get_url(f'.{return_view}') else: url = get_redirect_target() or self.get_url('.index_view') diff --git a/flask_admin/base.py b/flask_admin/base.py index e706c916e..9bb0963f7 100644 --- a/flask_admin/base.py +++ b/flask_admin/base.py @@ -42,10 +42,7 @@ def expose_plugview(url='/'): def wrap(v): handler = expose(url, v.methods) - if hasattr(v, 'as_view'): - return handler(v.as_view(v.__name__)) - else: - return handler(v) + return handler(v.as_view(v.__name__)) if hasattr(v, 'as_view') else handler(v) return wrap @@ -80,26 +77,26 @@ class AdminViewMeta(type): Does some precalculations (like getting list of view methods from the class) to avoid calculating them for each view class instance. """ - def __init__(cls, classname, bases, fields): - type.__init__(cls, classname, bases, fields) + def __init__(self, classname, bases, fields): + type.__init__(self, classname, bases, fields) # Gather exposed views - cls._urls = [] - cls._default_view = None + self._urls = [] + self._default_view = None - for p in dir(cls): - attr = getattr(cls, p) + for p in dir(self): + attr = getattr(self, p) if hasattr(attr, '_urls'): # Collect methods for url, methods in attr._urls: - cls._urls.append((url, p, methods)) + self._urls.append((url, p, methods)) if url == '/': - cls._default_view = p + self._default_view = p # Wrap views - setattr(cls, p, _wrap_view(attr)) + setattr(self, p, _wrap_view(attr)) class BaseViewClass(object): @@ -205,33 +202,28 @@ def __init__(self, name=None, category=None, endpoint=None, url=None, # Default view if self._default_view is None: - raise Exception(u'Attempted to instantiate admin view %s without default view' % self.__class__.__name__) + raise Exception( + f'Attempted to instantiate admin view {self.__class__.__name__} without default view' + ) def _get_endpoint(self, endpoint): """ Generate Flask endpoint name. By default converts class name to lower case if endpoint is not explicitly provided. """ - if endpoint: - return endpoint - - return self.__class__.__name__.lower() + return endpoint or self.__class__.__name__.lower() def _get_view_url(self, admin, url): """ Generate URL for the view. Override to change default behavior. """ if url is None: - if admin.url != '/': - url = '%s/%s' % (admin.url, self.endpoint) + if admin.url == '/': + url = '/' if self == admin.index_view else f'/{self.endpoint}' else: - if self == admin.index_view: - url = '/' - else: - url = '/%s' % self.endpoint - else: - if not url.startswith('/'): - url = '%s/%s' % (admin.url, url) + url = f'{admin.url}/{self.endpoint}' + elif not url.startswith('/'): + url = f'{admin.url}/{url}' return url @@ -391,10 +383,7 @@ def get_url(self, endpoint, **kwargs): @property def _debug(self): - if not self.admin or not self.admin.app: - return False - - return self.admin.app.debug + return False if not self.admin or not self.admin.app else self.admin.app.debug class AdminIndexView(BaseView): @@ -502,7 +491,7 @@ def __init__(self, app=None, name=None, self._views = [] self._menu = [] - self._menu_categories = dict() + self._menu_categories = {} self._menu_links = [] if name is None: @@ -722,7 +711,7 @@ def init_app(self, app, index_view=None, def _init_extension(self): if not hasattr(self.app, 'extensions'): - self.app.extensions = dict() + self.app.extensions = {} admins = self.app.extensions.get('admin', []) diff --git a/flask_admin/contrib/appengine/view.py b/flask_admin/contrib/appengine/view.py index 010e4bdab..7adc9daea 100644 --- a/flask_admin/contrib/appengine/view.py +++ b/flask_admin/contrib/appengine/view.py @@ -53,7 +53,7 @@ class MyAdminView(ModelView): """ def scaffold_form(self): - form_class = wt_ndb.model_form( + return wt_ndb.model_form( self.model(), base_class=Form, only=self.form_columns, @@ -61,7 +61,6 @@ def scaffold_form(self): field_args=self.form_args, converter=self.model_form_converter(), ) - return form_class def scaffold_list_form(self, widget=None, validators=None): form_class = wt_ndb.model_form( @@ -71,8 +70,7 @@ def scaffold_list_form(self, widget=None, validators=None): field_args=self.form_args, converter=self.model_form_converter(), ) - result = create_editable_list_form(Form, form_class, widget) - return result + return create_editable_list_form(Form, form_class, widget) def get_list(self, page, sort_field, sort_desc, search, filters, page_size=None): @@ -179,7 +177,7 @@ def get_list(self, page, sort_field, sort_desc, search, filters): if sort_field: if sort_desc: - sort_field = "-" + sort_field + sort_field = f"-{sort_field}" q.order(sort_field) results = q.fetch(self.page_size, offset=page * self.page_size) @@ -232,4 +230,4 @@ def ModelView(model): elif issubclass(model, db.Model): return DbModelView(model) else: - raise ValueError("Unsupported model: %s" % model) + raise ValueError(f"Unsupported model: {model}") diff --git a/flask_admin/contrib/fileadmin/__init__.py b/flask_admin/contrib/fileadmin/__init__.py index f5c786adb..8c08835ba 100644 --- a/flask_admin/contrib/fileadmin/__init__.py +++ b/flask_admin/contrib/fileadmin/__init__.py @@ -481,10 +481,7 @@ def edit_form(self): Override to implement custom behavior. """ edit_form_class = self.get_edit_form() - if request.form: - return edit_form_class(request.form) - else: - return edit_form_class() + return edit_form_class(request.form) if request.form else edit_form_class() def delete_form(self): """ @@ -493,10 +490,7 @@ def delete_form(self): Override to implement custom behavior. """ delete_form_class = self.get_delete_form() - if request.form: - return delete_form_class(request.form) - else: - return delete_form_class() + return delete_form_class(request.form) if request.form else delete_form_class() def action_form(self): """ @@ -505,10 +499,7 @@ def action_form(self): Override to implement custom behavior. """ action_form_class = self.get_action_form() - if request.form: - return action_form_class(request.form) - else: - return action_form_class() + return action_form_class(request.form) if request.form else action_form_class() def is_file_allowed(self, filename): """ @@ -590,15 +581,13 @@ def _get_dir_url(self, endpoint, path=None, **kwargs): :param kwargs: Additional arguments """ - if not path: - return self.get_url(endpoint, **kwargs) - else: + if path: if self._on_windows: path = path.replace('\\', '/') kwargs['path'] = path - return self.get_url(endpoint, **kwargs) + return self.get_url(endpoint, **kwargs) def _get_file_url(self, path, **kwargs): """ @@ -610,11 +599,7 @@ def _get_file_url(self, path, **kwargs): if self._on_windows: path = path.replace('\\', '/') - if self.is_file_editable(path): - route = '.edit' - else: - route = '.download' - + route = '.edit' if self.is_file_editable(path) else '.download' return self.get_url(route, path=path, **kwargs) def _normalize_path(self, path): @@ -631,11 +616,7 @@ def _normalize_path(self, path): path = '' else: path = op.normpath(path) - if base_path: - directory = self._separator.join([base_path, path]) - else: - directory = path - + directory = self._separator.join([base_path, path]) if base_path else path directory = op.normpath(directory) if not self.is_in_folder(base_path, directory): @@ -871,14 +852,10 @@ def index_view(self, path=None): action_form = None def sort_url(column, path, invert=False): - desc = None - if not path: path = None - if invert and not sort_desc: - desc = 1 - + desc = 1 if invert and not sort_desc else None return self.get_url('.index_view', path=path, sort=column, desc=desc) return self.render(self.list_template, @@ -948,9 +925,7 @@ def download(self, path=None): base_path, directory, path = self._normalize_path(path) - # backward compatibility with base_url - base_url = self.get_base_url() - if base_url: + if base_url := self.get_base_url(): base_url = urljoin(self.get_url('.index_view'), base_url) return redirect(urljoin(quote(base_url), quote(path))) @@ -1063,13 +1038,12 @@ def rename(self): form = self.name_form() path = form.path.data - if path: - base_path, full_path, path = self._normalize_path(path) - - return_url = self._get_dir_url('.index_view', op.dirname(path)) - else: + if not path: return redirect(self.get_url('.index_view')) + base_path, full_path, path = self._normalize_path(path) + + return_url = self._get_dir_url('.index_view', op.dirname(path)) if not self.can_rename: flash(gettext('Renaming is disabled.'), 'error') return redirect(return_url) @@ -1113,15 +1087,11 @@ def edit(self): """ Edit view method """ - next_url = None - path = request.args.getlist('path') if not path: return redirect(self.get_url('.index_view')) - if len(path) > 1: - next_url = self.get_url('.edit', path=path[1:]) - + next_url = self.get_url('.edit', path=path[1:]) if len(path) > 1 else None path = path[0] base_path, full_path, path = self._normalize_path(path) @@ -1156,18 +1126,12 @@ def edit(self): except IOError: flash(gettext("Error reading %(name)s.", name=path), 'error') error = True - except: - flash(gettext("Unexpected error while reading from %(name)s", name=path), 'error') - error = True else: try: content = content.decode('utf8') except UnicodeDecodeError: flash(gettext("Cannot edit %(name)s.", name=path), 'error') error = True - except: - flash(gettext("Unexpected error while reading from %(name)s", name=path), 'error') - error = True else: form.content.data = content diff --git a/flask_admin/contrib/fileadmin/azure.py b/flask_admin/contrib/fileadmin/azure.py index 8e0fb44e3..1d389d097 100644 --- a/flask_admin/contrib/fileadmin/azure.py +++ b/flask_admin/contrib/fileadmin/azure.py @@ -112,10 +112,10 @@ def get_files(self, path, directory): folders.add(folder_name) folders.discard(directory) + is_dir = True for folder in folders: name = folder.split(self.separator)[-1] rel_path = folder - is_dir = True size = 0 last_modified = 0 files.append((name, rel_path, is_dir, size, last_modified)) @@ -125,14 +125,12 @@ def get_files(self, path, directory): def is_dir(self, path): path = self._ensure_blob_path(path) - num_blobs = 0 - for blob in self._client.list_blobs(self._container_name, path): + for num_blobs, blob in enumerate(self._client.list_blobs(self._container_name, path), start=1): blob_path_parts = blob.name.split(self.separator) is_explicit_directory = blob_path_parts[-1] == self._fakedir if is_explicit_directory: return True - num_blobs += 1 path_cannot_be_leaf = num_blobs >= 2 if path_cannot_be_leaf: return True @@ -178,7 +176,7 @@ def send_file(self, file_path): BlobPermissions.READ, expiry=now + self._send_file_validity, start=now - self._send_file_lookback) - return redirect('%s?%s' % (url, sas)) + return redirect(f'{url}?{sas}') def read_file(self, path): path = self._ensure_blob_path(path) diff --git a/flask_admin/contrib/fileadmin/s3.py b/flask_admin/contrib/fileadmin/s3.py index 1c9bdc719..8eeb7879c 100644 --- a/flask_admin/contrib/fileadmin/s3.py +++ b/flask_admin/contrib/fileadmin/s3.py @@ -65,9 +65,7 @@ def __init__(self, bucket_name, region, aws_access_key_id, def get_files(self, path, directory): def _strip_path(name, path): - if name.startswith(path): - return name.replace(path, '', 1) - return name + return name.replace(path, '', 1) if name.startswith(path) else name def _remove_trailing_slash(name): return name[:-1] @@ -95,11 +93,11 @@ def _iso_to_epoch(timestamp): def _get_bucket_list_prefix(self, path): parts = path.split(self.separator) - if len(parts) == 1: - search = '' - else: - search = self.separator.join(parts[:-1]) + self.separator - return search + return ( + '' + if len(parts) == 1 + else self.separator.join(parts[:-1]) + self.separator + ) def _get_path_keys(self, path): search = self._get_bucket_list_prefix(path) diff --git a/flask_admin/contrib/geoa/fields.py b/flask_admin/contrib/geoa/fields.py index 0e9a30371..3d7f75bd3 100644 --- a/flask_admin/contrib/geoa/fields.py +++ b/flask_admin/contrib/geoa/fields.py @@ -19,31 +19,27 @@ def __init__(self, label=None, validators=None, geometry_type="GEOMETRY", super(GeoJSONField, self).__init__(label, validators, **kwargs) self.web_srid = 4326 self.srid = srid - if self.srid == -1: - self.transform_srid = self.web_srid - else: - self.transform_srid = self.srid + self.transform_srid = self.web_srid if self.srid == -1 else self.srid self.geometry_type = geometry_type.upper() self.session = session def _value(self): if self.raw_data: return self.raw_data[0] - if type(self.data) is geoalchemy2.elements.WKBElement: - if self.srid == -1: - return self.session.scalar(func.ST_AsGeoJSON(self.data)) - else: - return self.session.scalar( - func.ST_AsGeoJSON( - func.ST_Transform(self.data, self.web_srid) - ) - ) - else: + if type(self.data) is not geoalchemy2.elements.WKBElement: return '' + if self.srid == -1: + return self.session.scalar(func.ST_AsGeoJSON(self.data)) + else: + return self.session.scalar( + func.ST_AsGeoJSON( + func.ST_Transform(self.data, self.web_srid) + ) + ) def process_formdata(self, valuelist): super(GeoJSONField, self).process_formdata(valuelist) - if str(self.data) == '': + if not str(self.data): self.data = None if self.data is not None: web_shape = self.session.scalar( @@ -57,4 +53,4 @@ def process_formdata(self, valuelist): ) ) ) - self.data = 'SRID=' + str(self.srid) + ';' + str(web_shape) + self.data = f'SRID={str(self.srid)};{str(web_shape)}' diff --git a/flask_admin/contrib/geoa/typefmt.py b/flask_admin/contrib/geoa/typefmt.py index 4e1902247..605e91687 100644 --- a/flask_admin/contrib/geoa/typefmt.py +++ b/flask_admin/contrib/geoa/typefmt.py @@ -22,7 +22,7 @@ def geom_formatter(view, value): value.srid = 4326 geojson = view.session.query(view.model).with_entities(func.ST_AsGeoJSON(value)).scalar() - return Markup('' % (params, geojson)) + return Markup(f'') DEFAULT_FORMATTERS = BASE_FORMATTERS.copy() diff --git a/flask_admin/contrib/mongoengine/ajax.py b/flask_admin/contrib/mongoengine/ajax.py index 470911fbc..7c190ffbd 100644 --- a/flask_admin/contrib/mongoengine/ajax.py +++ b/flask_admin/contrib/mongoengine/ajax.py @@ -20,7 +20,9 @@ def __init__(self, name, model, **options): self._cached_fields = self._process_fields() if not self.fields: - raise ValueError('AJAX loading requires `fields` to be specified for %s.%s' % (model, self.name)) + raise ValueError( + f'AJAX loading requires `fields` to be specified for {model}.{self.name}' + ) def _process_fields(self): remote_fields = [] @@ -30,7 +32,7 @@ def _process_fields(self): attr = getattr(self.model, field, None) if not attr: - raise ValueError('%s.%s does not exist.' % (self.model, field)) + raise ValueError(f'{self.model}.{field} does not exist.') remote_fields.append(attr) else: @@ -39,10 +41,7 @@ def _process_fields(self): return remote_fields def format(self, model): - if not model: - return None - - return (as_unicode(model.pk), as_unicode(model)) + return (as_unicode(model.pk), as_unicode(model)) if model else None def get_one(self, pk): return self.model.objects.filter(pk=pk).first() @@ -54,7 +53,7 @@ def get_list(self, term, offset=0, limit=DEFAULT_PAGE_SIZE): criteria = None for field in self._cached_fields: - flt = {u'%s__icontains' % field.name: term} + flt = {f'{field.name}__icontains': term} if not criteria: criteria = mongoengine.Q(**flt) @@ -73,16 +72,16 @@ def create_ajax_loader(model, name, field_name, opts): prop = getattr(model, field_name, None) if prop is None: - raise ValueError('Model %s does not have field %s.' % (model, field_name)) + raise ValueError(f'Model {model} does not have field {field_name}.') ftype = type(prop).__name__ - if ftype == 'ListField' or ftype == 'SortedListField': + if ftype in ['ListField', 'SortedListField']: prop = prop.field ftype = type(prop).__name__ if ftype != 'ReferenceField': - raise ValueError('Dont know how to convert %s type for AJAX loader' % ftype) + raise ValueError(f'Dont know how to convert {ftype} type for AJAX loader') remote_model = prop.document_type return QueryAjaxModelLoader(name, remote_model, **opts) @@ -90,18 +89,13 @@ def create_ajax_loader(model, name, field_name, opts): def process_ajax_references(references, view): def make_name(base, name): - if base: - return ('%s-%s' % (base, name)).lower() - else: - return as_unicode(name).lower() + return f'{base}-{name}'.lower() if base else as_unicode(name).lower() def handle_field(field, subdoc, base): ftype = type(field).__name__ - if ftype == 'ListField' or ftype == 'SortedListField': - child_doc = getattr(subdoc, '_form_subdocuments', {}).get(None) - - if child_doc: + if ftype in ['ListField', 'SortedListField']: + if child_doc := getattr(subdoc, '_form_subdocuments', {}).get(None): handle_field(field.field, child_doc, base) elif ftype == 'EmbeddedDocumentField': result = {} @@ -121,11 +115,10 @@ def handle_field(field, subdoc, base): subdoc._form_ajax_refs = result - child_doc = getattr(subdoc, '_form_subdocuments', None) - if child_doc: + if child_doc := getattr(subdoc, '_form_subdocuments', None): handle_subdoc(field.document_type_obj, subdoc, base) else: - raise ValueError('Failed to process subdocument field %s' % (field,)) + raise ValueError(f'Failed to process subdocument field {field}') def handle_subdoc(model, subdoc, base): documents = getattr(subdoc, '_form_subdocuments', {}) @@ -134,7 +127,7 @@ def handle_subdoc(model, subdoc, base): field = getattr(model, name, None) if not field: - raise ValueError('Invalid subdocument field %s.%s' % (model, name)) + raise ValueError(f'Invalid subdocument field {model}.{name}') handle_field(field, doc, make_name(base, name)) diff --git a/flask_admin/contrib/mongoengine/fields.py b/flask_admin/contrib/mongoengine/fields.py index 64418ecb0..eaeb9308b 100644 --- a/flask_admin/contrib/mongoengine/fields.py +++ b/flask_admin/contrib/mongoengine/fields.py @@ -59,7 +59,7 @@ def __init__(self, label=None, validators=None, **kwargs): def process(self, formdata, data=unset_value): if formdata: - marker = '_%s-delete' % self.name + marker = f'_{self.name}-delete' if marker in formdata: self._should_delete = True @@ -74,11 +74,7 @@ def populate_obj(self, obj, name): return if isinstance(self.data, FileStorage) and not is_empty(self.data.stream): - if not field.grid_id: - func = field.put - else: - func = field.replace - + func = field.replace if field.grid_id else field.put func(self.data.stream, filename=self.data.filename, content_type=self.data.content_type) diff --git a/flask_admin/contrib/mongoengine/filters.py b/flask_admin/contrib/mongoengine/filters.py index ccbe30368..94cf952b1 100644 --- a/flask_admin/contrib/mongoengine/filters.py +++ b/flask_admin/contrib/mongoengine/filters.py @@ -32,7 +32,7 @@ def __init__(self, column, name, options=None, data_type=None): # Common filters class FilterEqual(BaseMongoEngineFilter): def apply(self, query, value): - flt = {'%s' % self.column.name: value} + flt = {f'{self.column.name}': value} return query.filter(**flt) def operation(self): @@ -41,7 +41,7 @@ def operation(self): class FilterNotEqual(BaseMongoEngineFilter): def apply(self, query, value): - flt = {'%s__ne' % self.column.name: value} + flt = {f'{self.column.name}__ne': value} return query.filter(**flt) def operation(self): @@ -51,7 +51,7 @@ def operation(self): class FilterLike(BaseMongoEngineFilter): def apply(self, query, value): term, data = parse_like_term(value) - flt = {'%s__%s' % (self.column.name, term): data} + flt = {f'{self.column.name}__{term}': data} return query.filter(**flt) def operation(self): @@ -61,7 +61,7 @@ def operation(self): class FilterNotLike(BaseMongoEngineFilter): def apply(self, query, value): term, data = parse_like_term(value) - flt = {'%s__not__%s' % (self.column.name, term): data} + flt = {f'{self.column.name}__not__{term}': data} return query.filter(**flt) def operation(self): @@ -70,7 +70,7 @@ def operation(self): class FilterGreater(BaseMongoEngineFilter): def apply(self, query, value): - flt = {'%s__gt' % self.column.name: value} + flt = {f'{self.column.name}__gt': value} return query.filter(**flt) def operation(self): @@ -79,7 +79,7 @@ def operation(self): class FilterSmaller(BaseMongoEngineFilter): def apply(self, query, value): - flt = {'%s__lt' % self.column.name: value} + flt = {f'{self.column.name}__lt': value} return query.filter(**flt) def operation(self): @@ -89,9 +89,9 @@ def operation(self): class FilterEmpty(BaseMongoEngineFilter, filters.BaseBooleanFilter): def apply(self, query, value): if value == '1': - flt = {'%s' % self.column.name: None} + flt = {f'{self.column.name}': None} else: - flt = {'%s__ne' % self.column.name: None} + flt = {f'{self.column.name}__ne': None} return query.filter(**flt) def operation(self): @@ -106,7 +106,7 @@ def clean(self, value): return [v.strip() for v in value.split(',') if v.strip()] def apply(self, query, value): - flt = {'%s__in' % self.column.name: value} + flt = {f'{self.column.name}__in': value} return query.filter(**flt) def operation(self): @@ -115,7 +115,7 @@ def operation(self): class FilterNotInList(FilterInList): def apply(self, query, value): - flt = {'%s__nin' % self.column.name: value} + flt = {f'{self.column.name}__nin': value} return query.filter(**flt) def operation(self): @@ -125,13 +125,13 @@ def operation(self): # Customized type filters class BooleanEqualFilter(FilterEqual, filters.BaseBooleanFilter): def apply(self, query, value): - flt = {'%s' % self.column.name: value == '1'} + flt = {f'{self.column.name}': value == '1'} return query.filter(**flt) class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter): def apply(self, query, value): - flt = {'%s' % self.column.name: value != '1'} + flt = {f'{self.column.name}': value != '1'} return query.filter(**flt) @@ -208,15 +208,19 @@ def __init__(self, column, name, options=None, data_type=None): def apply(self, query, value): start, end = value - flt = {'%s__gte' % self.column.name: start, '%s__lte' % self.column.name: end} + flt = {f'{self.column.name}__gte': start, f'{self.column.name}__lte': end} return query.filter(**flt) class DateTimeNotBetweenFilter(DateTimeBetweenFilter): def apply(self, query, value): start, end = value - return query.filter(Q(**{'%s__not__gte' % self.column.name: start}) | - Q(**{'%s__not__lte' % self.column.name: end})) + return query.filter( + ( + Q(**{f'{self.column.name}__not__gte': start}) + | Q(**{f'{self.column.name}__not__lte': end}) + ) + ) def operation(self): return lazy_gettext('not between') @@ -240,7 +244,7 @@ def clean(self, value): return ObjectId(value.strip()) def apply(self, query, value): - flt = {'%s' % self.column.name: value} + flt = {f'{self.column.name}': value} return query.filter(**flt) def operation(self): diff --git a/flask_admin/contrib/mongoengine/form.py b/flask_admin/contrib/mongoengine/form.py index 1934f2bc3..e47628c85 100644 --- a/flask_admin/contrib/mongoengine/form.py +++ b/flask_admin/contrib/mongoengine/form.py @@ -28,9 +28,7 @@ def __init__(self, view): self.view = view def _get_field_override(self, name): - form_overrides = getattr(self.view, 'form_overrides', None) - - if form_overrides: + if form_overrides := getattr(self.view, 'form_overrides', None): return form_overrides.get(name) return None @@ -68,7 +66,7 @@ def convert(self, model, field, field_args): } if field_args: - kwargs.update(field_args) + kwargs |= field_args if kwargs['validators']: # Create a copy of the list since we will be modifying it. @@ -98,8 +96,7 @@ def convert(self, model, field, field_args): if hasattr(field, 'to_form_field'): return field.to_form_field(model, kwargs) - override = self._get_field_override(field.name) - if override: + if override := self._get_field_override(field.name): return override(**kwargs) if ftype in self.converters: @@ -116,8 +113,7 @@ def conv_List(self, model, field, kwargs): raise ValueError('ListField "%s" must have field specified for model %s' % (field.name, model)) if isinstance(field.field, ReferenceField): - loader = getattr(self.view, '_form_ajax_refs', {}).get(field.name) - if loader: + if loader := getattr(self.view, '_form_ajax_refs', {}).get(field.name): return AjaxSelectMultipleField(loader, **kwargs) kwargs['widget'] = form.Select2Widget(multiple=True) @@ -166,8 +162,7 @@ def conv_EmbeddedDocument(self, model, field, kwargs): def conv_Reference(self, model, field, kwargs): kwargs['allow_blank'] = not field.required - loader = getattr(self.view, '_form_ajax_refs', {}).get(field.name) - if loader: + if loader := getattr(self.view, '_form_ajax_refs', {}).get(field.name): return AjaxSelectField(loader, **kwargs) kwargs['widget'] = form.Select2Widget() @@ -237,7 +232,7 @@ def find(name): if p is not None: return p - raise ValueError('Invalid model property name %s.%s' % (model, name)) + raise ValueError(f'Invalid model property name {model}.{name}') properties = ((p, find(p)) for p in only) elif exclude: diff --git a/flask_admin/contrib/mongoengine/helpers.py b/flask_admin/contrib/mongoengine/helpers.py index 48ee9f77c..a496d51a1 100644 --- a/flask_admin/contrib/mongoengine/helpers.py +++ b/flask_admin/contrib/mongoengine/helpers.py @@ -16,18 +16,17 @@ def make_gridfs_args(value): def make_thumb_args(value): - if getattr(value, 'thumbnail', None): - args = { - 'id': value.thumbnail._id, - 'coll': value.collection_name - } + if not getattr(value, 'thumbnail', None): + return make_gridfs_args(value) + args = { + 'id': value.thumbnail._id, + 'coll': value.collection_name + } - if value.db_alias != 'default': - args['db'] = value.db_alias + if value.db_alias != 'default': + args['db'] = value.db_alias - return args - else: - return make_gridfs_args(value) + return args def format_error(error): diff --git a/flask_admin/contrib/mongoengine/tools.py b/flask_admin/contrib/mongoengine/tools.py index 4d52671ce..ddec105c7 100644 --- a/flask_admin/contrib/mongoengine/tools.py +++ b/flask_admin/contrib/mongoengine/tools.py @@ -24,5 +24,5 @@ def parse_like_term(term): oper = 'contains' # add case insensitive flag if not case_sensitive: - oper = 'i' + oper + oper = f'i{oper}' return oper, term diff --git a/flask_admin/contrib/mongoengine/typefmt.py b/flask_admin/contrib/mongoengine/typefmt.py index 840526de3..6e8d9393e 100644 --- a/flask_admin/contrib/mongoengine/typefmt.py +++ b/flask_admin/contrib/mongoengine/typefmt.py @@ -27,17 +27,25 @@ def grid_formatter(view, value): def grid_image_formatter(view, value): - if not value.grid_id: - return '' - - return Markup( - ('
' + - '' + - '
') % - { - 'url': view.get_url('.api_file_view', **helpers.make_gridfs_args(value)), - 'thumb': view.get_url('.api_file_view', **helpers.make_thumb_args(value)), - }) + return ( + Markup( + ( + '
' + + '' + + '
' + ) + % { + 'url': view.get_url( + '.api_file_view', **helpers.make_gridfs_args(value) + ), + 'thumb': view.get_url( + '.api_file_view', **helpers.make_thumb_args(value) + ), + } + ) + if value.grid_id + else '' + ) DEFAULT_FORMATTERS = BASE_FORMATTERS.copy() diff --git a/flask_admin/contrib/mongoengine/view.py b/flask_admin/contrib/mongoengine/view.py index 10663164a..f066f432f 100644 --- a/flask_admin/contrib/mongoengine/view.py +++ b/flask_admin/contrib/mongoengine/view.py @@ -26,7 +26,7 @@ log = logging.getLogger("flask-admin.mongo") -SORTABLE_FIELDS = set(( +SORTABLE_FIELDS = { mongoengine.StringField, mongoengine.IntField, mongoengine.FloatField, @@ -38,8 +38,8 @@ mongoengine.ReferenceField, mongoengine.EmailField, mongoengine.UUIDField, - mongoengine.URLField -)) + mongoengine.URLField, +} class ModelView(BaseModelView): @@ -339,14 +339,12 @@ def scaffold_sortable_columns(self): """ Return a dictionary of sortable columns (name, field) """ - columns = {} - - for n, f in self._get_model_fields(): - if type(f) in SORTABLE_FIELDS: - if self.column_display_pk or type(f) != mongoengine.ObjectIdField: - columns[n] = f - - return columns + return { + n: f + for n, f in self._get_model_fields() + if type(f) in SORTABLE_FIELDS + and (self.column_display_pk or type(f) != mongoengine.ObjectIdField) + } def init_search(self): """ @@ -378,30 +376,22 @@ def scaffold_filters(self, name): :param name: Either field name or field instance """ - if isinstance(name, string_types): - attr = self.model._fields.get(name) - else: - attr = name - + attr = self.model._fields.get(name) if isinstance(name, string_types) else name if attr is None: - raise Exception('Failed to find field for filter: %s' % name) + raise Exception(f'Failed to find field for filter: {name}') - # Find name - visible_name = None - - if not isinstance(name, string_types): - visible_name = self.get_column_name(attr.name) + visible_name = ( + None + if isinstance(name, string_types) + else self.get_column_name(attr.name) + ) if not visible_name: visible_name = self.get_column_name(name) # Convert filter type_name = type(attr).__name__ - flt = self.filter_converter.convert(type_name, - attr, - visible_name) - - return flt + return self.filter_converter.convert(type_name, attr, visible_name) def is_valid_filter(self, filter): """ @@ -416,15 +406,15 @@ def scaffold_form(self): """ Create form from the model. """ - form_class = get_form(self.model, - self.model_form_converter(self), - base_class=self.form_base_class, - only=self.form_columns, - exclude=self.form_excluded_columns, - field_args=self.form_args, - extra_fields=self.form_extra_fields) - - return form_class + return get_form( + self.model, + self.model_form_converter(self), + base_class=self.form_base_class, + only=self.form_columns, + exclude=self.form_excluded_columns, + field_args=self.form_args, + extra_fields=self.form_extra_fields, + ) def scaffold_list_form(self, widget=None, validators=None): """ @@ -469,10 +459,10 @@ def _search(self, query, search_term): for field in self._search_fields: if type(field) == mongoengine.ReferenceField: import re - regex = re.compile('.*%s.*' % term) + regex = re.compile(f'.*{term}.*') else: regex = term - flt = {'%s__%s' % (field.name, op): regex} + flt = {f'{field.name}__{op}': regex} q = mongoengine.Q(**flt) if criteria is None: @@ -517,18 +507,14 @@ def get_list(self, page, sort_column, sort_desc, search, filters, query = self._search(query, search) # Get count - count = query.count() if not self.simple_list_pager else None + count = None if self.simple_list_pager else query.count() # Sorting if sort_column: - query = query.order_by('%s%s' % ('-' if sort_desc else '', sort_column)) - else: - order = self._get_default_order() - - if order: - keys = ['%s%s' % ('-' if desc else '', col) - for (col, desc) in order] - query = query.order_by(*keys) + query = query.order_by(f"{'-' if sort_desc else ''}{sort_column}") + elif order := self._get_default_order(): + keys = [f"{'-' if desc else ''}{col}" for (col, desc) in order] + query = query.order_by(*keys) # Pagination if page_size is None: @@ -667,11 +653,11 @@ def is_action_allowed(self, name): lazy_gettext('Are you sure you want to delete selected records?')) def action_delete(self, ids): try: - count = 0 - all_ids = [self.object_id_converter(pk) for pk in ids] - for obj in self.get_query().in_bulk(all_ids).values(): - count += self.delete_model(obj) + count = sum( + self.delete_model(obj) + for obj in self.get_query().in_bulk(all_ids).values() + ) flash(ngettext('Record was successfully deleted.', '%(count)s records were successfully deleted.', diff --git a/flask_admin/contrib/mongoengine/widgets.py b/flask_admin/contrib/mongoengine/widgets.py index 55a503c81..3ba2037c9 100644 --- a/flask_admin/contrib/mongoengine/widgets.py +++ b/flask_admin/contrib/mongoengine/widgets.py @@ -29,9 +29,10 @@ def __call__(self, field, **kwargs): 'name': escape(data.name), 'content_type': escape(data.content_type), 'size': data.length // 1024, - 'marker': '_%s-delete' % field.name + 'marker': f'_{field.name}-delete', } + return Markup('%s' % (placeholder, html_params(name=field.name, type='file', @@ -54,9 +55,10 @@ def __call__(self, field, **kwargs): args = helpers.make_thumb_args(field.data) placeholder = self.template % { 'thumb': get_url('.api_file_view', **args), - 'marker': '_%s-delete' % field.name + 'marker': f'_{field.name}-delete', } + return Markup('%s' % (placeholder, html_params(name=field.name, type='file', diff --git a/flask_admin/contrib/peewee/ajax.py b/flask_admin/contrib/peewee/ajax.py index 99d8842db..eab80b78f 100644 --- a/flask_admin/contrib/peewee/ajax.py +++ b/flask_admin/contrib/peewee/ajax.py @@ -18,7 +18,10 @@ def __init__(self, name, model, **options): self.fields = options.get('fields') if not self.fields: - raise ValueError('AJAX loading requires `fields` to be specified for %s.%s' % (model, self.name)) + raise ValueError( + f'AJAX loading requires `fields` to be specified for {model}.{self.name}' + ) + self._cached_fields = self._process_fields() @@ -32,7 +35,7 @@ def _process_fields(self): attr = getattr(self.model, field, None) if not attr: - raise ValueError('%s.%s does not exist.' % (self.model, field)) + raise ValueError(f'{self.model}.{field} does not exist.') remote_fields.append(attr) else: @@ -41,10 +44,7 @@ def _process_fields(self): return remote_fields def format(self, model): - if not model: - return None - - return (getattr(model, self.pk), as_unicode(model)) + return (getattr(model, self.pk), as_unicode(model)) if model else None def get_one(self, pk): return self.model.get(**{self.pk: pk}) @@ -74,7 +74,7 @@ def create_ajax_loader(model, name, field_name, options): prop = getattr(model, field_name, None) if prop is None: - raise ValueError('Model %s does not have field %s.' % (model, field_name)) + raise ValueError(f'Model {model} does not have field {field_name}.') # TODO: Check for field remote_model = prop.rel_model diff --git a/flask_admin/contrib/peewee/form.py b/flask_admin/contrib/peewee/form.py index e48390ffd..abc553bb8 100644 --- a/flask_admin/contrib/peewee/form.py +++ b/flask_admin/contrib/peewee/form.py @@ -72,7 +72,7 @@ def save_related(self, obj): attr = getattr(self.model, self.prop) values = self.model.select().where(attr == model_id).execute() - pk_map = dict((str(getattr(v, self._pk)), v) for v in values) + pk_map = {str(getattr(v, self._pk)): v for v in values} # Handle request data for field in self.entries: @@ -123,9 +123,7 @@ def __init__(self, view, additional=None): self.overrides = getattr(self.view, 'form_overrides', None) or {} def handle_foreign_key(self, model, field, **kwargs): - loader = getattr(self.view, '_form_ajax_refs', {}).get(field.name) - - if loader: + if loader := getattr(self.view, '_form_ajax_refs', {}).get(field.name): if field.null: kwargs['allow_blank'] = True @@ -199,13 +197,14 @@ def get_info(self, p): else: model = getattr(p, 'model', None) if model is None: - raise Exception('Unknown inline model admin: %s' % repr(p)) + raise Exception(f'Unknown inline model admin: {repr(p)}') - attrs = dict() + attrs = { + attr: getattr(p, attr) + for attr in dir(p) + if not attr.startswith('_') and attr != 'model' + } - for attr in dir(p): - if not attr.startswith('_') and attr != 'model': - attrs[attr] = getattr(p, attr) info = InlineFormAdmin(model, **attrs) @@ -215,13 +214,11 @@ def get_info(self, p): return info def process_ajax_refs(self, info): - refs = getattr(info, 'form_ajax_refs', None) - result = {} - if refs: + if refs := getattr(info, 'form_ajax_refs', None): for name, opts in iteritems(refs): - new_name = '%s.%s' % (info.model.__name__.lower(), name) + new_name = f'{info.model.__name__.lower()}.{name}' loader = None if isinstance(opts, (list, tuple)): @@ -243,12 +240,11 @@ def contribute(self, converter, model, form_class, inline_model): for field in get_meta_fields(info.model): field_type = type(field) - if field_type == ForeignKeyField: - if field.rel_model == model: - reverse_field = field - break + if field_type == ForeignKeyField and field.rel_model == model: + reverse_field = field + break else: - raise Exception('Cannot find reverse relation for model %s' % info.model) + raise Exception(f'Cannot find reverse relation for model {info.model}') # Remove reverse property from the list ignore = [reverse_field.name] diff --git a/flask_admin/contrib/peewee/tools.py b/flask_admin/contrib/peewee/tools.py index 98277d56b..dff9a5548 100644 --- a/flask_admin/contrib/peewee/tools.py +++ b/flask_admin/contrib/peewee/tools.py @@ -4,18 +4,16 @@ def get_primary_key(model): def parse_like_term(term): if term.startswith('^'): - stmt = '%s%%' % term[1:] + return '%s%%' % term[1:] elif term.startswith('='): - stmt = term[1:] + return term[1:] else: - stmt = '%%%s%%' % term - - return stmt + return '%%%s%%' % term def get_meta_fields(model): - if hasattr(model._meta, 'sorted_fields'): - fields = model._meta.sorted_fields - else: - fields = model._meta.get_fields() - return fields + return ( + model._meta.sorted_fields + if hasattr(model._meta, 'sorted_fields') + else model._meta.get_fields() + ) diff --git a/flask_admin/contrib/peewee/view.py b/flask_admin/contrib/peewee/view.py index ee77873a4..b35b375a0 100644 --- a/flask_admin/contrib/peewee/view.py +++ b/flask_admin/contrib/peewee/view.py @@ -185,9 +185,11 @@ def scaffold_pk(self): def get_pk_value(self, model): if self.model._meta.composite_key: - return tuple([ + return tuple( model._data[field_name] - for field_name in self.model._meta.primary_key.field_names]) + for field_name in self.model._meta.primary_key.field_names + ) + return getattr(model, self._primary_key) def scaffold_list_columns(self): @@ -197,21 +199,20 @@ def scaffold_list_columns(self): # Verify type field_class = type(f) - if field_class == ForeignKeyField: - columns.append(n) - elif self.column_display_pk or field_class != PrimaryKeyField: + if ( + field_class == ForeignKeyField + or self.column_display_pk + or field_class != PrimaryKeyField + ): columns.append(n) - return columns def scaffold_sortable_columns(self): - columns = dict() - - for n, f in self._get_model_fields(): - if self.column_display_pk or type(f) != PrimaryKeyField: - columns[n] = f - - return columns + return { + n: f + for n, f in self._get_model_fields() + if self.column_display_pk or type(f) != PrimaryKeyField + } def init_search(self): if self.column_searchable_list: @@ -235,7 +236,7 @@ def scaffold_filters(self, name): attr = name if attr is None: - raise Exception('Failed to find field for filter: %s' % name) + raise Exception(f'Failed to find field for filter: {name}') # Check if field is in different model model_class = None @@ -245,20 +246,17 @@ def scaffold_filters(self, name): model_class = attr.model if model_class != self.model: - visible_name = '%s / %s' % (self.get_column_name(model_class.__name__), - self.get_column_name(attr.name)) + visible_name = f'{self.get_column_name(model_class.__name__)} / {self.get_column_name(attr.name)}' + else: - if not isinstance(name, string_types): - visible_name = self.get_column_name(attr.name) - else: - visible_name = self.get_column_name(name) + visible_name = ( + self.get_column_name(name) + if isinstance(name, string_types) + else self.get_column_name(attr.name) + ) type_name = type(attr).__name__ - flt = self.filter_converter.convert(type_name, - attr, - visible_name) - - return flt + return self.filter_converter.convert(type_name, attr, visible_name) def is_valid_filter(self, filter): return isinstance(filter, filters.BasePeeweeFilter) @@ -416,17 +414,15 @@ def get_list(self, page, sort_column, sort_desc, search, filters, query = f.apply(query, f.clean(value)) # Get count - count = query.count() if not self.simple_list_pager else None + count = None if self.simple_list_pager else query.count() # Apply sorting if sort_column is not None: sort_field = self._sortable_columns[sort_column] order = [(sort_field, sort_desc)] query, joins = self._order_by(query, joins, order) - else: - order = self._get_default_order() - if order: - query, joins = self._order_by(query, joins, order) + elif order := self._get_default_order(): + query, joins = self._order_by(query, joins, order) # Pagination if page_size is None: diff --git a/flask_admin/contrib/pymongo/tools.py b/flask_admin/contrib/pymongo/tools.py index e931527de..638f69884 100644 --- a/flask_admin/contrib/pymongo/tools.py +++ b/flask_admin/contrib/pymongo/tools.py @@ -9,8 +9,8 @@ def parse_like_term(term): Search term """ if term.startswith('^'): - return '^{}'.format(re.escape(term[1:])) + return f'^{re.escape(term[1:])}' elif term.startswith('='): - return '^{}$'.format(re.escape(term[1:])) + return f'^{re.escape(term[1:])}$' return re.escape(term) diff --git a/flask_admin/contrib/pymongo/view.py b/flask_admin/contrib/pymongo/view.py index fe859fba8..fa7fa73c4 100644 --- a/flask_admin/contrib/pymongo/view.py +++ b/flask_admin/contrib/pymongo/view.py @@ -97,7 +97,7 @@ def __init__(self, coll, name = self._prettify_name(coll.name) if endpoint is None: - endpoint = ('%sview' % coll.name).lower() + endpoint = f'{coll.name}view'.lower() super(ModelView, self).__init__(None, name, category, endpoint, url, menu_class_name=menu_class_name, @@ -184,11 +184,9 @@ def _search(self, query, search_term): regex = parse_like_term(value) - stmt = [] - for field in self._search_fields: - stmt.append({field: {'$regex': regex}}) - - if stmt: + if stmt := [ + {field: {'$regex': regex}} for field in self._search_fields + ]: if len(stmt) == 1: queries.append(stmt[0]) else: @@ -196,16 +194,8 @@ def _search(self, query, search_term): # Construct final query if queries: - if len(queries) == 1: - final = queries[0] - else: - final = {'$and': queries} - - if query: - query = {'$and': [query, final]} - else: - query = final - + final = queries[0] if len(queries) == 1 else {'$and': queries} + query = {'$and': [query, final]} if query else final return query def get_list(self, page, sort_column, sort_desc, search, filters, @@ -251,29 +241,22 @@ def get_list(self, page, sort_column, sort_desc, search, filters, query = self._search(query, search) # Get count - count = self.coll.find(query).count() if not self.simple_list_pager else None + count = None if self.simple_list_pager else self.coll.find(query).count() # Sorting sort_by = None if sort_column: sort_by = [(sort_column, pymongo.DESCENDING if sort_desc else pymongo.ASCENDING)] - else: - order = self._get_default_order() - - if order: - sort_by = [(col, pymongo.DESCENDING if desc else pymongo.ASCENDING) - for (col, desc) in order] + elif order := self._get_default_order(): + sort_by = [(col, pymongo.DESCENDING if desc else pymongo.ASCENDING) + for (col, desc) in order] # Pagination if page_size is None: page_size = self.page_size - skip = 0 - - if page and page_size: - skip = page * page_size - + skip = page * page_size if page and page_size else 0 results = self.coll.find(query, sort=sort_by, skip=skip, limit=page_size) if execute: @@ -386,12 +369,7 @@ def is_action_allowed(self, name): lazy_gettext('Are you sure you want to delete selected records?')) def action_delete(self, ids): try: - count = 0 - - # TODO: Optimize me - for pk in ids: - if self.delete_model(self.get_one(pk)): - count += 1 + count = sum(1 for pk in ids if self.delete_model(self.get_one(pk))) flash(ngettext('Record was successfully deleted.', '%(count)s records were successfully deleted.', diff --git a/flask_admin/contrib/rediscli.py b/flask_admin/contrib/rediscli.py index 659dc5e0d..296c038c5 100644 --- a/flask_admin/contrib/rediscli.py +++ b/flask_admin/contrib/rediscli.py @@ -103,9 +103,7 @@ def _execute_command(self, name, args): :param args: Command arguments """ - # Do some remapping - new_cmd = self.remapped_commands.get(name) - if new_cmd: + if new_cmd := self.remapped_commands.get(name): name = new_cmd # Execute command @@ -150,8 +148,11 @@ def _cmd_help(self, *args): Help command implementation. """ if not args: - help = 'Usage: help .\nList of supported commands: ' - help += ', '.join(n for n in sorted(self.commands)) + help = ( + 'Usage: help .\nList of supported commands: ' + + ', '.join(sorted(self.commands)) + ) + return TextWrapper(help) cmd = args[0] @@ -188,7 +189,7 @@ def execute_view(self): return self._execute_command(parts[0], parts[1:]) except CommandError as err: - return self._error('Cli: %s' % err) + return self._error(f'Cli: {err}') except Exception as ex: log.exception(ex) - return self._error('Cli: %s' % ex) + return self._error(f'Cli: {ex}') diff --git a/flask_admin/contrib/sqla/ajax.py b/flask_admin/contrib/sqla/ajax.py index 8e4ab3cf9..492790c2d 100644 --- a/flask_admin/contrib/sqla/ajax.py +++ b/flask_admin/contrib/sqla/ajax.py @@ -26,7 +26,10 @@ def __init__(self, name, session, model, **options): self.filters = options.get('filters') if not self.fields: - raise ValueError('AJAX loading requires `fields` to be specified for %s.%s' % (model, self.name)) + raise ValueError( + f'AJAX loading requires `fields` to be specified for {model}.{self.name}' + ) + self._cached_fields = self._process_fields() @@ -43,7 +46,7 @@ def _process_fields(self): attr = getattr(self.model, field, None) if not attr: - raise ValueError('%s.%s does not exist.' % (self.model, field)) + raise ValueError(f'{self.model}.{field} does not exist.') remote_fields.append(attr) else: @@ -53,10 +56,7 @@ def _process_fields(self): return remote_fields def format(self, model): - if not model: - return None - - return getattr(model, self.pk), as_unicode(model) + return (getattr(model, self.pk), as_unicode(model)) if model else None def get_query(self): return self.session.query(self.model) @@ -75,7 +75,11 @@ def get_list(self, term, offset=0, limit=DEFAULT_PAGE_SIZE): query = query.filter(or_(*filters)) if self.filters: - filters = [text("%s.%s" % (self.model.__tablename__.lower(), value)) for value in self.filters] + filters = [ + text(f"{self.model.__tablename__.lower()}.{value}") + for value in self.filters + ] + query = query.filter(and_(*filters)) if self.order_by: @@ -88,10 +92,10 @@ def create_ajax_loader(model, session, name, field_name, options): attr = getattr(model, field_name, None) if attr is None: - raise ValueError('Model %s does not have field %s.' % (model, field_name)) + raise ValueError(f'Model {model} does not have field {field_name}.') if not is_relationship(attr) and not is_association_proxy(attr): - raise ValueError('%s.%s is not a relation.' % (model, field_name)) + raise ValueError(f'{model}.{field_name} is not a relation.') if is_association_proxy(attr): attr = attr.remote_attr diff --git a/flask_admin/contrib/sqla/fields.py b/flask_admin/contrib/sqla/fields.py index fcb22ceb0..35e4a5ef4 100644 --- a/flask_admin/contrib/sqla/fields.py +++ b/flask_admin/contrib/sqla/fields.py @@ -176,7 +176,7 @@ def pre_validate(self, form): if self._invalid_formdata: raise ValidationError(self.gettext(u'Not a valid choice')) elif self.data: - obj_list = list(x[1] for x in self._get_object_list()) + obj_list = [x[1] for x in self._get_object_list()] for v in self.data: if v not in obj_list: raise ValidationError(self.gettext(u'Not a valid choice')) @@ -233,7 +233,7 @@ def process(self, formdata, data=unset_value): def populate_obj(self, obj, name): """ Combines each FormField key/value into a dictionary for storage """ - _fake = type(str('_fake'), (object, ), {}) + _fake = type('_fake', (object, ), {}) output = {} for form_field in self.entries: @@ -297,7 +297,7 @@ def populate_obj(self, obj, name): return # Create primary key map - pk_map = dict((get_obj_pk(v, self._pk), v) for v in values) + pk_map = {get_obj_pk(v, self._pk): v for v in values} # Handle request data for field in self.entries: diff --git a/flask_admin/contrib/sqla/filters.py b/flask_admin/contrib/sqla/filters.py index cdbe680eb..9ea2d956c 100644 --- a/flask_admin/contrib/sqla/filters.py +++ b/flask_admin/contrib/sqla/filters.py @@ -87,7 +87,7 @@ def operation(self): class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter): def apply(self, query, value, alias=None): if value == '1': - return query.filter(self.get_column(alias) == None) # noqa: E711 + return query.filter(self.get_column(alias) is None) else: return query.filter(self.get_column(alias) != None) # noqa: E711 @@ -113,7 +113,7 @@ class FilterNotInList(FilterInList): def apply(self, query, value, alias=None): # NOT IN can exclude NULL values, so "or_ == None" needed to be added column = self.get_column(alias) - return query.filter(or_(~column.in_(value), column == None)) # noqa: E711 + return query.filter(or_(~column.in_(value), column is None)) def operation(self): return lazy_gettext('not in list') @@ -384,7 +384,7 @@ def apply(self, query, user_query, alias=None): break if choice_type: # != can exclude NULL values, so "or_ == None" needed to be added - return query.filter(or_(column != choice_type, column == None)) # noqa: E711 + return query.filter(or_(column != choice_type, column is None)) else: return query @@ -399,17 +399,20 @@ def apply(self, query, user_query, alias=None): if user_query: # loop through choice 'values' looking for matches if isinstance(column.type.choices, enum.EnumMeta): - for choice in column.type.choices: - if user_query.lower() in choice.name.lower(): - choice_types.append(choice.value) + choice_types.extend( + choice.value + for choice in column.type.choices + if user_query.lower() in choice.name.lower() + ) + else: - for type, value in column.type.choices: - if user_query.lower() in value.lower(): - choice_types.append(type) - if choice_types: - return query.filter(column.in_(choice_types)) - else: - return query + choice_types.extend( + type + for type, value in column.type.choices + if user_query.lower() in value.lower() + ) + + return query.filter(column.in_(choice_types)) if choice_types else query class ChoiceTypeNotLikeFilter(FilterNotLike): @@ -422,16 +425,22 @@ def apply(self, query, user_query, alias=None): if user_query: # loop through choice 'values' looking for matches if isinstance(column.type.choices, enum.EnumMeta): - for choice in column.type.choices: - if user_query.lower() in choice.name.lower(): - choice_types.append(choice.value) + choice_types.extend( + choice.value + for choice in column.type.choices + if user_query.lower() in choice.name.lower() + ) + else: - for type, value in column.type.choices: - if user_query.lower() in value.lower(): - choice_types.append(type) + choice_types.extend( + type + for type, value in column.type.choices + if user_query.lower() in value.lower() + ) + if choice_types: # != can exclude NULL values, so "or_ == None" needed to be added - return query.filter(or_(column.notin_(choice_types), column == None)) # noqa: E711 + return query.filter(or_(column.notin_(choice_types), column is None)) else: return query diff --git a/flask_admin/contrib/sqla/form.py b/flask_admin/contrib/sqla/form.py index f6dd499fb..88a372c60 100644 --- a/flask_admin/contrib/sqla/form.py +++ b/flask_admin/contrib/sqla/form.py @@ -44,13 +44,12 @@ def _get_label(self, name, field_args): if 'label' in field_args: return field_args['label'] - column_labels = get_property(self.view, 'column_labels', 'rename_columns') - - if column_labels: + if column_labels := get_property( + self.view, 'column_labels', 'rename_columns' + ): return column_labels.get(name) - prettify_override = getattr(self.view, 'prettify_name', None) - if prettify_override: + if prettify_override := getattr(self.view, 'prettify_name', None): return prettify_override(name) return prettify_name(name) @@ -59,23 +58,17 @@ def _get_description(self, name, field_args): if 'description' in field_args: return field_args['description'] - column_descriptions = getattr(self.view, 'column_descriptions', None) - - if column_descriptions: + if column_descriptions := getattr(self.view, 'column_descriptions', None): return column_descriptions.get(name) def _get_field_override(self, name): - form_overrides = getattr(self.view, 'form_overrides', None) - - if form_overrides: + if form_overrides := getattr(self.view, 'form_overrides', None): return form_overrides.get(name) return None def _model_select_field(self, prop, multiple, remote_model, **kwargs): - loader = getattr(self.view, '_form_ajax_refs', {}).get(prop.key) - - if loader: + if loader := getattr(self.view, '_form_ajax_refs', {}).get(prop.key): if multiple: return AjaxSelectMultipleField(loader, **kwargs) else: @@ -118,9 +111,7 @@ def _convert_relation(self, name, prop, property_is_association_proxy, kwargs): if not requirement_validator_specified: kwargs['validators'].append(validators.InputRequired()) - # Override field type if necessary - override = self._get_field_override(prop.key) - if override: + if override := self._get_field_override(prop.key): return override(**kwargs) multiple = (property_is_association_proxy or @@ -138,21 +129,29 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk): } if field_args: - kwargs.update(field_args) + kwargs |= field_args if kwargs['validators']: # Create a copy of the list since we will be modifying it. kwargs['validators'] = list(kwargs['validators']) # Check if it is relation or property - if hasattr(prop, 'direction') or is_association_proxy(prop): + if ( + hasattr(prop, 'direction') + or not hasattr(prop, 'direction') + and is_association_proxy(prop) + ): property_is_association_proxy = is_association_proxy(prop) if property_is_association_proxy: if not hasattr(prop.remote_attr, 'prop'): raise Exception('Association proxy referencing another association proxy is not supported.') prop = prop.remote_attr.prop return self._convert_relation(name, prop, property_is_association_proxy, kwargs) - elif hasattr(prop, 'columns'): # Ignore pk/fk + elif ( + not hasattr(prop, 'direction') + and not is_association_proxy(prop) + and hasattr(prop, 'columns') + ): # Ignore pk/fk # Check if more than one column mapped to the property if len(prop.columns) > 1 and not isinstance(prop, ColumnProperty): columns = filter_foreign_columns(model.__table__, prop.columns) @@ -160,7 +159,10 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk): if len(columns) == 0: return None elif len(columns) > 1: - warnings.warn('Can not convert multiple-column properties (%s.%s)' % (model, prop.key)) + warnings.warn( + f'Can not convert multiple-column properties ({model}.{prop.key})' + ) + return None column = columns[0] @@ -184,18 +186,17 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk): if hidden_pk: # If requested to add hidden field, show it return fields.HiddenField() - else: - # By default, don't show primary keys either - # If PK is not explicitly allowed, ignore it - if prop.key not in form_columns: - return None - - # Current Unique Validator does not work with multicolumns-pks - if not has_multiple_pks(model): - kwargs['validators'].append(Unique(self.session, - model, - column)) - unique = True + # By default, don't show primary keys either + # If PK is not explicitly allowed, ignore it + if prop.key not in form_columns: + return None + + # Current Unique Validator does not work with multicolumns-pks + if not has_multiple_pks(model): + kwargs['validators'].append(Unique(self.session, + model, + column)) + unique = True # If field is unique, validate it if column.unique and not unique: @@ -228,9 +229,8 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk): if value is not None: if getattr(default, 'is_callable', False): value = lambda: default.arg(None) # noqa: E731 - else: - if not getattr(default, 'is_scalar', True): - value = None + elif not getattr(default, 'is_scalar', True): + value = None if value is not None: kwargs['default'] = value @@ -239,16 +239,13 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk): if column.nullable: kwargs['validators'].append(validators.Optional()) - # Override field type if necessary - override = self._get_field_override(prop.key) - if override: + if override := self._get_field_override(prop.key): return override(**kwargs) # Check if a list of 'form_choices' are specified form_choices = getattr(self.view, 'form_choices', None) if mapper.class_ == self.view.model and form_choices: - choices = form_choices.get(prop.key) - if choices: + if choices := form_choices.get(prop.key): return form.Select2Field( choices=choices, allow_blank=column.nullable, @@ -441,7 +438,7 @@ def avoid_empty_strings(value): except AttributeError: # values are not always strings pass - return value if value else None + return value or None def choice_type_coerce_factory(type_): @@ -474,10 +471,7 @@ def _resolve_prop(prop): Property to resolve """ # Try to see if it is proxied property - if hasattr(prop, '_proxied_property'): - return prop._proxied_property - - return prop + return prop._proxied_property if hasattr(prop, '_proxied_property') else prop # Get list of fields and generate form @@ -527,7 +521,11 @@ def find(name): column, path = get_field_with_path(model, name, return_remote_proxy_attr=False) - if path and not (is_relationship(column) or is_association_proxy(column)): + if ( + path + and not is_relationship(column) + and not is_association_proxy(column) + ): raise Exception("form column is located in another table and " "requires inline_models: {0}".format(name)) @@ -539,7 +537,7 @@ def find(name): if column is not None and hasattr(column, 'property'): return relation_name, column.property - raise ValueError('Invalid model property name %s.%s' % (model, name)) + raise ValueError(f'Invalid model property name {model}.{name}') # Filter properties while maintaining property order in 'only' list properties = (find(x) for x in only) @@ -602,20 +600,18 @@ def get_info(self, p): if info is None: if hasattr(p, '_sa_class_manager'): return self.form_admin_class(p) - else: - model = getattr(p, 'model', None) + model = getattr(p, 'model', None) - if model is None: - raise Exception('Unknown inline model admin: %s' % repr(p)) + if model is None: + raise Exception(f'Unknown inline model admin: {repr(p)}') - attrs = dict() - for attr in dir(p): - if not attr.startswith('_') and attr != 'model': - attrs[attr] = getattr(p, attr) + attrs = { + attr: getattr(p, attr) + for attr in dir(p) + if not attr.startswith('_') and attr != 'model' + } - return self.form_admin_class(model, **attrs) - - info = self.form_admin_class(model, **attrs) + return self.form_admin_class(model, **attrs) # Resolve AJAX FKs info._form_ajax_refs = self.process_ajax_refs(info) @@ -623,13 +619,11 @@ def get_info(self, p): return info def process_ajax_refs(self, info): - refs = getattr(info, 'form_ajax_refs', None) - result = {} - if refs: + if refs := getattr(info, 'form_ajax_refs', None): for name, opts in iteritems(refs): - new_name = '%s-%s' % (info.model.__name__.lower(), name) + new_name = f'{info.model.__name__.lower()}-{name}' loader = None if isinstance(opts, dict): @@ -664,28 +658,30 @@ def _calculate_mapping_key_pair(self, model, info): reverse_prop = None for prop in target_mapper.iterate_properties: - if hasattr(prop, 'direction') and prop.direction.name in ('MANYTOONE', 'MANYTOMANY'): - if issubclass(model, prop.mapper.class_): - reverse_prop = prop - break + if ( + hasattr(prop, 'direction') + and prop.direction.name in ('MANYTOONE', 'MANYTOMANY') + and issubclass(model, prop.mapper.class_) + ): + reverse_prop = prop + break else: - raise Exception('Cannot find reverse relation for model %s' % info.model) + raise Exception(f'Cannot find reverse relation for model {info.model}') # Find forward property forward_prop = None - if prop.direction.name == 'MANYTOONE': - candidate = 'ONETOMANY' - else: - candidate = 'MANYTOMANY' - + candidate = 'ONETOMANY' if prop.direction.name == 'MANYTOONE' else 'MANYTOMANY' for prop in mapper.iterate_properties: - if hasattr(prop, 'direction') and prop.direction.name == candidate: - if prop.mapper.class_ == target_mapper.class_: - forward_prop = prop - break + if ( + hasattr(prop, 'direction') + and prop.direction.name == candidate + and prop.mapper.class_ == target_mapper.class_ + ): + forward_prop = prop + break else: - raise Exception('Cannot find forward relation for model %s' % info.model) + raise Exception(f'Cannot find forward relation for model {info.model}') return forward_prop.key, reverse_prop.key @@ -745,10 +741,9 @@ def contribute(self, model, form_class, inline_model): # Post-process form child_form = info.postprocess_form(child_form) - kwargs = dict() + kwargs = {} - label = self.get_label(info, forward_prop_key) - if label: + if label := self.get_label(info, forward_prop_key): kwargs['label'] = label if self.view.form_args: @@ -776,7 +771,7 @@ def _calculate_mapping_key_pair(self, model, info): mapper = info.model._sa_class_manager.mapper.base_mapper target_mapper = model._sa_class_manager.mapper - inline_relationship = dict() + inline_relationship = {} for forward_prop in mapper.iterate_properties: if not hasattr(forward_prop, 'direction'): @@ -788,15 +783,9 @@ def _calculate_mapping_key_pair(self, model, info): if forward_prop.mapper.class_ != target_mapper.class_: continue - # in case when model has few relationships to target model or - # has just installed references manually. This is more quick - # solution rather than rotate yet another one loop - ref = getattr(forward_prop, 'backref') - - if not ref: - ref = getattr(forward_prop, 'back_populates') - - if ref: + if ref := getattr(forward_prop, 'backref') or getattr( + forward_prop, 'back_populates' + ): inline_relationship[ref] = forward_prop.key continue @@ -813,13 +802,11 @@ def _calculate_mapping_key_pair(self, model, info): inline_relationship[backward_prop.key] = forward_prop.key break else: - raise Exception( - 'Cannot find reverse relation for model %s' % info.model) + raise Exception(f'Cannot find reverse relation for model {info.model}') break if not inline_relationship: - raise Exception( - 'Cannot find forward relation for model %s' % info.model) + raise Exception(f'Cannot find forward relation for model {info.model}') return inline_relationship @@ -829,7 +816,7 @@ def contribute(self, model, form_class, inline_model): inline_relationships = self._calculate_mapping_key_pair(model, info) # Remove reverse property from the list - ignore = [value for value in inline_relationships.values()] + ignore = list(inline_relationships.values()) if info.form_excluded_columns: exclude = ignore + list(info.form_excluded_columns) @@ -855,7 +842,7 @@ def contribute(self, model, form_class, inline_model): # Post-process form child_form = info.postprocess_form(child_form) - kwargs = dict() + kwargs = {} # Contribute field for key in inline_relationships.keys(): diff --git a/flask_admin/contrib/sqla/tools.py b/flask_admin/contrib/sqla/tools.py index 28390d479..d56d0ccf2 100644 --- a/flask_admin/contrib/sqla/tools.py +++ b/flask_admin/contrib/sqla/tools.py @@ -19,13 +19,11 @@ def parse_like_term(term): if term.startswith('^'): - stmt = '%s%%' % term[1:] + return '%s%%' % term[1:] elif term.startswith('='): - stmt = term[1:] + return term[1:] else: - stmt = '%%%s%%' % term - - return stmt + return '%%%s%%' % term def filter_foreign_columns(base_table, columns): @@ -82,14 +80,9 @@ def tuple_operator_in(model_pk, ids): """ ands = [] for id in ids: - k = [] - for i in range(len(model_pk)): - k.append(eq(model_pk[i], id[i])) + k = [eq(model_pk[i], id[i]) for i in range(len(model_pk))] ands.append(and_(*k)) - if len(ands) >= 1: - return or_(*ands) - else: - return None + return or_(*ands) if ands else None def get_query_for_ids(modelquery, model, ids): @@ -125,7 +118,7 @@ def get_columns_for_field(field): not hasattr(field, 'property') or not hasattr(field.property, 'columns') or not field.property.columns): - raise Exception('Invalid field %s: does not contains any columns.' % field) + raise Exception(f'Invalid field {field}: does not contains any columns.') return field.property.columns @@ -177,7 +170,7 @@ def get_field_with_path(model, name, return_remote_proxy_attr=True): columns = get_columns_for_field(attr) if len(columns) > 1: - raise Exception('Can only handle one column for %s' % name) + raise Exception(f'Can only handle one column for {name}') column = columns[0] @@ -190,32 +183,31 @@ def get_field_with_path(model, name, return_remote_proxy_attr=True): # copied from sqlalchemy-utils def get_hybrid_properties(model): - return dict( - (key, prop) + return { + key: prop for key, prop in inspect(model).all_orm_descriptors.items() if isinstance(prop, hybrid_property) - ) + } def is_hybrid_property(model, attr_name): - if isinstance(attr_name, string_types): - names = attr_name.split('.') - last_model = model - for i in range(len(names) - 1): - attr = getattr(last_model, names[i]) - if is_association_proxy(attr): - attr = attr.remote_attr - last_model = attr.property.argument - if isinstance(last_model, string_types): - last_model = attr.property._clsregistry_resolve_name(last_model)() - elif isinstance(last_model, _class_resolver): - last_model = model._decl_class_registry[last_model.arg] - elif isinstance(last_model, (types.FunctionType, types.MethodType)): - last_model = last_model() - last_name = names[-1] - return last_name in get_hybrid_properties(last_model) - else: + if not isinstance(attr_name, string_types): return attr_name.name in get_hybrid_properties(model) + names = attr_name.split('.') + last_model = model + for i in range(len(names) - 1): + attr = getattr(last_model, names[i]) + if is_association_proxy(attr): + attr = attr.remote_attr + last_model = attr.property.argument + if isinstance(last_model, string_types): + last_model = attr.property._clsregistry_resolve_name(last_model)() + elif isinstance(last_model, _class_resolver): + last_model = model._decl_class_registry[last_model.arg] + elif isinstance(last_model, (types.FunctionType, types.MethodType)): + last_model = last_model() + last_name = names[-1] + return last_name in get_hybrid_properties(last_model) def is_relationship(attr): diff --git a/flask_admin/contrib/sqla/validators.py b/flask_admin/contrib/sqla/validators.py index b82f1f85d..3db94866d 100644 --- a/flask_admin/contrib/sqla/validators.py +++ b/flask_admin/contrib/sqla/validators.py @@ -37,7 +37,7 @@ def __call__(self, form, field): .filter(self.column == field.data) .one()) - if not hasattr(form, '_obj') or not form._obj == obj: + if not hasattr(form, '_obj') or form._obj != obj: if self.message is None: self.message = field.gettext(u'Already exists.') raise ValidationError(self.message) diff --git a/flask_admin/contrib/sqla/view.py b/flask_admin/contrib/sqla/view.py index 3095e8d3f..0d1b57f97 100755 --- a/flask_admin/contrib/sqla/view.py +++ b/flask_admin/contrib/sqla/view.py @@ -332,9 +332,9 @@ def __init__(self, model, session, self._search_fields = None - self._filter_joins = dict() + self._filter_joins = {} - self._sortable_joins = dict() + self._sortable_joins = {} if self.form_choices is None: self.form_choices = {} @@ -350,13 +350,12 @@ def __init__(self, model, session, self._primary_key = self.scaffold_pk() if self._primary_key is None: - raise Exception('Model %s does not have primary key.' % self.model.__name__) + raise Exception(f'Model {self.model.__name__} does not have primary key.') # Configuration - if not self.column_select_related_list: - self._auto_joins = self.scaffold_auto_joins() - else: - self._auto_joins = self.column_select_related_list + self._auto_joins = ( + self.column_select_related_list or self.scaffold_auto_joins() + ) # Internal API def _get_model_iterator(self, model=None): @@ -441,7 +440,10 @@ def scaffold_list_columns(self): if len(filtered) == 0: continue elif len(filtered) > 1: - warnings.warn('Can not convert multiple-column properties (%s.%s)' % (self.model, p.key)) + warnings.warn( + f'Can not convert multiple-column properties ({self.model}.{p.key})' + ) + continue column = filtered[0] @@ -463,7 +465,7 @@ def scaffold_sortable_columns(self): Return a dictionary of sortable columns. Key is column name, value is sort column/field. """ - columns = dict() + columns = {} for p in self._get_model_iterator(): if hasattr(p, 'columns'): @@ -493,45 +495,43 @@ def get_sortable_columns(self): If `column_sortable_list` is set, will use it. Otherwise, will call `scaffold_sortable_columns` to get them from the model. """ - self._sortable_joins = dict() + self._sortable_joins = {} if self.column_sortable_list is None: return self.scaffold_sortable_columns() - else: - result = dict() - - for c in self.column_sortable_list: - if isinstance(c, tuple): - if isinstance(c[1], tuple): - column, path = [], [] - for item in c[1]: - column_item, path_item = tools.get_field_with_path(self.model, item) - column.append(column_item) - path.append(path_item) - column_name = c[0] - else: - column, path = tools.get_field_with_path(self.model, c[1]) - column_name = c[0] + result = {} + + for c in self.column_sortable_list: + if isinstance(c, tuple): + if isinstance(c[1], tuple): + column, path = [], [] + for item in c[1]: + column_item, path_item = tools.get_field_with_path(self.model, item) + column.append(column_item) + path.append(path_item) else: - column, path = tools.get_field_with_path(self.model, c) - column_name = text_type(c) + column, path = tools.get_field_with_path(self.model, c[1]) + column_name = c[0] + else: + column, path = tools.get_field_with_path(self.model, c) + column_name = text_type(c) - if path and (hasattr(path[0], 'property') or isinstance(path[0], list)): - self._sortable_joins[column_name] = path - elif path: - raise Exception("For sorting columns in a related table, " - "column_sortable_list requires a string " - "like '.'. " - "Failed on: {0}".format(c)) - else: - # column is in same table, use only model attribute name - if getattr(column, 'key', None) is not None: - column_name = column.key + if path and (hasattr(path[0], 'property') or isinstance(path[0], list)): + self._sortable_joins[column_name] = path + elif path: + raise Exception("For sorting columns in a related table, " + "column_sortable_list requires a string " + "like '.'. " + "Failed on: {0}".format(c)) + else: + # column is in same table, use only model attribute name + if getattr(column, 'key', None) is not None: + column_name = column.key - # column_name must match column_name used in `get_list_columns` - result[column_name] = column + # column_name must match column_name used in `get_list_columns` + result[column_name] = column - return result + return result def get_column_names(self, only_columns, excluded_columns): """ @@ -554,15 +554,10 @@ def get_column_names(self, only_columns, excluded_columns): try: column, path = tools.get_field_with_path(self.model, c) - if path: - # column is a relation (InstrumentedAttribute), use full path - column_name = text_type(c) + if not path and getattr(column, 'key', None) is not None: + column_name = column.key else: - # column is in same table, use only model attribute name - if getattr(column, 'key', None) is not None: - column_name = column.key - else: - column_name = text_type(c) + column_name = text_type(c) except AttributeError: # TODO: See ticket #1299 - allow virtual columns. Probably figure out # better way to handle it. For now just assume if column was not found - it @@ -591,7 +586,7 @@ def init_search(self): attr, joins = tools.get_field_with_path(self.model, name) if not attr: - raise Exception('Failed to find field for search field: %s' % name) + raise Exception(f'Failed to find field for search field: {name}') if tools.is_hybrid_property(self.model, name): column = attr @@ -599,8 +594,10 @@ def init_search(self): column.key = name.split('.')[-1] self._search_fields.append((column, joins)) else: - for column in tools.get_columns_for_field(attr): - self._search_fields.append((column, joins)) + self._search_fields.extend( + (column, joins) + for column in tools.get_columns_for_field(attr) + ) return bool(self.column_searchable_list) @@ -639,7 +636,7 @@ def scaffold_filters(self, name): attr, joins = tools.get_field_with_path(self.model, name) if attr is None: - raise Exception('Failed to find field for filter: %s' % name) + raise Exception(f'Failed to find field for filter: {name}') # Figure out filters for related column if is_relationship(attr): @@ -653,15 +650,13 @@ def scaffold_filters(self, name): if column.foreign_keys or column.primary_key: continue - visible_name = '%s / %s' % (self.get_column_name(attr.prop.target.name), - self.get_column_name(p.key)) + visible_name = f'{self.get_column_name(attr.prop.target.name)} / {self.get_column_name(p.key)}' - type_name = type(column.type).__name__ - flt = self.filter_converter.convert(type_name, - column, - visible_name) - if flt: + type_name = type(column.type).__name__ + if flt := self.filter_converter.convert( + type_name, column, visible_name + ): table = column.table if joins: @@ -682,7 +677,7 @@ def scaffold_filters(self, name): columns = tools.get_columns_for_field(attr) if len(columns) > 1: - raise Exception('Can not filter more than on one column for %s' % name) + raise Exception(f'Can not filter more than on one column for {name}') column = columns[0] @@ -695,26 +690,19 @@ def scaffold_filters(self, name): # Join not needed for hybrid properties if (not is_hybrid_property and tools.need_join(self.model, column.table) and name not in self.column_labels): - if joined_column_name: - visible_name = '%s / %s / %s' % ( - joined_column_name, - self.get_column_name(column.table.name), - self.get_column_name(column.name) - ) - else: - visible_name = '%s / %s' % ( - self.get_column_name(column.table.name), - self.get_column_name(column.name) - ) + visible_name = ( + f'{joined_column_name} / {self.get_column_name(column.table.name)} / {self.get_column_name(column.name)}' + if joined_column_name + else f'{self.get_column_name(column.table.name)} / {self.get_column_name(column.name)}' + ) + + elif not isinstance(name, string_types): + visible_name = self.get_column_name(name.property.key) + elif self.column_labels and name in self.column_labels: + visible_name = self.column_labels[name] else: - if not isinstance(name, string_types): - visible_name = self.get_column_name(name.property.key) - else: - if self.column_labels and name in self.column_labels: - visible_name = self.column_labels[name] - else: - visible_name = self.get_column_name(name) - visible_name = visible_name.replace('.', ' / ') + visible_name = self.get_column_name(name) + visible_name = visible_name.replace('.', ' / ') type_name = type(column.type).__name__ @@ -837,13 +825,11 @@ def scaffold_auto_joins(self): if p.direction.name in ['MANYTOONE', 'MANYTOMANY']: relations.add(p.key) - joined = [] - - for prop, name in self._list_columns: - if prop in relations: - joined.append(getattr(self.model, prop)) - - return joined + return [ + getattr(self.model, prop) + for prop, name in self._list_columns + if prop in relations + ] # AJAX foreignkey support def _create_ajax_loader(self, name, options): @@ -900,11 +886,7 @@ def _order_by(self, query, joins, sort_joins, sort_field, sort_desc): column = sort_field if alias is None else getattr(alias, sort_field.key) - if sort_desc: - query = query.order_by(desc(column)) - else: - query = query.order_by(column) - + query = query.order_by(desc(column)) if sort_desc else query.order_by(column) return query, joins def _get_default_order(self): @@ -914,21 +896,20 @@ def _get_default_order(self): yield attr, joins, direction def _apply_sorting(self, query, joins, sort_column, sort_desc): - if sort_column is not None: - if sort_column in self._sortable_columns: - sort_field = self._sortable_columns[sort_column] - sort_joins = self._sortable_joins.get(sort_column) - - if isinstance(sort_field, list): - for field_item, join_item in zip(sort_field, sort_joins): - query, joins = self._order_by(query, joins, join_item, field_item, sort_desc) - else: - query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc) - else: + if sort_column is None: order = self._get_default_order() for sort_field, sort_joins, sort_desc in order: query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc) + elif sort_column in self._sortable_columns: + sort_field = self._sortable_columns[sort_column] + sort_joins = self._sortable_joins.get(sort_column) + + if isinstance(sort_field, list): + for field_item, join_item in zip(sort_field, sort_joins): + query, joins = self._order_by(query, joins, join_item, field_item, sort_desc) + else: + query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc) return query, joins def _apply_search(self, query, count_query, joins, count_joins, search): @@ -1057,7 +1038,7 @@ def get_list(self, page, sort_column, sort_desc, search, filters, count_joins = {} query = self.get_query() - count_query = self.get_count_query() if not self.simple_list_pager else None + count_query = None if self.simple_list_pager else self.get_count_query() # Ignore eager-loaded relations (prevent unnecessary joins) # TODO: Separate join detection for query and count query? @@ -1243,11 +1224,7 @@ def action_delete(self, ids): if self.fast_mass_delete: count = query.delete(synchronize_session=False) else: - count = 0 - - for m in query.all(): - if self.delete_model(m): - count += 1 + count = sum(1 for m in query.all() if self.delete_model(m)) self.session.commit() diff --git a/flask_admin/form/__init__.py b/flask_admin/form/__init__.py index 9db32c5a4..4a019126c 100644 --- a/flask_admin/form/__init__.py +++ b/flask_admin/form/__init__.py @@ -35,7 +35,10 @@ def recreate_field(unbound): UnboundField instance """ if not isinstance(unbound, UnboundField): - raise ValueError('recreate_field expects UnboundField instance, %s was passed.' % type(unbound)) + raise ValueError( + f'recreate_field expects UnboundField instance, {type(unbound)} was passed.' + ) + return unbound.field_class(*unbound.args, **unbound.kwargs) diff --git a/flask_admin/form/rules.py b/flask_admin/form/rules.py index 4651a6632..057c7688f 100644 --- a/flask_admin/form/rules.py +++ b/flask_admin/form/rules.py @@ -82,8 +82,7 @@ def visible_fields(self): """ visible_fields = [] for rule in self.rules: - for field in rule.visible_fields: - visible_fields.append(field) + visible_fields.extend(iter(rule.visible_fields)) return visible_fields def __iter__(self): @@ -103,10 +102,7 @@ def __call__(self, form, form_opts=None, field_args={}): :param field_args: Optional arguments that should be passed to template or the field """ - result = [] - - for r in self.rules: - result.append(r(form, form_opts, field_args)) + result = [r(form, form_opts, field_args) for r in self.rules] return Markup(self.separator.join(result)) @@ -130,10 +126,7 @@ def __init__(self, text, escape=True): self.escape = escape def __call__(self, form, form_opts=None, field_args={}): - if self.escape: - return self.text - - return Markup(self.text) + return self.text if self.escape else Markup(self.text) class HTML(Text): @@ -205,10 +198,10 @@ def __call__(self, form, form_opts=None, field_args={}): macro = self._resolve(context, self.macro_name) if not macro: - raise ValueError('Cannot find macro %s in current context.' % self.macro_name) + raise ValueError(f'Cannot find macro {self.macro_name} in current context.') opts = dict(self.default_args) - opts.update(field_args) + opts |= field_args return macro(**opts) @@ -302,12 +295,12 @@ def __call__(self, form, form_opts=None, field_args={}): field = getattr(form, self.field_name, None) if field is None: - raise ValueError('Form %s does not have field %s' % (form, self.field_name)) + raise ValueError(f'Form {form} does not have field {self.field_name}') opts = {} if form_opts: - opts.update(form_opts.widget_args.get(self.field_name, {})) + opts |= form_opts.widget_args.get(self.field_name, {}) opts.update(field_args) @@ -351,11 +344,7 @@ def __init__(self, rules, header=None, separator=''): :param separator: Child rule separator """ - if header: - rule_set = [Header(header)] + list(rules) - else: - rule_set = list(rules) - + rule_set = [Header(header)] + list(rules) if header else list(rules) super(FieldSet, self).__init__(rule_set, separator=separator) @@ -425,9 +414,7 @@ def __init__(self, field_name, prepend=None, append=None, **kwargs): @property def visible_fields(self): fields = [self.field_name] - for cnf in self._addons: - if cnf['type'] == 'field': - fields.append(cnf['name']) + fields.extend(cnf['name'] for cnf in self._addons if cnf['type'] == 'field') return fields def __call__(self, form, form_opts=None, field_args={}): @@ -444,13 +431,9 @@ def __call__(self, form, form_opts=None, field_args={}): field = getattr(form, self.field_name, None) if field is None: - raise ValueError('Form %s does not have field %s' % (form, self.field_name)) - - if form_opts: - widget_args = form_opts.widget_args - else: - widget_args = {} + raise ValueError(f'Form {form} does not have field {self.field_name}') + widget_args = form_opts.widget_args if form_opts else {} opts = {} prepend = [] append = [] @@ -459,8 +442,7 @@ def __call__(self, form, form_opts=None, field_args={}): typ = cnf['type'] if typ == 'field': name = cnf['name'] - fld = form._fields.get(name, None) - if fld: + if fld := form._fields.get(name, None): w_args = widget_args.setdefault(name, {}) if fld.type in ('BooleanField', 'RadioField'): w_args.setdefault('class', 'form-check-input') @@ -484,7 +466,7 @@ def __call__(self, form, form_opts=None, field_args={}): if append: opts['append'] = Markup(''.join(append)) - opts.update(widget_args.get(self.field_name, {})) + opts |= widget_args.get(self.field_name, {}) opts.update(field_args) params = { @@ -516,8 +498,7 @@ def __init__(self, view, rules): def visible_fields(self): visible_fields = [] for rule in self.rules: - for field in rule.visible_fields: - visible_fields.append(field) + visible_fields.extend(iter(rule.visible_fields)) return visible_fields def convert_string(self, value): @@ -558,5 +539,4 @@ def __iter__(self): """ Iterate through registered rules. """ - for r in self.rules: - yield r + yield from self.rules diff --git a/flask_admin/form/upload.py b/flask_admin/form/upload.py index 009a79f3e..cb9894841 100644 --- a/flask_admin/form/upload.py +++ b/flask_admin/form/upload.py @@ -61,16 +61,21 @@ def __call__(self, field, **kwargs): else: value = field.data or '' - return Markup(template % { - 'text': html_params(type='text', - readonly='readonly', - value=value, - name=field.name), - 'file': html_params(type='file', - value=value, - **kwargs), - 'marker': '_%s-delete' % field.name - }) + return Markup( + ( + template + % { + 'text': html_params( + type='text', + readonly='readonly', + value=value, + name=field.name, + ), + 'file': html_params(type='file', value=value, **kwargs), + 'marker': f'_{field.name}-delete', + } + ) + ) class ImageUploadInput(object): @@ -95,14 +100,12 @@ def __call__(self, field, **kwargs): kwargs.setdefault('name', field.name) args = { - 'text': html_params(type='hidden', - value=field.data, - name=field.name), - 'file': html_params(type='file', - **kwargs), - 'marker': '_%s-delete' % field.name + 'text': html_params(type='hidden', value=field.data, name=field.name), + 'file': html_params(type='file', **kwargs), + 'marker': f'_{field.name}-delete', } + if field.data and isinstance(field.data, string_types): url = self.get_url(field) args['image'] = html_params(src=url) @@ -198,12 +201,15 @@ def is_file_allowed(self, filename): :param filename: File name to check """ - if not self.allowed_extensions: - return True - - return ('.' in filename and - filename.rsplit('.', 1)[1].lower() in - map(lambda x: x.lower(), self.allowed_extensions)) + return ( + ( + '.' in filename + and filename.rsplit('.', 1)[1].lower() + in map(lambda x: x.lower(), self.allowed_extensions) + ) + if self.allowed_extensions + else True + ) def _is_uploaded_file(self, data): return (data and isinstance(data, FileStorage) and data.filename) @@ -221,7 +227,7 @@ def pre_validate(self, form): def process(self, formdata, data=unset_value, extra_filters=None): if formdata: - marker = '_%s-delete' % self.name + marker = f'_{self.name}-delete' if marker in formdata: self._should_delete = True @@ -241,12 +247,10 @@ def process_formdata(self, valuelist): def populate_obj(self, obj, name): field = getattr(obj, name, None) - if field: - # If field should be deleted, clean it up - if self._should_delete: - self._delete_file(field) - setattr(obj, name, None) - return + if field and self._should_delete: + self._delete_file(field) + setattr(obj, name, None) + return if self._is_uploaded_file(self.data): if field: @@ -412,7 +416,7 @@ def pre_validate(self, form): try: self.image = Image.open(self.data) except Exception as e: - raise ValidationError('Invalid image: %s' % e) + raise ValidationError(f'Invalid image: {e}') # Deletion def _delete_file(self, filename): @@ -465,10 +469,9 @@ def _resize(self, image, size): if image.size[0] > width or image.size[1] > height: if force: return ImageOps.fit(self.image, (width, height), Image.ANTIALIAS) - else: - thumb = self.image.copy() - thumb.thumbnail((width, height), Image.ANTIALIAS) - return thumb + thumb = self.image.copy() + thumb.thumbnail((width, height), Image.ANTIALIAS) + return thumb return image @@ -485,7 +488,7 @@ def _save_image(self, image, path, format='JPEG'): def _get_save_format(self, filename, image): if image.format not in self.keep_image_formats: name, ext = op.splitext(filename) - filename = '%s.jpg' % name + filename = f'{name}.jpg' return filename, 'JPEG' return filename, image.format @@ -504,4 +507,4 @@ def thumbgen_filename(filename): Generate thumbnail name from filename. """ name, ext = op.splitext(filename) - return '%s_thumb%s' % (name, ext) + return f'{name}_thumb{ext}' diff --git a/flask_admin/helpers.py b/flask_admin/helpers.py index da6033715..ec4a8772d 100644 --- a/flask_admin/helpers.py +++ b/flask_admin/helpers.py @@ -51,10 +51,12 @@ def is_required_form_field(field): WTForms field to check """ from flask_admin.form.validators import FieldListInputRequired - for validator in field.validators: - if isinstance(validator, (DataRequired, InputRequired, FieldListInputRequired)): - return True - return False + return any( + isinstance( + validator, (DataRequired, InputRequired, FieldListInputRequired) + ) + for validator in field.validators + ) def is_form_submitted(): @@ -104,7 +106,7 @@ def is_field_error(errors): def flash_errors(form, message): from flask_admin.babel import gettext for field_name, errors in iteritems(form.errors): - errors = form[field_name].label.text + u": " + u", ".join(errors) + errors = f"{form[field_name].label.text}: " + u", ".join(errors) flash(gettext(message, error=str(errors)), 'error') diff --git a/flask_admin/menu.py b/flask_admin/menu.py index 2257d3384..a983dabb9 100644 --- a/flask_admin/menu.py +++ b/flask_admin/menu.py @@ -27,11 +27,7 @@ def is_category(self): return False def is_active(self, view): - for c in self._children: - if c.is_active(view): - return True - - return False + return any(c.is_active(view) for c in self._children) def get_class_name(self): return self.class_name @@ -63,18 +59,10 @@ def is_category(self): return True def is_visible(self): - for c in self._children: - if c.is_visible(): - return True - - return False + return any(c.is_visible() for c in self._children) def is_accessible(self): - for c in self._children: - if c.is_accessible(): - return True - - return False + return any(c.is_accessible() for c in self._children) class MenuView(BaseMenu): @@ -100,7 +88,7 @@ def get_url(self): if self._cached_url: return self._cached_url - url = self._view.get_url('%s.%s' % (self._view.endpoint, self._view._default_view)) + url = self._view.get_url(f'{self._view.endpoint}.{self._view._default_view}') if self._cache: self._cached_url = url @@ -108,22 +96,13 @@ def get_url(self): return url def is_active(self, view): - if view == self._view: - return True - - return super(MenuView, self).is_active(view) + return True if view == self._view else super(MenuView, self).is_active(view) def is_visible(self): - if self._view is None: - return False - - return self._view.is_visible() + return False if self._view is None else self._view.is_visible() def is_accessible(self): - if self._view is None: - return False - - return self._view.is_accessible() + return False if self._view is None else self._view.is_accessible() class MenuLink(BaseMenu): diff --git a/flask_admin/model/base.py b/flask_admin/model/base.py index 665e64680..01c046840 100755 --- a/flask_admin/model/base.py +++ b/flask_admin/model/base.py @@ -56,11 +56,7 @@ def __init__(self, page=None, page_size=None, sort=None, sort_desc=None, self.extra_args = extra_args or dict() def clone(self, **kwargs): - if self.filters: - flt = list(self.filters) - else: - flt = None - + flt = list(self.filters) if self.filters else None kwargs.setdefault('page', self.page) kwargs.setdefault('page_size', self.page_size) kwargs.setdefault('sort', self.sort) @@ -85,8 +81,7 @@ def non_lazy(self): for item in self.filters: copy = dict(item) copy['operation'] = as_unicode(copy['operation']) - options = copy['options'] - if options: + if options := copy['options']: copy['options'] = [(k, text_type(v)) for k, v in options] filters.append(copy) @@ -803,7 +798,7 @@ def __init__(self, model, # If name not provided, it is model name if name is None: - name = '%s' % self._prettify_class_name(model.__name__) + name = f'{self._prettify_class_name(model.__name__)}' super(BaseModelView, self).__init__(name, category, endpoint, url, static_folder, menu_class_name=menu_class_name, @@ -938,7 +933,7 @@ def _refresh_cache(self): self.column_type_formatters_detail = dict(typefmt.DETAIL_FORMATTERS) if self.column_descriptions is None: - self.column_descriptions = dict() + self.column_descriptions = {} # Filters self._refresh_filters_cache() @@ -1088,16 +1083,15 @@ def get_sortable_columns(self): """ if self.column_sortable_list is None: return self.scaffold_sortable_columns() or dict() - else: - result = dict() + result = {} - for c in self.column_sortable_list: - if isinstance(c, tuple): - result[c[0]] = c[1] - else: - result[c] = c + for c in self.column_sortable_list: + if isinstance(c, tuple): + result[c[0]] = c[1] + else: + result[c] = c - return result + return result def init_search(self): """ @@ -1150,21 +1144,18 @@ def get_filters(self): If your model backend implementation does not support filters, override this method and return `None`. """ - if self.column_filters: - collection = [] - - for n in self.column_filters: - if self.is_valid_filter(n): - collection.append(self.handle_filter(n)) - else: - flt = self.scaffold_filters(n) - if flt: - collection.extend(flt) - else: - raise Exception('Unsupported filter type %s' % n) - return collection - else: + if not self.column_filters: return None + collection = [] + + for n in self.column_filters: + if self.is_valid_filter(n): + collection.append(self.handle_filter(n)) + elif flt := self.scaffold_filters(n): + collection.extend(flt) + else: + raise Exception(f'Unsupported filter type {n}') + return collection def get_filter_arg(self, index, flt): """ @@ -1178,21 +1169,20 @@ def get_filter_arg(self, index, flt): :param flt: Filter instance """ - if self.named_filter_urls: - operation = flt.operation() + if not self.named_filter_urls: + return str(index) + operation = flt.operation() - try: - # get lazy string original value - operation = operation._args[0] - except AttributeError: - pass + try: + # get lazy string original value + operation = operation._args[0] + except AttributeError: + pass - name = ('%s %s' % (flt.name, as_unicode(operation))).lower() - name = filter_char_re.sub('', name) - name = filter_compact_re.sub('_', name) - return name - else: - return str(index) + name = f'{flt.name} {as_unicode(operation)}'.lower() + name = filter_char_re.sub('', name) + name = filter_compact_re.sub('_', name) + return name def _get_filter_groups(self): """ @@ -1241,10 +1231,7 @@ def get_form(self): Override to implement customized behavior. """ - if self.form is not None: - return self.form - - return self.scaffold_form() + return self.form if self.form is not None else self.scaffold_form() def get_list_form(self): """ @@ -1271,11 +1258,12 @@ def get_list_form(self): """ if self.form_args: # get only validators, other form_args can break FieldList wrapper - validators = dict( - (key, {'validators': value["validators"]}) + validators = { + key: {'validators': value["validators"]} for key, value in iteritems(self.form_args) if value.get("validators") - ) + } + else: validators = None @@ -1405,9 +1393,9 @@ def _get_ruleset_missing_fields(self, ruleset, form): if ruleset: visible_fields = ruleset.visible_fields - for field in form: - if field.name not in visible_fields: - missing_fields.append(field.name) + missing_fields.extend( + field.name for field in form if field.name not in visible_fields + ) return missing_fields @@ -1415,27 +1403,36 @@ def _show_missing_fields_warning(self, text): warnings.warn(text) def _validate_form_class(self, ruleset, form_class, remove_missing=True): - form_fields = [] - for name, obj in iteritems(form_class.__dict__): - if isinstance(obj, UnboundField): - form_fields.append(name) + form_fields = [ + name + for name, obj in iteritems(form_class.__dict__) + if isinstance(obj, UnboundField) + ] missing_fields = [] if ruleset: visible_fields = ruleset.visible_fields - for field_name in form_fields: - if field_name not in visible_fields: - missing_fields.append(field_name) + missing_fields.extend( + field_name + for field_name in form_fields + if field_name not in visible_fields + ) if missing_fields: - self._show_missing_fields_warning('Fields missing from ruleset: %s' % (','.join(missing_fields))) + self._show_missing_fields_warning( + f"Fields missing from ruleset: {','.join(missing_fields)}" + ) + if remove_missing: self._remove_fields_from_form_class(missing_fields, form_class) def _validate_form_instance(self, ruleset, form, remove_missing=True): missing_fields = self._get_ruleset_missing_fields(ruleset=ruleset, form=form) if missing_fields: - self._show_missing_fields_warning('Fields missing from ruleset: %s' % (','.join(missing_fields))) + self._show_missing_fields_warning( + f"Fields missing from ruleset: {','.join(missing_fields)}" + ) + if remove_missing: self._remove_fields_from_form_instance(missing_fields, form) @@ -1711,33 +1708,32 @@ def get_invalid_value_msg(self, value, filter): # URL generation helpers def _get_list_filter_args(self): - if self._filters: - filters = [] - - for arg in request.args: - if not arg.startswith('flt'): - continue + if not self._filters: + return None + filters = [] - if '_' not in arg: - continue + for arg in request.args: + if not arg.startswith('flt'): + continue - pos, key = arg[3:].split('_', 1) + if '_' not in arg: + continue - if key in self._filter_args: - idx, flt = self._filter_args[key] + pos, key = arg[3:].split('_', 1) - value = request.args[arg] + if key in self._filter_args: + idx, flt = self._filter_args[key] - if flt.validate(value): - data = (pos, (idx, as_unicode(flt.name), value)) - filters.append(data) - else: - flash(self.get_invalid_value_msg(value, flt), 'error') + value = request.args[arg] - # Sort filters - return [v[1] for v in sorted(filters, key=lambda n: n[0])] + if flt.validate(value): + data = (pos, (idx, as_unicode(flt.name), value)) + filters.append(data) + else: + flash(self.get_invalid_value_msg(value, flt), 'error') - return None + # Sort filters + return [v[1] for v in sorted(filters, key=lambda n: n[0])] def _get_list_extra_args(self): """ @@ -1788,7 +1784,7 @@ def _get_list_url(self, view_args): desc = 1 if view_args.sort_desc else None kwargs = dict(page=page, sort=view_args.sort, desc=desc, search=view_args.search) - kwargs.update(view_args.extra_args) + kwargs |= view_args.extra_args if view_args.page_size: kwargs['page_size'] = view_args.page_size @@ -1836,15 +1832,18 @@ def _get_list_value(self, context, model, name, column_formatters, else: value = self._get_field_value(model, name) - choices_map = self._column_choices_map.get(name, {}) - if choices_map: + if choices_map := self._column_choices_map.get(name, {}): return choices_map.get(value) or value - type_fmt = None - for typeobj, formatter in column_type_formatters.items(): - if isinstance(value, typeobj): - type_fmt = formatter - break + type_fmt = next( + ( + formatter + for typeobj, formatter in column_type_formatters.items() + if isinstance(value, typeobj) + ), + None, + ) + if type_fmt is not None: value = type_fmt(self, value) @@ -1912,10 +1911,7 @@ def get_export_name(self, export_type='csv'): """ :return: The exported csv file name. """ - filename = '%s_%s.%s' % (self.name, - time.strftime("%Y-%m-%d_%H-%M-%S"), - export_type) - return filename + return f'{self.name}_{time.strftime("%Y-%m-%d_%H-%M-%S")}.{export_type}' # AJAX references def _process_ajax_references(self): @@ -1932,7 +1928,7 @@ def _process_ajax_references(self): elif isinstance(options, AjaxModelLoader): result[name] = options else: - raise ValueError('%s.form_ajax_refs can not handle %s types' % (self, type(options))) + raise ValueError(f'{self}.form_ajax_refs can not handle {type(options)} types') return result @@ -2083,10 +2079,7 @@ def create_view(self): self._validate_form_instance(ruleset=self._form_create_rules, form=form) if self.validate_form(form): - # in versions 1.1.0 and before, this returns a boolean - # in later versions, this is the model itself - model = self.create_model(form) - if model: + if model := self.create_model(form): flash(gettext('Record was successfully created.'), 'success') if '_add_another' in request.form: return redirect(request.url) @@ -2138,16 +2131,15 @@ def edit_view(self): if not hasattr(form, '_validated_ruleset') or not form._validated_ruleset: self._validate_form_instance(ruleset=self._form_edit_rules, form=form) - if self.validate_form(form): - if self.update_model(form, model): - flash(gettext('Record was successfully saved.'), 'success') - if '_add_another' in request.form: - return redirect(self.get_url('.create_view', url=return_url)) - elif '_continue_editing' in request.form: - return redirect(self.get_url('.edit_view', id=self.get_pk_value(model))) - else: - # save button - return redirect(self.get_save_return_url(model, is_created=False)) + if self.validate_form(form) and self.update_model(form, model): + flash(gettext('Record was successfully saved.'), 'success') + if '_add_another' in request.form: + return redirect(self.get_url('.create_view', url=return_url)) + elif '_continue_editing' in request.form: + return redirect(self.get_url('.edit_view', id=self.get_pk_value(model))) + else: + # save button + return redirect(self.get_save_return_url(model, is_created=False)) if request.method == 'GET' or form.errors: self.on_form_prefill(form, id) @@ -2334,13 +2326,13 @@ def _export_tablib(self, export_type, return_url): filename = self.get_export_name(export_type) - disposition = 'attachment;filename=%s' % (secure_filename(filename),) + disposition = f'attachment;filename={secure_filename(filename)}' mimetype, encoding = mimetypes.guess_type(filename) if not mimetype: mimetype = 'application/octet-stream' if encoding: - mimetype = '%s; charset=%s' % (mimetype, encoding) + mimetype = f'{mimetype}; charset={encoding}' ds = tablib.Dataset(headers=[csv_encode(c[1]) for c in self._export_columns]) @@ -2394,9 +2386,7 @@ def ajax_update(self): # prevent validation issues due to submitting a single field # delete all fields except the submitted fields and csrf token for field in list(form): - if (field.name in request.form) or (field.name == 'csrf_token'): - pass - else: + if field.name not in request.form and field.name != 'csrf_token': form.__delitem__(field.name) if self.validate_form(form): @@ -2409,11 +2399,10 @@ def ajax_update(self): if self.update_model(form, record): # Success return gettext('Record was successfully saved.') - else: # Error: No records changed, or problem saving to database. - msgs = ", ".join([msg for msg in get_flashed_messages()]) - return gettext('Failed to update record. %(error)s', - error=msgs), 500 + msgs = ", ".join(list(get_flashed_messages())) + return gettext('Failed to update record. %(error)s', + error=msgs), 500 else: for field in form: for error in field.errors: diff --git a/flask_admin/model/fields.py b/flask_admin/model/fields.py index ed28c7e88..da4bc46d4 100644 --- a/flask_admin/model/fields.py +++ b/flask_admin/model/fields.py @@ -20,9 +20,7 @@ def __init__(self, *args, **kwargs): super(InlineFieldList, self).__init__(*args, **kwargs) def __call__(self, **kwargs): - # Create template - meta = getattr(self, 'meta', None) - if meta: + if meta := getattr(self, 'meta', None): template = self.unbound_field.bind(form=None, name='', _meta=meta) else: template = self.unbound_field.bind(form=None, name='') @@ -47,7 +45,7 @@ def process(self, formdata, data=unset_value, extra_filters=None): # Postprocess - contribute flag if formdata: for f in self.entries: - key = 'del-%s' % f.id + key = f'del-{f.id}' f._should_delete = key in formdata return res @@ -60,17 +58,17 @@ def validate(self, form, extra_validators=tuple()): that FieldList validates all its enclosed fields first before running any of its own validators. """ - self.errors = [] + self.errors = [ + subfield.errors + for subfield in self.entries + if not self.should_delete(subfield) and not subfield.validate(form) + ] - # Run validators on all entries within - for subfield in self.entries: - if not self.should_delete(subfield) and not subfield.validate(form): - self.errors.append(subfield.errors) chain = itertools.chain(self.validators, extra_validators) self._run_validation_chain(form, chain) - return len(self.errors) == 0 + return not self.errors def should_delete(self, field): return getattr(field, '_should_delete', False) @@ -83,7 +81,7 @@ def populate_obj(self, obj, name): ivalues = iter([]) candidates = itertools.chain(ivalues, itertools.repeat(None)) - _fake = type(str('_fake'), (object, ), {}) + _fake = type('_fake', (object, ), {}) output = [] for field, data in zip(self.entries, candidates): @@ -192,8 +190,7 @@ def __init__(self, loader, label=None, validators=None, default=None, **kwargs): self._invalid_formdata = False def _get_data(self): - formdata = self._formdata - if formdata: + if formdata := self._formdata: data = [] # TODO: Optimize? diff --git a/flask_admin/model/filters.py b/flask_admin/model/filters.py index 414392e95..47f19f6f2 100644 --- a/flask_admin/model/filters.py +++ b/flask_admin/model/filters.py @@ -36,9 +36,7 @@ def get_options(self, view): :param view: Associated administrative view class. """ - options = self.options - - if options: + if options := self.options: if callable(options): options = options() @@ -178,10 +176,7 @@ def validate(self, value): for range in value.split(' to ')] # if " to " is missing, fail validation # sqlalchemy's .between() will not work if end date is before start date - if (len(value) == 2) and (value[0] <= value[1]): - return True - else: - return False + return len(value) == 2 and value[0] <= value[1] except ValueError: return False @@ -216,10 +211,7 @@ def validate(self, value): try: value = [datetime.datetime.strptime(range, '%Y-%m-%d %H:%M:%S') for range in value.split(' to ')] - if (len(value) == 2) and (value[0] <= value[1]): - return True - else: - return False + return len(value) == 2 and value[0] <= value[1] except ValueError: return False @@ -261,13 +253,9 @@ def validate(self, value): try: timetuples = [time.strptime(range, '%H:%M:%S') for range in value.split(' to ')] - if (len(timetuples) == 2) and (timetuples[0] <= timetuples[1]): - return True - else: - return False + return len(timetuples) == 2 and timetuples[0] <= timetuples[1] except ValueError: raise - return False class BaseUuidFilter(BaseFilter): @@ -313,7 +301,7 @@ class BaseFilterConverter(object): logic. """ def __init__(self): - self.converters = dict() + self.converters = {} for p in dir(self): attr = getattr(self, p) diff --git a/flask_admin/model/form.py b/flask_admin/model/form.py index 63da1ae1b..9f9d53641 100644 --- a/flask_admin/model/form.py +++ b/flask_admin/model/form.py @@ -77,10 +77,7 @@ def __init__(self, **kwargs): for k, v in iteritems(kwargs): setattr(self, k, v) - # Convert form rules - form_rules = getattr(self, 'form_rules', None) - - if form_rules: + if form_rules := getattr(self, 'form_rules', None): self._form_rules = rules.RuleSet(self, form_rules) else: self._form_rules = None @@ -175,17 +172,19 @@ def get_converter(self, column): # Search by module + name for col_type in types: - type_string = '%s.%s' % (col_type.__module__, col_type.__name__) + type_string = f'{col_type.__module__}.{col_type.__name__}' if type_string in self.converters: return self.converters[type_string] - # Search by name - for col_type in types: - if col_type.__name__ in self.converters: - return self.converters[col_type.__name__] - - return None + return next( + ( + self.converters[col_type.__name__] + for col_type in types + if col_type.__name__ in self.converters + ), + None, + ) def get_form(self, model, base_class=BaseForm, only=None, exclude=None, @@ -214,8 +213,7 @@ def get_label(self, info, name): :param name: Field name """ - form_name = getattr(info, 'form_label', None) - if form_name: + if form_name := getattr(info, 'form_label', None): return form_name column_labels = getattr(self.view, 'column_labels', None) diff --git a/flask_admin/model/template.py b/flask_admin/model/template.py index 495e4199a..8ecd98897 100644 --- a/flask_admin/model/template.py +++ b/flask_admin/model/template.py @@ -14,12 +14,11 @@ def render_ctx(self, context, row_id, row): return self.render(context, row_id, row) def _resolve_symbol(self, context, symbol): - if '.' in symbol: - parts = symbol.split('.') - m = context.resolve(parts[0]) - return reduce(getattr, parts[1:], m) - else: + if '.' not in symbol: return context.resolve(symbol) + parts = symbol.split('.') + m = context.resolve(parts[0]) + return reduce(getattr, parts[1:], m) class LinkRowAction(BaseListRowAction): diff --git a/flask_admin/model/widgets.py b/flask_admin/model/widgets.py index e906790dc..18d2619d5 100644 --- a/flask_admin/model/widgets.py +++ b/flask_admin/model/widgets.py @@ -52,12 +52,9 @@ def __call__(self, field, **kwargs): kwargs['value'] = separator.join(ids) kwargs['data-json'] = json.dumps(result) kwargs['data-multiple'] = u'1' - else: - data = field.loader.format(field.data) - - if data: - kwargs['value'] = data[0] - kwargs['data-json'] = json.dumps(data) + elif data := field.loader.format(field.data): + kwargs['value'] = data[0] + kwargs['data-json'] = json.dumps(data) placeholder = field.loader.options.get('placeholder', gettext('Please select model')) kwargs.setdefault('data-placeholder', placeholder) @@ -65,7 +62,7 @@ def __call__(self, field, **kwargs): minimum_input_length = int(field.loader.options.get('minimum_input_length', 1)) kwargs.setdefault('data-minimum-input-length', minimum_input_length) - return Markup('' % html_params(name=field.name, **kwargs)) + return Markup(f'') class XEditableWidget(object): @@ -94,10 +91,7 @@ def __call__(self, field, **kwargs): kwargs = self.get_kwargs(field, kwargs) - return Markup( - '%s' % (html_params(**kwargs), - escape(display_value)) - ) + return Markup(f'{escape(display_value)}') def get_kwargs(self, field, kwargs): """ @@ -177,6 +171,6 @@ def get_kwargs(self, field, kwargs): else: kwargs['data-value'] = text_type(selected_ids[0]) else: - raise Exception('Unsupported field type: %s' % (type(field),)) + raise Exception(f'Unsupported field type: {type(field)}') return kwargs diff --git a/flask_admin/tests/fileadmin/test_fileadmin_azure.py b/flask_admin/tests/fileadmin/test_fileadmin_azure.py index 31a5be066..82109b864 100644 --- a/flask_admin/tests/fileadmin/test_fileadmin_azure.py +++ b/flask_admin/tests/fileadmin/test_fileadmin_azure.py @@ -16,7 +16,7 @@ def setUp(self): if not azure.BlockBlobService: raise SkipTest('AzureFileAdmin dependencies not installed') - self._container_name = 'fileadmin-tests-%s' % uuid4() + self._container_name = f'fileadmin-tests-{uuid4()}' if not self._test_storage or not self._container_name: raise SkipTest('AzureFileAdmin test credentials not set') diff --git a/flask_admin/tests/geoa/test_basic.py b/flask_admin/tests/geoa/test_basic.py index d6718dcda..7cbb09ae8 100644 --- a/flask_admin/tests/geoa/test_basic.py +++ b/flask_admin/tests/geoa/test_basic.py @@ -89,12 +89,12 @@ def test_model(): html = rv.data.decode('utf-8') pattern = r'(.|\n)+({.*"type": ?"Point".*})(.|\n)+' - group = re.match(pattern, html).group(2) + group = re.match(pattern, html)[2] p = json.loads(group) assert p['coordinates'][0] == 125.8 assert p['coordinates'][1] == 10.0 - url = '/admin/geomodel/edit/?id=%s' % model.id + url = f'/admin/geomodel/edit/?id={model.id}' rv = client.get(url) assert rv.status_code == 200 data = rv.data.decode('utf-8') @@ -121,7 +121,7 @@ def test_model(): # assert list(to_shape(model.multi).geoms[0].coords) == [(100.0, 0.0]) # assert list(to_shape(model.multi).geoms[1].coords) == [(101.0, 1.0]) - url = '/admin/geomodel/delete/?id=%s' % model.id + url = f'/admin/geomodel/delete/?id={model.id}' rv = client.post(url) assert rv.status_code == 302 assert db.session.query(GeoModel).count() == 0 @@ -147,7 +147,7 @@ def test_none(): model = db.session.query(GeoModel).first() - url = '/admin/geomodel/edit/?id=%s' % model.id + url = f'/admin/geomodel/edit/?id={model.id}' rv = client.get(url) assert rv.status_code == 200 data = rv.data.decode('utf-8') diff --git a/flask_admin/tests/mongoengine/test_basic.py b/flask_admin/tests/mongoengine/test_basic.py index b977d23b5..8967ec1dc 100644 --- a/flask_admin/tests/mongoengine/test_basic.py +++ b/flask_admin/tests/mongoengine/test_basic.py @@ -123,7 +123,7 @@ def test_model(): assert rv.status_code == 200 assert b'test1large' in rv.data - url = '/admin/model1/edit/?id=%s' % model.id + url = f'/admin/model1/edit/?id={model.id}' rv = client.get(url) assert rv.status_code == 200 @@ -137,7 +137,7 @@ def test_model(): assert model.test3 == '' assert model.test4 == '' - url = '/admin/model1/delete/?id=%s' % model.id + url = f'/admin/model1/delete/?id={model.id}' rv = client.post(url) assert rv.status_code == 302 assert Model1.objects.count() == 0 @@ -253,12 +253,12 @@ def test_details_view(): assert '/admin/model2/details/' in data # test redirection when details are disabled - url = '/admin/model1/details/?url=%2Fadmin%2Fmodel1%2F&id=' + str(m1_id) + url = f'/admin/model1/details/?url=%2Fadmin%2Fmodel1%2F&id={str(m1_id)}' rv = client.get(url) assert rv.status_code == 302 # test if correct data appears in details view when enabled - url = '/admin/model2/details/?url=%2Fadmin%2Fmodel2%2F&id=' + str(m2_id) + url = f'/admin/model2/details/?url=%2Fadmin%2Fmodel2%2F&id={str(m2_id)}' rv = client.get(url) data = rv.data.decode('utf-8') assert 'String Field' in data @@ -266,7 +266,7 @@ def test_details_view(): assert 'Int Field' in data # test column_details_list - url = '/admin/sf_view/details/?url=%2Fadmin%2Fsf_view%2F&id=' + str(m2_id) + url = f'/admin/sf_view/details/?url=%2Fadmin%2Fsf_view%2F&id={str(m2_id)}' rv = client.get(url) data = rv.data.decode('utf-8') assert 'String Field' in data @@ -1185,7 +1185,7 @@ def test_export_csv(): endpoint='row_limit_2') admin.add_view(view) - for x in range(5): + for _ in range(5): fill_db(Model1, Model2) client = app.test_client() diff --git a/flask_admin/tests/peeweemodel/test_basic.py b/flask_admin/tests/peeweemodel/test_basic.py index fa7f97645..4ad062d30 100644 --- a/flask_admin/tests/peeweemodel/test_basic.py +++ b/flask_admin/tests/peeweemodel/test_basic.py @@ -154,7 +154,7 @@ def test_model(): assert rv.status_code == 200 assert b'test1large' in rv.data - url = '/admin/model1/edit/?id=%s' % model.id + url = f'/admin/model1/edit/?id={model.id}' rv = client.get(url) assert rv.status_code == 200 @@ -168,7 +168,7 @@ def test_model(): assert model.test3 is None or model.test3 == '' assert model.test4 is None or model.test4 == '' - url = '/admin/model1/delete/?id=%s' % model.id + url = f'/admin/model1/delete/?id={model.id}' rv = client.post(url) assert rv.status_code == 302 assert Model1.select().count() == 0 @@ -1045,7 +1045,7 @@ def test_export_csv(): endpoint='row_limit_2') admin.add_view(view) - for x in range(5): + for _ in range(5): fill_db(Model1, Model2) client = app.test_client() diff --git a/flask_admin/tests/pymongo/test_basic.py b/flask_admin/tests/pymongo/test_basic.py index c7d598607..d21dc9748 100644 --- a/flask_admin/tests/pymongo/test_basic.py +++ b/flask_admin/tests/pymongo/test_basic.py @@ -61,7 +61,7 @@ def test_model(): assert rv.status_code == 200 assert 'test1large' in rv.data.decode('utf-8') - url = '/admin/testview/edit/?id=%s' % model['_id'] + url = f"/admin/testview/edit/?id={model['_id']}" rv = client.get(url) assert rv.status_code == 200 @@ -75,7 +75,7 @@ def test_model(): assert model['test1'] == 'test1small' assert model['test2'] == 'test2large' - url = '/admin/testview/delete/?id=%s' % model['_id'] + url = f"/admin/testview/delete/?id={model['_id']}" rv = client.post(url) assert rv.status_code == 302 assert db.test.count() == 0 diff --git a/flask_admin/tests/sqla/test_basic.py b/flask_admin/tests/sqla/test_basic.py index 3aaba3325..ce919ed54 100644 --- a/flask_admin/tests/sqla/test_basic.py +++ b/flask_admin/tests/sqla/test_basic.py @@ -254,7 +254,7 @@ def test_model(): assert u'test1large' in rv.data.decode('utf-8') # check that we can retrieve an edit view - url = '/admin/model1/edit/?id=%s' % model.id + url = f'/admin/model1/edit/?id={model.id}' rv = client.get(url) assert rv.status_code == 200 @@ -299,7 +299,7 @@ def test_model(): assert model.sqla_utils_color is None # check that the model can be deleted - url = '/admin/model1/delete/?id=%s' % model.id + url = f'/admin/model1/delete/?id={model.id}' rv = client.post(url) assert rv.status_code == 302 assert db.session.query(Model1).count() == 0 @@ -1583,8 +1583,8 @@ def number_of_pixels_str(self): return str(self.number_of_pixels()) @number_of_pixels_str.expression - def number_of_pixels_str(cls): - return cast(cls.width * cls.height, db.String) + def number_of_pixels_str(self): + return cast(self.width * self.height, db.String) db.create_all() @@ -1643,7 +1643,7 @@ class Model1(db.Model): @hybrid_property def fullname(self): - return '{} {}'.format(self.firstname, self.lastname) + return f'{self.firstname} {self.lastname}' class Model2(db.Model): id = db.Column(db.Integer, primary_key=True) @@ -2605,7 +2605,7 @@ def test_export_csv(): app, db, admin = setup() Model1, Model2 = create_models(db) - for x in range(5): + for _ in range(5): fill_db(db, Model1, Model2) view = CustomModelView(Model1, db.session, can_export=True, diff --git a/flask_admin/tests/test_base.py b/flask_admin/tests/test_base.py index ea0d69d2e..d2848c82a 100644 --- a/flask_admin/tests/test_base.py +++ b/flask_admin/tests/test_base.py @@ -29,16 +29,10 @@ def _handle_view(self, name, **kwargs): return 'Failure!' def is_accessible(self): - if self.allow_access: - return super(MockView, self).is_accessible() - - return False + return super(MockView, self).is_accessible() if self.allow_access else False def is_visible(self): - if self.visible: - return super(MockView, self).is_visible() - - return False + return super(MockView, self).is_visible() if self.visible else False class MockMethodView(base.BaseView): @@ -465,7 +459,7 @@ def check_class_name(): def check_endpoint(): class CustomView(MockView): def _get_endpoint(self, endpoint): - return 'admin.' + super(CustomView, self)._get_endpoint(endpoint) + return f'admin.{super(CustomView, self)._get_endpoint(endpoint)}' view = CustomView() assert view.endpoint == 'admin.customview' diff --git a/flask_admin/tests/test_model.py b/flask_admin/tests/test_model.py index eee8c401e..c35c678c6 100644 --- a/flask_admin/tests/test_model.py +++ b/flask_admin/tests/test_model.py @@ -62,11 +62,7 @@ def __init__(self, model, data=None, name=None, category=None, self.search_arguments = [] - if data is None: - self.all_models = {1: Model(1), 2: Model(2)} - else: - self.all_models = data - + self.all_models = {1: Model(1), 2: Model(2)} if data is None else data self.last_id = len(self.all_models) + 1 # Scaffolding @@ -355,8 +351,7 @@ def scaffold_form(self): def get_csrf_token(data): data = data.split('name="csrf_token" type="hidden" value="')[1] - token = data.split('"')[0] - return token + return data.split('"')[0] app, admin = setup() diff --git a/flask_admin/tools.py b/flask_admin/tools.py index b0533aae2..9933d3d05 100644 --- a/flask_admin/tools.py +++ b/flask_admin/tools.py @@ -94,11 +94,14 @@ def get_dict_attr(obj, attr, default=None): :param default: Default value if attribute was not found """ - for obj in [obj] + obj.__class__.mro(): - if attr in obj.__dict__: - return obj.__dict__[attr] - - return default + return next( + ( + obj.__dict__[attr] + for obj in [obj] + obj.__class__.mro() + if attr in obj.__dict__ + ), + default, + ) def escape(value): @@ -134,17 +137,16 @@ def iterdecode(value): escaped = False for c in value: - if not escaped: - if c == CHAR_ESCAPE: - escaped = True - continue - elif c == CHAR_SEPARATOR: - result.append(accumulator) - accumulator = u'' - continue - else: + if escaped: escaped = False + elif c == CHAR_ESCAPE: + escaped = True + continue + elif c == CHAR_SEPARATOR: + result.append(accumulator) + accumulator = u'' + continue accumulator += c result.append(accumulator)