diff --git a/examples/auth-flask-login/app.py b/examples/auth-flask-login/app.py
index b29803841..587f53329 100644
--- a/examples/auth-flask-login/app.py
+++ b/examples/auth-flask-login/app.py
@@ -108,9 +108,11 @@ class MyAdminIndexView(admin.AdminIndexView):
@expose('/')
def index(self):
- if not login.current_user.is_authenticated:
- return redirect(url_for('.login_view'))
- return super(MyAdminIndexView, self).index()
+ return (
+ super(MyAdminIndexView, self).index()
+ if login.current_user.is_authenticated
+ else redirect(url_for('.login_view'))
+ )
@expose('/login/', methods=('GET', 'POST'))
def login_view(self):
@@ -201,8 +203,14 @@ def build_sample_db():
user.first_name = first_names[i]
user.last_name = last_names[i]
user.login = user.first_name.lower()
- user.email = user.login + "@example.com"
- user.password = generate_password_hash(''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(10)))
+ user.email = f"{user.login}@example.com"
+ user.password = generate_password_hash(
+ ''.join(
+ random.choice(string.ascii_lowercase + string.digits)
+ for _ in range(10)
+ )
+ )
+
db.session.add(user)
db.session.commit()
diff --git a/examples/auth/app.py b/examples/auth/app.py
index f9cf7bc75..ebe0e2eac 100644
--- a/examples/auth/app.py
+++ b/examples/auth/app.py
@@ -138,8 +138,12 @@ def build_sample_db():
]
for i in range(len(first_names)):
- tmp_email = first_names[i].lower() + "." + last_names[i].lower() + "@example.com"
- tmp_pass = ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(10))
+ tmp_email = f"{first_names[i].lower()}.{last_names[i].lower()}@example.com"
+ tmp_pass = ''.join(
+ random.choice(string.ascii_lowercase + string.digits)
+ for _ in range(10)
+ )
+
user_datastore.create_user(
first_name=first_names[i],
last_name=last_names[i],
diff --git a/examples/auth/config.py b/examples/auth/config.py
index c25c011ff..3f199a33f 100644
--- a/examples/auth/config.py
+++ b/examples/auth/config.py
@@ -3,7 +3,7 @@
# Create in-memory database
DATABASE_FILE = 'sample_db.sqlite'
-SQLALCHEMY_DATABASE_URI = 'sqlite:///' + DATABASE_FILE
+SQLALCHEMY_DATABASE_URI = f'sqlite:///{DATABASE_FILE}'
SQLALCHEMY_ECHO = True
# Flask-Security config
diff --git a/examples/babel/app.py b/examples/babel/app.py
index 8275be64e..c0950bb58 100644
--- a/examples/babel/app.py
+++ b/examples/babel/app.py
@@ -23,9 +23,7 @@
@babel.localeselector
def get_locale():
- override = request.args.get('lang')
-
- if override:
+ if override := request.args.get('lang'):
session['lang'] = override
return session.get('lang', 'en')
@@ -58,7 +56,7 @@ def __unicode__(self):
# Flask views
@app.route('/')
def index():
- tmp = u"""
+ return u"""
Click me to get to Admin! (English)
Click me to get to Admin! (Czech)
Click me to get to Admin! (German)
@@ -71,7 +69,6 @@ def index():
Click me to get to Admin! (Chinese - Simplified)
Click me to get to Admin! (Chinese - Traditional)
"""
- return tmp
if __name__ == '__main__':
# Create admin
diff --git a/examples/bootstrap4/app.py b/examples/bootstrap4/app.py
index bd766c17d..700810bb2 100644
--- a/examples/bootstrap4/app.py
+++ b/examples/bootstrap4/app.py
@@ -92,8 +92,8 @@ def build_sample_db():
for i in range(len(first_names)):
user = User()
- user.name = first_names[i] + " " + last_names[i]
- user.email = first_names[i].lower() + "@example.com"
+ user.name = f"{first_names[i]} {last_names[i]}"
+ user.email = f"{first_names[i].lower()}@example.com"
db.session.add(user)
sample_text = [
diff --git a/examples/custom-layout/app.py b/examples/custom-layout/app.py
index 9660c9723..e8bcf8419 100644
--- a/examples/custom-layout/app.py
+++ b/examples/custom-layout/app.py
@@ -86,8 +86,8 @@ def build_sample_db():
for i in range(len(first_names)):
user = User()
- user.name = first_names[i] + " " + last_names[i]
- user.email = first_names[i].lower() + "@example.com"
+ user.name = f"{first_names[i]} {last_names[i]}"
+ user.email = f"{first_names[i].lower()}@example.com"
db.session.add(user)
sample_text = [
diff --git a/examples/forms-files-images/app.py b/examples/forms-files-images/app.py
index e39136c0f..3a2e6467b 100644
--- a/examples/forms-files-images/app.py
+++ b/examples/forms-files-images/app.py
@@ -111,7 +111,7 @@ class CKTextAreaWidget(widgets.TextArea):
def __call__(self, field, **kwargs):
# add WYSIWYG class to existing classes
existing_classes = kwargs.pop('class', '') or kwargs.pop('class_', '')
- kwargs['class'] = '{} {}'.format(existing_classes, "ckeditor")
+ kwargs['class'] = f'{existing_classes} ckeditor'
return super(CKTextAreaWidget, self).__call__(field, **kwargs)
@@ -144,12 +144,15 @@ class FileView(sqla.ModelView):
class ImageView(sqla.ModelView):
- def _list_thumbnail(view, context, model, name):
- if not model.path:
- return ''
-
- return Markup('' % url_for('static',
- filename=form.thumbgen_filename(model.path)))
+ def _list_thumbnail(self, context, model, name):
+ return (
+ Markup(
+ ''
+ % url_for('static', filename=form.thumbgen_filename(model.path))
+ )
+ if model.path
+ else ''
+ )
column_formatters = {
'path': _list_thumbnail
@@ -256,9 +259,9 @@ def build_sample_db():
user = User()
user.first_name = first_names[i]
user.last_name = last_names[i]
- user.email = user.first_name.lower() + "@example.com"
- tmp = ''.join(random.choice(string.digits) for i in range(10))
- user.phone = "(" + tmp[0:3] + ") " + tmp[3:6] + " " + tmp[6::]
+ user.email = f"{user.first_name.lower()}@example.com"
+ tmp = ''.join(random.choice(string.digits) for _ in range(10))
+ user.phone = "(" + tmp[:3] + ") " + tmp[3:6] + " " + tmp[6::]
user.city = locations[i][0]
user.country = locations[i][1]
db.session.add(user)
@@ -267,13 +270,13 @@ def build_sample_db():
for name in images:
image = Image()
image.name = name
- image.path = name.lower() + ".jpg"
+ image.path = f"{name.lower()}.jpg"
db.session.add(image)
for i in [1, 2, 3]:
file = File()
- file.name = "Example " + str(i)
- file.path = "example_" + str(i) + ".pdf"
+ file.name = f"Example {str(i)}"
+ file.path = f"example_{str(i)}.pdf"
db.session.add(file)
sample_text = "This is a test
" + \
diff --git a/examples/peewee/app.py b/examples/peewee/app.py
index de0113575..cdeb042a9 100644
--- a/examples/peewee/app.py
+++ b/examples/peewee/app.py
@@ -32,7 +32,7 @@ class UserInfo(BaseModel):
user = peewee.ForeignKeyField(User)
def __unicode__(self):
- return '%s - %s' % (self.key, self.value)
+ return f'{self.key} - {self.value}'
class Post(BaseModel):
diff --git a/examples/pymongo/app.py b/examples/pymongo/app.py
index dc66b28b0..519108e2a 100644
--- a/examples/pymongo/app.py
+++ b/examples/pymongo/app.py
@@ -77,7 +77,7 @@ def get_list(self, *args, **kwargs):
users = db.user.find(query, fields=('name',))
# Contribute user names to the models
- users_map = dict((x['_id'], x['name']) for x in users)
+ users_map = {x['_id']: x['name'] for x in users}
for item in data:
item['user_name'] = users_map.get(item['user_id'])
diff --git a/examples/sqla-association_proxy/app.py b/examples/sqla-association_proxy/app.py
index 8d8e32ee9..f2c814771 100644
--- a/examples/sqla-association_proxy/app.py
+++ b/examples/sqla-association_proxy/app.py
@@ -67,7 +67,7 @@ def __init__(self, keyword=None):
self.keyword = keyword
def __repr__(self):
- return 'Keyword(%s)' % repr(self.keyword)
+ return f'Keyword({repr(self.keyword)})'
class UserAdmin(sqla.ModelView):
diff --git a/examples/sqla-custom-inline-forms/app.py b/examples/sqla-custom-inline-forms/app.py
index 659d80d74..19563a4be 100644
--- a/examples/sqla-custom-inline-forms/app.py
+++ b/examples/sqla-custom-inline-forms/app.py
@@ -89,9 +89,7 @@ def postprocess_form(self, form_class):
return form_class
def on_model_change(self, form, model):
- file_data = request.files.get(form.upload.name)
-
- if file_data:
+ if file_data := request.files.get(form.upload.name):
model.path = secure_filename(file_data.filename)
file_data.save(op.join(base_path, model.path))
diff --git a/examples/sqla/admin/__init__.py b/examples/sqla/admin/__init__.py
index 99c979b39..cabdbd502 100644
--- a/examples/sqla/admin/__init__.py
+++ b/examples/sqla/admin/__init__.py
@@ -13,9 +13,7 @@
@babel.localeselector
def get_locale():
- override = request.args.get('lang')
-
- if override:
+ if override := request.args.get('lang'):
session['lang'] = override
return session.get('lang', 'en')
diff --git a/examples/sqla/admin/config.py b/examples/sqla/admin/config.py
index c685d83a0..86c9b7541 100644
--- a/examples/sqla/admin/config.py
+++ b/examples/sqla/admin/config.py
@@ -7,6 +7,6 @@
# Create in-memory database
DATABASE_FILE = 'sample_db.sqlite'
-SQLALCHEMY_DATABASE_URI = 'sqlite:///' + DATABASE_FILE
+SQLALCHEMY_DATABASE_URI = f'sqlite:///{DATABASE_FILE}'
SQLALCHEMY_ECHO = True
SQLALCHEMY_TRACK_MODIFICATIONS = False
diff --git a/examples/sqla/admin/data.py b/examples/sqla/admin/data.py
index 0cb461356..44f146c38 100644
--- a/examples/sqla/admin/data.py
+++ b/examples/sqla/admin/data.py
@@ -41,7 +41,7 @@ def build_sample_db():
user.type = random.choice(AVAILABLE_USER_TYPES)[0]
user.first_name = first_names[i]
user.last_name = last_names[i]
- user.email = first_names[i].lower() + "@example.com"
+ user.email = f"{first_names[i].lower()}@example.com"
user.website = "https://www.example.com"
user.ip_address = "127.0.0.1"
@@ -107,7 +107,7 @@ def build_sample_db():
entry = random.choice(sample_text) # select text at random
post = Post()
post.user = user
- post.title = "{}'s opinion on {}".format(user.first_name, entry['title'])
+ post.title = f"{user.first_name}'s opinion on {entry['title']}"
post.text = entry['content']
post.background_color = random.choice(["#cccccc", "red", "lightblue", "#0f0"])
tmp = int(1000 * random.random()) # random number between 0 and 1000:
@@ -120,12 +120,12 @@ def build_sample_db():
db.session.add(trunk)
for i in range(5):
branch = Tree()
- branch.name = "Branch " + str(i + 1)
+ branch.name = f"Branch {str(i + 1)}"
branch.parent = trunk
db.session.add(branch)
for j in range(5):
leaf = Tree()
- leaf.name = "Leaf " + str(j + 1)
+ leaf.name = f"Leaf {str(j + 1)}"
leaf.parent = branch
db.session.add(leaf)
diff --git a/examples/sqla/admin/main.py b/examples/sqla/admin/main.py
index 3fc7878a4..60a9c47df 100644
--- a/examples/sqla/admin/main.py
+++ b/examples/sqla/admin/main.py
@@ -15,7 +15,7 @@
# Flask views
@app.route('/')
def index():
- tmp = u"""
+ return u"""
Click me to get to Admin! (English)
Click me to get to Admin! (Czech)
Click me to get to Admin! (German)
@@ -28,7 +28,6 @@ def index():
Click me to get to Admin! (Chinese - Simplified)
Click me to get to Admin! (Chinese - Traditional)
"""
- return tmp
@app.route('/favicon.ico')
@@ -50,7 +49,11 @@ def operation(self):
# Customized User model admin
def phone_number_formatter(view, context, model, name):
- return Markup("{}".format(model.phone_number)) if model.phone_number else None
+ return (
+ Markup(f"{model.phone_number}")
+ if model.phone_number
+ else None
+ )
def is_numberic_validator(form, field):
diff --git a/examples/sqla/admin/models.py b/examples/sqla/admin/models.py
index 25d1f1a27..0dac1bf25 100644
--- a/examples/sqla/admin/models.py
+++ b/examples/sqla/admin/models.py
@@ -54,18 +54,21 @@ class User(db.Model):
def phone_number(self):
if self.dialling_code and self.local_phone_number:
number = str(self.local_phone_number)
- return "+{} ({}) {} {} {}".format(self.dialling_code, number[0], number[1:3], number[3:6], number[6::])
+ return f"+{self.dialling_code} ({number[0]}) {number[1:3]} {number[3:6]} {number[6::]}"
+
return
@phone_number.expression
- def phone_number(cls):
- return sql.operators.ColumnOperators.concat(cast(cls.dialling_code, db.String), cls.local_phone_number)
+ def phone_number(self):
+ return sql.operators.ColumnOperators.concat(
+ cast(self.dialling_code, db.String), self.local_phone_number
+ )
def __str__(self):
- return "{}, {}".format(self.last_name, self.first_name)
+ return f"{self.last_name}, {self.first_name}"
def __repr__(self):
- return "{}: {}".format(self.id, self.__str__())
+ return f"{self.id}: {self.__str__()}"
# Create M2M table
@@ -90,7 +93,7 @@ class Post(db.Model):
tags = db.relationship('Tag', secondary=post_tags_table)
def __str__(self):
- return "{}".format(self.title)
+ return f"{self.title}"
class Tag(db.Model):
@@ -98,7 +101,7 @@ class Tag(db.Model):
name = db.Column(db.Unicode(64), unique=True)
def __str__(self):
- return "{}".format(self.name)
+ return f"{self.name}"
class Tree(db.Model):
@@ -110,4 +113,4 @@ class Tree(db.Model):
parent = db.relationship('Tree', remote_side=[id], backref='children')
def __str__(self):
- return "{}".format(self.name)
+ return f"{self.name}"
diff --git a/flask_admin/_backwards.py b/flask_admin/_backwards.py
index 0f1a2a49b..efbc73a4d 100644
--- a/flask_admin/_backwards.py
+++ b/flask_admin/_backwards.py
@@ -29,8 +29,11 @@ def get_property(obj, name, old_name, default=None):
Default value
"""
if hasattr(obj, old_name):
- warnings.warn('Property %s is obsolete, please use %s instead' %
- (old_name, name), stacklevel=2)
+ warnings.warn(
+ f'Property {old_name} is obsolete, please use {name} instead',
+ stacklevel=2,
+ )
+
return getattr(obj, old_name)
return getattr(obj, name, default)
@@ -40,7 +43,7 @@ class ObsoleteAttr(object):
def __init__(self, new_name, old_name, default):
self.new_name = new_name
self.old_name = old_name
- self.cache = '_cache_' + new_name
+ self.cache = f'_cache_{new_name}'
self.default = default
def __get__(self, obj, objtype=None):
@@ -53,8 +56,11 @@ def __get__(self, obj, objtype=None):
# Check if there's old attribute
if hasattr(obj, self.old_name):
- warnings.warn('Property %s is obsolete, please use %s instead' %
- (self.old_name, self.new_name), stacklevel=2)
+ warnings.warn(
+ f'Property {self.old_name} is obsolete, please use {self.new_name} instead',
+ stacklevel=2,
+ )
+
return getattr(obj, self.old_name)
# Return default otherwise
diff --git a/flask_admin/_compat.py b/flask_admin/_compat.py
index 610d29301..84f71370e 100644
--- a/flask_admin/_compat.py
+++ b/flask_admin/_compat.py
@@ -27,10 +27,7 @@
filter_list = lambda f, l: list(filter(f, l))
def as_unicode(s):
- if isinstance(s, bytes):
- return s.decode('utf-8')
-
- return str(s)
+ return s.decode('utf-8') if isinstance(s, bytes) else str(s)
def csv_encode(s):
''' Returns unicode string expected by Python 3's csv module '''
@@ -50,10 +47,7 @@ def csv_encode(s):
filter_list = filter
def as_unicode(s):
- if isinstance(s, str):
- return s.decode('utf-8')
-
- return unicode(s)
+ return s.decode('utf-8') if isinstance(s, str) else unicode(s)
def csv_encode(s):
''' Returns byte string expected by Python 2's csv module '''
diff --git a/flask_admin/actions.py b/flask_admin/actions.py
index f217fa0c4..4bed96132 100644
--- a/flask_admin/actions.py
+++ b/flask_admin/actions.py
@@ -89,8 +89,7 @@ def get_actions_list(self):
if self.is_action_allowed(name):
actions.append((name, text_type(text)))
- confirmation = self._actions_data[name][2]
- if confirmation:
+ if confirmation := self._actions_data[name][2]:
actions_confirmation[name] = text_type(confirmation)
return actions, actions_confirmation
@@ -122,7 +121,7 @@ def handle_action(self, return_view=None):
flash_errors(form, message='Failed to perform action. %(error)s')
if return_view:
- url = self.get_url('.' + return_view)
+ url = self.get_url(f'.{return_view}')
else:
url = get_redirect_target() or self.get_url('.index_view')
diff --git a/flask_admin/base.py b/flask_admin/base.py
index e706c916e..9bb0963f7 100644
--- a/flask_admin/base.py
+++ b/flask_admin/base.py
@@ -42,10 +42,7 @@ def expose_plugview(url='/'):
def wrap(v):
handler = expose(url, v.methods)
- if hasattr(v, 'as_view'):
- return handler(v.as_view(v.__name__))
- else:
- return handler(v)
+ return handler(v.as_view(v.__name__)) if hasattr(v, 'as_view') else handler(v)
return wrap
@@ -80,26 +77,26 @@ class AdminViewMeta(type):
Does some precalculations (like getting list of view methods from the class) to avoid
calculating them for each view class instance.
"""
- def __init__(cls, classname, bases, fields):
- type.__init__(cls, classname, bases, fields)
+ def __init__(self, classname, bases, fields):
+ type.__init__(self, classname, bases, fields)
# Gather exposed views
- cls._urls = []
- cls._default_view = None
+ self._urls = []
+ self._default_view = None
- for p in dir(cls):
- attr = getattr(cls, p)
+ for p in dir(self):
+ attr = getattr(self, p)
if hasattr(attr, '_urls'):
# Collect methods
for url, methods in attr._urls:
- cls._urls.append((url, p, methods))
+ self._urls.append((url, p, methods))
if url == '/':
- cls._default_view = p
+ self._default_view = p
# Wrap views
- setattr(cls, p, _wrap_view(attr))
+ setattr(self, p, _wrap_view(attr))
class BaseViewClass(object):
@@ -205,33 +202,28 @@ def __init__(self, name=None, category=None, endpoint=None, url=None,
# Default view
if self._default_view is None:
- raise Exception(u'Attempted to instantiate admin view %s without default view' % self.__class__.__name__)
+ raise Exception(
+ f'Attempted to instantiate admin view {self.__class__.__name__} without default view'
+ )
def _get_endpoint(self, endpoint):
"""
Generate Flask endpoint name. By default converts class name to lower case if endpoint is
not explicitly provided.
"""
- if endpoint:
- return endpoint
-
- return self.__class__.__name__.lower()
+ return endpoint or self.__class__.__name__.lower()
def _get_view_url(self, admin, url):
"""
Generate URL for the view. Override to change default behavior.
"""
if url is None:
- if admin.url != '/':
- url = '%s/%s' % (admin.url, self.endpoint)
+ if admin.url == '/':
+ url = '/' if self == admin.index_view else f'/{self.endpoint}'
else:
- if self == admin.index_view:
- url = '/'
- else:
- url = '/%s' % self.endpoint
- else:
- if not url.startswith('/'):
- url = '%s/%s' % (admin.url, url)
+ url = f'{admin.url}/{self.endpoint}'
+ elif not url.startswith('/'):
+ url = f'{admin.url}/{url}'
return url
@@ -391,10 +383,7 @@ def get_url(self, endpoint, **kwargs):
@property
def _debug(self):
- if not self.admin or not self.admin.app:
- return False
-
- return self.admin.app.debug
+ return False if not self.admin or not self.admin.app else self.admin.app.debug
class AdminIndexView(BaseView):
@@ -502,7 +491,7 @@ def __init__(self, app=None, name=None,
self._views = []
self._menu = []
- self._menu_categories = dict()
+ self._menu_categories = {}
self._menu_links = []
if name is None:
@@ -722,7 +711,7 @@ def init_app(self, app, index_view=None,
def _init_extension(self):
if not hasattr(self.app, 'extensions'):
- self.app.extensions = dict()
+ self.app.extensions = {}
admins = self.app.extensions.get('admin', [])
diff --git a/flask_admin/contrib/appengine/view.py b/flask_admin/contrib/appengine/view.py
index 010e4bdab..7adc9daea 100644
--- a/flask_admin/contrib/appengine/view.py
+++ b/flask_admin/contrib/appengine/view.py
@@ -53,7 +53,7 @@ class MyAdminView(ModelView):
"""
def scaffold_form(self):
- form_class = wt_ndb.model_form(
+ return wt_ndb.model_form(
self.model(),
base_class=Form,
only=self.form_columns,
@@ -61,7 +61,6 @@ def scaffold_form(self):
field_args=self.form_args,
converter=self.model_form_converter(),
)
- return form_class
def scaffold_list_form(self, widget=None, validators=None):
form_class = wt_ndb.model_form(
@@ -71,8 +70,7 @@ def scaffold_list_form(self, widget=None, validators=None):
field_args=self.form_args,
converter=self.model_form_converter(),
)
- result = create_editable_list_form(Form, form_class, widget)
- return result
+ return create_editable_list_form(Form, form_class, widget)
def get_list(self, page, sort_field, sort_desc, search, filters,
page_size=None):
@@ -179,7 +177,7 @@ def get_list(self, page, sort_field, sort_desc, search, filters):
if sort_field:
if sort_desc:
- sort_field = "-" + sort_field
+ sort_field = f"-{sort_field}"
q.order(sort_field)
results = q.fetch(self.page_size, offset=page * self.page_size)
@@ -232,4 +230,4 @@ def ModelView(model):
elif issubclass(model, db.Model):
return DbModelView(model)
else:
- raise ValueError("Unsupported model: %s" % model)
+ raise ValueError(f"Unsupported model: {model}")
diff --git a/flask_admin/contrib/fileadmin/__init__.py b/flask_admin/contrib/fileadmin/__init__.py
index f5c786adb..8c08835ba 100644
--- a/flask_admin/contrib/fileadmin/__init__.py
+++ b/flask_admin/contrib/fileadmin/__init__.py
@@ -481,10 +481,7 @@ def edit_form(self):
Override to implement custom behavior.
"""
edit_form_class = self.get_edit_form()
- if request.form:
- return edit_form_class(request.form)
- else:
- return edit_form_class()
+ return edit_form_class(request.form) if request.form else edit_form_class()
def delete_form(self):
"""
@@ -493,10 +490,7 @@ def delete_form(self):
Override to implement custom behavior.
"""
delete_form_class = self.get_delete_form()
- if request.form:
- return delete_form_class(request.form)
- else:
- return delete_form_class()
+ return delete_form_class(request.form) if request.form else delete_form_class()
def action_form(self):
"""
@@ -505,10 +499,7 @@ def action_form(self):
Override to implement custom behavior.
"""
action_form_class = self.get_action_form()
- if request.form:
- return action_form_class(request.form)
- else:
- return action_form_class()
+ return action_form_class(request.form) if request.form else action_form_class()
def is_file_allowed(self, filename):
"""
@@ -590,15 +581,13 @@ def _get_dir_url(self, endpoint, path=None, **kwargs):
:param kwargs:
Additional arguments
"""
- if not path:
- return self.get_url(endpoint, **kwargs)
- else:
+ if path:
if self._on_windows:
path = path.replace('\\', '/')
kwargs['path'] = path
- return self.get_url(endpoint, **kwargs)
+ return self.get_url(endpoint, **kwargs)
def _get_file_url(self, path, **kwargs):
"""
@@ -610,11 +599,7 @@ def _get_file_url(self, path, **kwargs):
if self._on_windows:
path = path.replace('\\', '/')
- if self.is_file_editable(path):
- route = '.edit'
- else:
- route = '.download'
-
+ route = '.edit' if self.is_file_editable(path) else '.download'
return self.get_url(route, path=path, **kwargs)
def _normalize_path(self, path):
@@ -631,11 +616,7 @@ def _normalize_path(self, path):
path = ''
else:
path = op.normpath(path)
- if base_path:
- directory = self._separator.join([base_path, path])
- else:
- directory = path
-
+ directory = self._separator.join([base_path, path]) if base_path else path
directory = op.normpath(directory)
if not self.is_in_folder(base_path, directory):
@@ -871,14 +852,10 @@ def index_view(self, path=None):
action_form = None
def sort_url(column, path, invert=False):
- desc = None
-
if not path:
path = None
- if invert and not sort_desc:
- desc = 1
-
+ desc = 1 if invert and not sort_desc else None
return self.get_url('.index_view', path=path, sort=column, desc=desc)
return self.render(self.list_template,
@@ -948,9 +925,7 @@ def download(self, path=None):
base_path, directory, path = self._normalize_path(path)
- # backward compatibility with base_url
- base_url = self.get_base_url()
- if base_url:
+ if base_url := self.get_base_url():
base_url = urljoin(self.get_url('.index_view'), base_url)
return redirect(urljoin(quote(base_url), quote(path)))
@@ -1063,13 +1038,12 @@ def rename(self):
form = self.name_form()
path = form.path.data
- if path:
- base_path, full_path, path = self._normalize_path(path)
-
- return_url = self._get_dir_url('.index_view', op.dirname(path))
- else:
+ if not path:
return redirect(self.get_url('.index_view'))
+ base_path, full_path, path = self._normalize_path(path)
+
+ return_url = self._get_dir_url('.index_view', op.dirname(path))
if not self.can_rename:
flash(gettext('Renaming is disabled.'), 'error')
return redirect(return_url)
@@ -1113,15 +1087,11 @@ def edit(self):
"""
Edit view method
"""
- next_url = None
-
path = request.args.getlist('path')
if not path:
return redirect(self.get_url('.index_view'))
- if len(path) > 1:
- next_url = self.get_url('.edit', path=path[1:])
-
+ next_url = self.get_url('.edit', path=path[1:]) if len(path) > 1 else None
path = path[0]
base_path, full_path, path = self._normalize_path(path)
@@ -1156,18 +1126,12 @@ def edit(self):
except IOError:
flash(gettext("Error reading %(name)s.", name=path), 'error')
error = True
- except:
- flash(gettext("Unexpected error while reading from %(name)s", name=path), 'error')
- error = True
else:
try:
content = content.decode('utf8')
except UnicodeDecodeError:
flash(gettext("Cannot edit %(name)s.", name=path), 'error')
error = True
- except:
- flash(gettext("Unexpected error while reading from %(name)s", name=path), 'error')
- error = True
else:
form.content.data = content
diff --git a/flask_admin/contrib/fileadmin/azure.py b/flask_admin/contrib/fileadmin/azure.py
index 8e0fb44e3..1d389d097 100644
--- a/flask_admin/contrib/fileadmin/azure.py
+++ b/flask_admin/contrib/fileadmin/azure.py
@@ -112,10 +112,10 @@ def get_files(self, path, directory):
folders.add(folder_name)
folders.discard(directory)
+ is_dir = True
for folder in folders:
name = folder.split(self.separator)[-1]
rel_path = folder
- is_dir = True
size = 0
last_modified = 0
files.append((name, rel_path, is_dir, size, last_modified))
@@ -125,14 +125,12 @@ def get_files(self, path, directory):
def is_dir(self, path):
path = self._ensure_blob_path(path)
- num_blobs = 0
- for blob in self._client.list_blobs(self._container_name, path):
+ for num_blobs, blob in enumerate(self._client.list_blobs(self._container_name, path), start=1):
blob_path_parts = blob.name.split(self.separator)
is_explicit_directory = blob_path_parts[-1] == self._fakedir
if is_explicit_directory:
return True
- num_blobs += 1
path_cannot_be_leaf = num_blobs >= 2
if path_cannot_be_leaf:
return True
@@ -178,7 +176,7 @@ def send_file(self, file_path):
BlobPermissions.READ,
expiry=now + self._send_file_validity,
start=now - self._send_file_lookback)
- return redirect('%s?%s' % (url, sas))
+ return redirect(f'{url}?{sas}')
def read_file(self, path):
path = self._ensure_blob_path(path)
diff --git a/flask_admin/contrib/fileadmin/s3.py b/flask_admin/contrib/fileadmin/s3.py
index 1c9bdc719..8eeb7879c 100644
--- a/flask_admin/contrib/fileadmin/s3.py
+++ b/flask_admin/contrib/fileadmin/s3.py
@@ -65,9 +65,7 @@ def __init__(self, bucket_name, region, aws_access_key_id,
def get_files(self, path, directory):
def _strip_path(name, path):
- if name.startswith(path):
- return name.replace(path, '', 1)
- return name
+ return name.replace(path, '', 1) if name.startswith(path) else name
def _remove_trailing_slash(name):
return name[:-1]
@@ -95,11 +93,11 @@ def _iso_to_epoch(timestamp):
def _get_bucket_list_prefix(self, path):
parts = path.split(self.separator)
- if len(parts) == 1:
- search = ''
- else:
- search = self.separator.join(parts[:-1]) + self.separator
- return search
+ return (
+ ''
+ if len(parts) == 1
+ else self.separator.join(parts[:-1]) + self.separator
+ )
def _get_path_keys(self, path):
search = self._get_bucket_list_prefix(path)
diff --git a/flask_admin/contrib/geoa/fields.py b/flask_admin/contrib/geoa/fields.py
index 0e9a30371..3d7f75bd3 100644
--- a/flask_admin/contrib/geoa/fields.py
+++ b/flask_admin/contrib/geoa/fields.py
@@ -19,31 +19,27 @@ def __init__(self, label=None, validators=None, geometry_type="GEOMETRY",
super(GeoJSONField, self).__init__(label, validators, **kwargs)
self.web_srid = 4326
self.srid = srid
- if self.srid == -1:
- self.transform_srid = self.web_srid
- else:
- self.transform_srid = self.srid
+ self.transform_srid = self.web_srid if self.srid == -1 else self.srid
self.geometry_type = geometry_type.upper()
self.session = session
def _value(self):
if self.raw_data:
return self.raw_data[0]
- if type(self.data) is geoalchemy2.elements.WKBElement:
- if self.srid == -1:
- return self.session.scalar(func.ST_AsGeoJSON(self.data))
- else:
- return self.session.scalar(
- func.ST_AsGeoJSON(
- func.ST_Transform(self.data, self.web_srid)
- )
- )
- else:
+ if type(self.data) is not geoalchemy2.elements.WKBElement:
return ''
+ if self.srid == -1:
+ return self.session.scalar(func.ST_AsGeoJSON(self.data))
+ else:
+ return self.session.scalar(
+ func.ST_AsGeoJSON(
+ func.ST_Transform(self.data, self.web_srid)
+ )
+ )
def process_formdata(self, valuelist):
super(GeoJSONField, self).process_formdata(valuelist)
- if str(self.data) == '':
+ if not str(self.data):
self.data = None
if self.data is not None:
web_shape = self.session.scalar(
@@ -57,4 +53,4 @@ def process_formdata(self, valuelist):
)
)
)
- self.data = 'SRID=' + str(self.srid) + ';' + str(web_shape)
+ self.data = f'SRID={str(self.srid)};{str(web_shape)}'
diff --git a/flask_admin/contrib/geoa/typefmt.py b/flask_admin/contrib/geoa/typefmt.py
index 4e1902247..605e91687 100644
--- a/flask_admin/contrib/geoa/typefmt.py
+++ b/flask_admin/contrib/geoa/typefmt.py
@@ -22,7 +22,7 @@ def geom_formatter(view, value):
value.srid = 4326
geojson = view.session.query(view.model).with_entities(func.ST_AsGeoJSON(value)).scalar()
- return Markup('' % (params, geojson))
+ return Markup(f'')
DEFAULT_FORMATTERS = BASE_FORMATTERS.copy()
diff --git a/flask_admin/contrib/mongoengine/ajax.py b/flask_admin/contrib/mongoengine/ajax.py
index 470911fbc..7c190ffbd 100644
--- a/flask_admin/contrib/mongoengine/ajax.py
+++ b/flask_admin/contrib/mongoengine/ajax.py
@@ -20,7 +20,9 @@ def __init__(self, name, model, **options):
self._cached_fields = self._process_fields()
if not self.fields:
- raise ValueError('AJAX loading requires `fields` to be specified for %s.%s' % (model, self.name))
+ raise ValueError(
+ f'AJAX loading requires `fields` to be specified for {model}.{self.name}'
+ )
def _process_fields(self):
remote_fields = []
@@ -30,7 +32,7 @@ def _process_fields(self):
attr = getattr(self.model, field, None)
if not attr:
- raise ValueError('%s.%s does not exist.' % (self.model, field))
+ raise ValueError(f'{self.model}.{field} does not exist.')
remote_fields.append(attr)
else:
@@ -39,10 +41,7 @@ def _process_fields(self):
return remote_fields
def format(self, model):
- if not model:
- return None
-
- return (as_unicode(model.pk), as_unicode(model))
+ return (as_unicode(model.pk), as_unicode(model)) if model else None
def get_one(self, pk):
return self.model.objects.filter(pk=pk).first()
@@ -54,7 +53,7 @@ def get_list(self, term, offset=0, limit=DEFAULT_PAGE_SIZE):
criteria = None
for field in self._cached_fields:
- flt = {u'%s__icontains' % field.name: term}
+ flt = {f'{field.name}__icontains': term}
if not criteria:
criteria = mongoengine.Q(**flt)
@@ -73,16 +72,16 @@ def create_ajax_loader(model, name, field_name, opts):
prop = getattr(model, field_name, None)
if prop is None:
- raise ValueError('Model %s does not have field %s.' % (model, field_name))
+ raise ValueError(f'Model {model} does not have field {field_name}.')
ftype = type(prop).__name__
- if ftype == 'ListField' or ftype == 'SortedListField':
+ if ftype in ['ListField', 'SortedListField']:
prop = prop.field
ftype = type(prop).__name__
if ftype != 'ReferenceField':
- raise ValueError('Dont know how to convert %s type for AJAX loader' % ftype)
+ raise ValueError(f'Dont know how to convert {ftype} type for AJAX loader')
remote_model = prop.document_type
return QueryAjaxModelLoader(name, remote_model, **opts)
@@ -90,18 +89,13 @@ def create_ajax_loader(model, name, field_name, opts):
def process_ajax_references(references, view):
def make_name(base, name):
- if base:
- return ('%s-%s' % (base, name)).lower()
- else:
- return as_unicode(name).lower()
+ return f'{base}-{name}'.lower() if base else as_unicode(name).lower()
def handle_field(field, subdoc, base):
ftype = type(field).__name__
- if ftype == 'ListField' or ftype == 'SortedListField':
- child_doc = getattr(subdoc, '_form_subdocuments', {}).get(None)
-
- if child_doc:
+ if ftype in ['ListField', 'SortedListField']:
+ if child_doc := getattr(subdoc, '_form_subdocuments', {}).get(None):
handle_field(field.field, child_doc, base)
elif ftype == 'EmbeddedDocumentField':
result = {}
@@ -121,11 +115,10 @@ def handle_field(field, subdoc, base):
subdoc._form_ajax_refs = result
- child_doc = getattr(subdoc, '_form_subdocuments', None)
- if child_doc:
+ if child_doc := getattr(subdoc, '_form_subdocuments', None):
handle_subdoc(field.document_type_obj, subdoc, base)
else:
- raise ValueError('Failed to process subdocument field %s' % (field,))
+ raise ValueError(f'Failed to process subdocument field {field}')
def handle_subdoc(model, subdoc, base):
documents = getattr(subdoc, '_form_subdocuments', {})
@@ -134,7 +127,7 @@ def handle_subdoc(model, subdoc, base):
field = getattr(model, name, None)
if not field:
- raise ValueError('Invalid subdocument field %s.%s' % (model, name))
+ raise ValueError(f'Invalid subdocument field {model}.{name}')
handle_field(field, doc, make_name(base, name))
diff --git a/flask_admin/contrib/mongoengine/fields.py b/flask_admin/contrib/mongoengine/fields.py
index 64418ecb0..eaeb9308b 100644
--- a/flask_admin/contrib/mongoengine/fields.py
+++ b/flask_admin/contrib/mongoengine/fields.py
@@ -59,7 +59,7 @@ def __init__(self, label=None, validators=None, **kwargs):
def process(self, formdata, data=unset_value):
if formdata:
- marker = '_%s-delete' % self.name
+ marker = f'_{self.name}-delete'
if marker in formdata:
self._should_delete = True
@@ -74,11 +74,7 @@ def populate_obj(self, obj, name):
return
if isinstance(self.data, FileStorage) and not is_empty(self.data.stream):
- if not field.grid_id:
- func = field.put
- else:
- func = field.replace
-
+ func = field.replace if field.grid_id else field.put
func(self.data.stream,
filename=self.data.filename,
content_type=self.data.content_type)
diff --git a/flask_admin/contrib/mongoengine/filters.py b/flask_admin/contrib/mongoengine/filters.py
index ccbe30368..94cf952b1 100644
--- a/flask_admin/contrib/mongoengine/filters.py
+++ b/flask_admin/contrib/mongoengine/filters.py
@@ -32,7 +32,7 @@ def __init__(self, column, name, options=None, data_type=None):
# Common filters
class FilterEqual(BaseMongoEngineFilter):
def apply(self, query, value):
- flt = {'%s' % self.column.name: value}
+ flt = {f'{self.column.name}': value}
return query.filter(**flt)
def operation(self):
@@ -41,7 +41,7 @@ def operation(self):
class FilterNotEqual(BaseMongoEngineFilter):
def apply(self, query, value):
- flt = {'%s__ne' % self.column.name: value}
+ flt = {f'{self.column.name}__ne': value}
return query.filter(**flt)
def operation(self):
@@ -51,7 +51,7 @@ def operation(self):
class FilterLike(BaseMongoEngineFilter):
def apply(self, query, value):
term, data = parse_like_term(value)
- flt = {'%s__%s' % (self.column.name, term): data}
+ flt = {f'{self.column.name}__{term}': data}
return query.filter(**flt)
def operation(self):
@@ -61,7 +61,7 @@ def operation(self):
class FilterNotLike(BaseMongoEngineFilter):
def apply(self, query, value):
term, data = parse_like_term(value)
- flt = {'%s__not__%s' % (self.column.name, term): data}
+ flt = {f'{self.column.name}__not__{term}': data}
return query.filter(**flt)
def operation(self):
@@ -70,7 +70,7 @@ def operation(self):
class FilterGreater(BaseMongoEngineFilter):
def apply(self, query, value):
- flt = {'%s__gt' % self.column.name: value}
+ flt = {f'{self.column.name}__gt': value}
return query.filter(**flt)
def operation(self):
@@ -79,7 +79,7 @@ def operation(self):
class FilterSmaller(BaseMongoEngineFilter):
def apply(self, query, value):
- flt = {'%s__lt' % self.column.name: value}
+ flt = {f'{self.column.name}__lt': value}
return query.filter(**flt)
def operation(self):
@@ -89,9 +89,9 @@ def operation(self):
class FilterEmpty(BaseMongoEngineFilter, filters.BaseBooleanFilter):
def apply(self, query, value):
if value == '1':
- flt = {'%s' % self.column.name: None}
+ flt = {f'{self.column.name}': None}
else:
- flt = {'%s__ne' % self.column.name: None}
+ flt = {f'{self.column.name}__ne': None}
return query.filter(**flt)
def operation(self):
@@ -106,7 +106,7 @@ def clean(self, value):
return [v.strip() for v in value.split(',') if v.strip()]
def apply(self, query, value):
- flt = {'%s__in' % self.column.name: value}
+ flt = {f'{self.column.name}__in': value}
return query.filter(**flt)
def operation(self):
@@ -115,7 +115,7 @@ def operation(self):
class FilterNotInList(FilterInList):
def apply(self, query, value):
- flt = {'%s__nin' % self.column.name: value}
+ flt = {f'{self.column.name}__nin': value}
return query.filter(**flt)
def operation(self):
@@ -125,13 +125,13 @@ def operation(self):
# Customized type filters
class BooleanEqualFilter(FilterEqual, filters.BaseBooleanFilter):
def apply(self, query, value):
- flt = {'%s' % self.column.name: value == '1'}
+ flt = {f'{self.column.name}': value == '1'}
return query.filter(**flt)
class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter):
def apply(self, query, value):
- flt = {'%s' % self.column.name: value != '1'}
+ flt = {f'{self.column.name}': value != '1'}
return query.filter(**flt)
@@ -208,15 +208,19 @@ def __init__(self, column, name, options=None, data_type=None):
def apply(self, query, value):
start, end = value
- flt = {'%s__gte' % self.column.name: start, '%s__lte' % self.column.name: end}
+ flt = {f'{self.column.name}__gte': start, f'{self.column.name}__lte': end}
return query.filter(**flt)
class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
def apply(self, query, value):
start, end = value
- return query.filter(Q(**{'%s__not__gte' % self.column.name: start}) |
- Q(**{'%s__not__lte' % self.column.name: end}))
+ return query.filter(
+ (
+ Q(**{f'{self.column.name}__not__gte': start})
+ | Q(**{f'{self.column.name}__not__lte': end})
+ )
+ )
def operation(self):
return lazy_gettext('not between')
@@ -240,7 +244,7 @@ def clean(self, value):
return ObjectId(value.strip())
def apply(self, query, value):
- flt = {'%s' % self.column.name: value}
+ flt = {f'{self.column.name}': value}
return query.filter(**flt)
def operation(self):
diff --git a/flask_admin/contrib/mongoengine/form.py b/flask_admin/contrib/mongoengine/form.py
index 1934f2bc3..e47628c85 100644
--- a/flask_admin/contrib/mongoengine/form.py
+++ b/flask_admin/contrib/mongoengine/form.py
@@ -28,9 +28,7 @@ def __init__(self, view):
self.view = view
def _get_field_override(self, name):
- form_overrides = getattr(self.view, 'form_overrides', None)
-
- if form_overrides:
+ if form_overrides := getattr(self.view, 'form_overrides', None):
return form_overrides.get(name)
return None
@@ -68,7 +66,7 @@ def convert(self, model, field, field_args):
}
if field_args:
- kwargs.update(field_args)
+ kwargs |= field_args
if kwargs['validators']:
# Create a copy of the list since we will be modifying it.
@@ -98,8 +96,7 @@ def convert(self, model, field, field_args):
if hasattr(field, 'to_form_field'):
return field.to_form_field(model, kwargs)
- override = self._get_field_override(field.name)
- if override:
+ if override := self._get_field_override(field.name):
return override(**kwargs)
if ftype in self.converters:
@@ -116,8 +113,7 @@ def conv_List(self, model, field, kwargs):
raise ValueError('ListField "%s" must have field specified for model %s' % (field.name, model))
if isinstance(field.field, ReferenceField):
- loader = getattr(self.view, '_form_ajax_refs', {}).get(field.name)
- if loader:
+ if loader := getattr(self.view, '_form_ajax_refs', {}).get(field.name):
return AjaxSelectMultipleField(loader, **kwargs)
kwargs['widget'] = form.Select2Widget(multiple=True)
@@ -166,8 +162,7 @@ def conv_EmbeddedDocument(self, model, field, kwargs):
def conv_Reference(self, model, field, kwargs):
kwargs['allow_blank'] = not field.required
- loader = getattr(self.view, '_form_ajax_refs', {}).get(field.name)
- if loader:
+ if loader := getattr(self.view, '_form_ajax_refs', {}).get(field.name):
return AjaxSelectField(loader, **kwargs)
kwargs['widget'] = form.Select2Widget()
@@ -237,7 +232,7 @@ def find(name):
if p is not None:
return p
- raise ValueError('Invalid model property name %s.%s' % (model, name))
+ raise ValueError(f'Invalid model property name {model}.{name}')
properties = ((p, find(p)) for p in only)
elif exclude:
diff --git a/flask_admin/contrib/mongoengine/helpers.py b/flask_admin/contrib/mongoengine/helpers.py
index 48ee9f77c..a496d51a1 100644
--- a/flask_admin/contrib/mongoengine/helpers.py
+++ b/flask_admin/contrib/mongoengine/helpers.py
@@ -16,18 +16,17 @@ def make_gridfs_args(value):
def make_thumb_args(value):
- if getattr(value, 'thumbnail', None):
- args = {
- 'id': value.thumbnail._id,
- 'coll': value.collection_name
- }
+ if not getattr(value, 'thumbnail', None):
+ return make_gridfs_args(value)
+ args = {
+ 'id': value.thumbnail._id,
+ 'coll': value.collection_name
+ }
- if value.db_alias != 'default':
- args['db'] = value.db_alias
+ if value.db_alias != 'default':
+ args['db'] = value.db_alias
- return args
- else:
- return make_gridfs_args(value)
+ return args
def format_error(error):
diff --git a/flask_admin/contrib/mongoengine/tools.py b/flask_admin/contrib/mongoengine/tools.py
index 4d52671ce..ddec105c7 100644
--- a/flask_admin/contrib/mongoengine/tools.py
+++ b/flask_admin/contrib/mongoengine/tools.py
@@ -24,5 +24,5 @@ def parse_like_term(term):
oper = 'contains'
# add case insensitive flag
if not case_sensitive:
- oper = 'i' + oper
+ oper = f'i{oper}'
return oper, term
diff --git a/flask_admin/contrib/mongoengine/typefmt.py b/flask_admin/contrib/mongoengine/typefmt.py
index 840526de3..6e8d9393e 100644
--- a/flask_admin/contrib/mongoengine/typefmt.py
+++ b/flask_admin/contrib/mongoengine/typefmt.py
@@ -27,17 +27,25 @@ def grid_formatter(view, value):
def grid_image_formatter(view, value):
- if not value.grid_id:
- return ''
-
- return Markup(
- ('' +
- '
' +
- '
') %
- {
- 'url': view.get_url('.api_file_view', **helpers.make_gridfs_args(value)),
- 'thumb': view.get_url('.api_file_view', **helpers.make_thumb_args(value)),
- })
+ return (
+ Markup(
+ (
+ ''
+ + '
'
+ + '
'
+ )
+ % {
+ 'url': view.get_url(
+ '.api_file_view', **helpers.make_gridfs_args(value)
+ ),
+ 'thumb': view.get_url(
+ '.api_file_view', **helpers.make_thumb_args(value)
+ ),
+ }
+ )
+ if value.grid_id
+ else ''
+ )
DEFAULT_FORMATTERS = BASE_FORMATTERS.copy()
diff --git a/flask_admin/contrib/mongoengine/view.py b/flask_admin/contrib/mongoengine/view.py
index 10663164a..f066f432f 100644
--- a/flask_admin/contrib/mongoengine/view.py
+++ b/flask_admin/contrib/mongoengine/view.py
@@ -26,7 +26,7 @@
log = logging.getLogger("flask-admin.mongo")
-SORTABLE_FIELDS = set((
+SORTABLE_FIELDS = {
mongoengine.StringField,
mongoengine.IntField,
mongoengine.FloatField,
@@ -38,8 +38,8 @@
mongoengine.ReferenceField,
mongoengine.EmailField,
mongoengine.UUIDField,
- mongoengine.URLField
-))
+ mongoengine.URLField,
+}
class ModelView(BaseModelView):
@@ -339,14 +339,12 @@ def scaffold_sortable_columns(self):
"""
Return a dictionary of sortable columns (name, field)
"""
- columns = {}
-
- for n, f in self._get_model_fields():
- if type(f) in SORTABLE_FIELDS:
- if self.column_display_pk or type(f) != mongoengine.ObjectIdField:
- columns[n] = f
-
- return columns
+ return {
+ n: f
+ for n, f in self._get_model_fields()
+ if type(f) in SORTABLE_FIELDS
+ and (self.column_display_pk or type(f) != mongoengine.ObjectIdField)
+ }
def init_search(self):
"""
@@ -378,30 +376,22 @@ def scaffold_filters(self, name):
:param name:
Either field name or field instance
"""
- if isinstance(name, string_types):
- attr = self.model._fields.get(name)
- else:
- attr = name
-
+ attr = self.model._fields.get(name) if isinstance(name, string_types) else name
if attr is None:
- raise Exception('Failed to find field for filter: %s' % name)
+ raise Exception(f'Failed to find field for filter: {name}')
- # Find name
- visible_name = None
-
- if not isinstance(name, string_types):
- visible_name = self.get_column_name(attr.name)
+ visible_name = (
+ None
+ if isinstance(name, string_types)
+ else self.get_column_name(attr.name)
+ )
if not visible_name:
visible_name = self.get_column_name(name)
# Convert filter
type_name = type(attr).__name__
- flt = self.filter_converter.convert(type_name,
- attr,
- visible_name)
-
- return flt
+ return self.filter_converter.convert(type_name, attr, visible_name)
def is_valid_filter(self, filter):
"""
@@ -416,15 +406,15 @@ def scaffold_form(self):
"""
Create form from the model.
"""
- form_class = get_form(self.model,
- self.model_form_converter(self),
- base_class=self.form_base_class,
- only=self.form_columns,
- exclude=self.form_excluded_columns,
- field_args=self.form_args,
- extra_fields=self.form_extra_fields)
-
- return form_class
+ return get_form(
+ self.model,
+ self.model_form_converter(self),
+ base_class=self.form_base_class,
+ only=self.form_columns,
+ exclude=self.form_excluded_columns,
+ field_args=self.form_args,
+ extra_fields=self.form_extra_fields,
+ )
def scaffold_list_form(self, widget=None, validators=None):
"""
@@ -469,10 +459,10 @@ def _search(self, query, search_term):
for field in self._search_fields:
if type(field) == mongoengine.ReferenceField:
import re
- regex = re.compile('.*%s.*' % term)
+ regex = re.compile(f'.*{term}.*')
else:
regex = term
- flt = {'%s__%s' % (field.name, op): regex}
+ flt = {f'{field.name}__{op}': regex}
q = mongoengine.Q(**flt)
if criteria is None:
@@ -517,18 +507,14 @@ def get_list(self, page, sort_column, sort_desc, search, filters,
query = self._search(query, search)
# Get count
- count = query.count() if not self.simple_list_pager else None
+ count = None if self.simple_list_pager else query.count()
# Sorting
if sort_column:
- query = query.order_by('%s%s' % ('-' if sort_desc else '', sort_column))
- else:
- order = self._get_default_order()
-
- if order:
- keys = ['%s%s' % ('-' if desc else '', col)
- for (col, desc) in order]
- query = query.order_by(*keys)
+ query = query.order_by(f"{'-' if sort_desc else ''}{sort_column}")
+ elif order := self._get_default_order():
+ keys = [f"{'-' if desc else ''}{col}" for (col, desc) in order]
+ query = query.order_by(*keys)
# Pagination
if page_size is None:
@@ -667,11 +653,11 @@ def is_action_allowed(self, name):
lazy_gettext('Are you sure you want to delete selected records?'))
def action_delete(self, ids):
try:
- count = 0
-
all_ids = [self.object_id_converter(pk) for pk in ids]
- for obj in self.get_query().in_bulk(all_ids).values():
- count += self.delete_model(obj)
+ count = sum(
+ self.delete_model(obj)
+ for obj in self.get_query().in_bulk(all_ids).values()
+ )
flash(ngettext('Record was successfully deleted.',
'%(count)s records were successfully deleted.',
diff --git a/flask_admin/contrib/mongoengine/widgets.py b/flask_admin/contrib/mongoengine/widgets.py
index 55a503c81..3ba2037c9 100644
--- a/flask_admin/contrib/mongoengine/widgets.py
+++ b/flask_admin/contrib/mongoengine/widgets.py
@@ -29,9 +29,10 @@ def __call__(self, field, **kwargs):
'name': escape(data.name),
'content_type': escape(data.content_type),
'size': data.length // 1024,
- 'marker': '_%s-delete' % field.name
+ 'marker': f'_{field.name}-delete',
}
+
return Markup('%s' % (placeholder,
html_params(name=field.name,
type='file',
@@ -54,9 +55,10 @@ def __call__(self, field, **kwargs):
args = helpers.make_thumb_args(field.data)
placeholder = self.template % {
'thumb': get_url('.api_file_view', **args),
- 'marker': '_%s-delete' % field.name
+ 'marker': f'_{field.name}-delete',
}
+
return Markup('%s' % (placeholder,
html_params(name=field.name,
type='file',
diff --git a/flask_admin/contrib/peewee/ajax.py b/flask_admin/contrib/peewee/ajax.py
index 99d8842db..eab80b78f 100644
--- a/flask_admin/contrib/peewee/ajax.py
+++ b/flask_admin/contrib/peewee/ajax.py
@@ -18,7 +18,10 @@ def __init__(self, name, model, **options):
self.fields = options.get('fields')
if not self.fields:
- raise ValueError('AJAX loading requires `fields` to be specified for %s.%s' % (model, self.name))
+ raise ValueError(
+ f'AJAX loading requires `fields` to be specified for {model}.{self.name}'
+ )
+
self._cached_fields = self._process_fields()
@@ -32,7 +35,7 @@ def _process_fields(self):
attr = getattr(self.model, field, None)
if not attr:
- raise ValueError('%s.%s does not exist.' % (self.model, field))
+ raise ValueError(f'{self.model}.{field} does not exist.')
remote_fields.append(attr)
else:
@@ -41,10 +44,7 @@ def _process_fields(self):
return remote_fields
def format(self, model):
- if not model:
- return None
-
- return (getattr(model, self.pk), as_unicode(model))
+ return (getattr(model, self.pk), as_unicode(model)) if model else None
def get_one(self, pk):
return self.model.get(**{self.pk: pk})
@@ -74,7 +74,7 @@ def create_ajax_loader(model, name, field_name, options):
prop = getattr(model, field_name, None)
if prop is None:
- raise ValueError('Model %s does not have field %s.' % (model, field_name))
+ raise ValueError(f'Model {model} does not have field {field_name}.')
# TODO: Check for field
remote_model = prop.rel_model
diff --git a/flask_admin/contrib/peewee/form.py b/flask_admin/contrib/peewee/form.py
index e48390ffd..abc553bb8 100644
--- a/flask_admin/contrib/peewee/form.py
+++ b/flask_admin/contrib/peewee/form.py
@@ -72,7 +72,7 @@ def save_related(self, obj):
attr = getattr(self.model, self.prop)
values = self.model.select().where(attr == model_id).execute()
- pk_map = dict((str(getattr(v, self._pk)), v) for v in values)
+ pk_map = {str(getattr(v, self._pk)): v for v in values}
# Handle request data
for field in self.entries:
@@ -123,9 +123,7 @@ def __init__(self, view, additional=None):
self.overrides = getattr(self.view, 'form_overrides', None) or {}
def handle_foreign_key(self, model, field, **kwargs):
- loader = getattr(self.view, '_form_ajax_refs', {}).get(field.name)
-
- if loader:
+ if loader := getattr(self.view, '_form_ajax_refs', {}).get(field.name):
if field.null:
kwargs['allow_blank'] = True
@@ -199,13 +197,14 @@ def get_info(self, p):
else:
model = getattr(p, 'model', None)
if model is None:
- raise Exception('Unknown inline model admin: %s' % repr(p))
+ raise Exception(f'Unknown inline model admin: {repr(p)}')
- attrs = dict()
+ attrs = {
+ attr: getattr(p, attr)
+ for attr in dir(p)
+ if not attr.startswith('_') and attr != 'model'
+ }
- for attr in dir(p):
- if not attr.startswith('_') and attr != 'model':
- attrs[attr] = getattr(p, attr)
info = InlineFormAdmin(model, **attrs)
@@ -215,13 +214,11 @@ def get_info(self, p):
return info
def process_ajax_refs(self, info):
- refs = getattr(info, 'form_ajax_refs', None)
-
result = {}
- if refs:
+ if refs := getattr(info, 'form_ajax_refs', None):
for name, opts in iteritems(refs):
- new_name = '%s.%s' % (info.model.__name__.lower(), name)
+ new_name = f'{info.model.__name__.lower()}.{name}'
loader = None
if isinstance(opts, (list, tuple)):
@@ -243,12 +240,11 @@ def contribute(self, converter, model, form_class, inline_model):
for field in get_meta_fields(info.model):
field_type = type(field)
- if field_type == ForeignKeyField:
- if field.rel_model == model:
- reverse_field = field
- break
+ if field_type == ForeignKeyField and field.rel_model == model:
+ reverse_field = field
+ break
else:
- raise Exception('Cannot find reverse relation for model %s' % info.model)
+ raise Exception(f'Cannot find reverse relation for model {info.model}')
# Remove reverse property from the list
ignore = [reverse_field.name]
diff --git a/flask_admin/contrib/peewee/tools.py b/flask_admin/contrib/peewee/tools.py
index 98277d56b..dff9a5548 100644
--- a/flask_admin/contrib/peewee/tools.py
+++ b/flask_admin/contrib/peewee/tools.py
@@ -4,18 +4,16 @@ def get_primary_key(model):
def parse_like_term(term):
if term.startswith('^'):
- stmt = '%s%%' % term[1:]
+ return '%s%%' % term[1:]
elif term.startswith('='):
- stmt = term[1:]
+ return term[1:]
else:
- stmt = '%%%s%%' % term
-
- return stmt
+ return '%%%s%%' % term
def get_meta_fields(model):
- if hasattr(model._meta, 'sorted_fields'):
- fields = model._meta.sorted_fields
- else:
- fields = model._meta.get_fields()
- return fields
+ return (
+ model._meta.sorted_fields
+ if hasattr(model._meta, 'sorted_fields')
+ else model._meta.get_fields()
+ )
diff --git a/flask_admin/contrib/peewee/view.py b/flask_admin/contrib/peewee/view.py
index ee77873a4..b35b375a0 100644
--- a/flask_admin/contrib/peewee/view.py
+++ b/flask_admin/contrib/peewee/view.py
@@ -185,9 +185,11 @@ def scaffold_pk(self):
def get_pk_value(self, model):
if self.model._meta.composite_key:
- return tuple([
+ return tuple(
model._data[field_name]
- for field_name in self.model._meta.primary_key.field_names])
+ for field_name in self.model._meta.primary_key.field_names
+ )
+
return getattr(model, self._primary_key)
def scaffold_list_columns(self):
@@ -197,21 +199,20 @@ def scaffold_list_columns(self):
# Verify type
field_class = type(f)
- if field_class == ForeignKeyField:
- columns.append(n)
- elif self.column_display_pk or field_class != PrimaryKeyField:
+ if (
+ field_class == ForeignKeyField
+ or self.column_display_pk
+ or field_class != PrimaryKeyField
+ ):
columns.append(n)
-
return columns
def scaffold_sortable_columns(self):
- columns = dict()
-
- for n, f in self._get_model_fields():
- if self.column_display_pk or type(f) != PrimaryKeyField:
- columns[n] = f
-
- return columns
+ return {
+ n: f
+ for n, f in self._get_model_fields()
+ if self.column_display_pk or type(f) != PrimaryKeyField
+ }
def init_search(self):
if self.column_searchable_list:
@@ -235,7 +236,7 @@ def scaffold_filters(self, name):
attr = name
if attr is None:
- raise Exception('Failed to find field for filter: %s' % name)
+ raise Exception(f'Failed to find field for filter: {name}')
# Check if field is in different model
model_class = None
@@ -245,20 +246,17 @@ def scaffold_filters(self, name):
model_class = attr.model
if model_class != self.model:
- visible_name = '%s / %s' % (self.get_column_name(model_class.__name__),
- self.get_column_name(attr.name))
+ visible_name = f'{self.get_column_name(model_class.__name__)} / {self.get_column_name(attr.name)}'
+
else:
- if not isinstance(name, string_types):
- visible_name = self.get_column_name(attr.name)
- else:
- visible_name = self.get_column_name(name)
+ visible_name = (
+ self.get_column_name(name)
+ if isinstance(name, string_types)
+ else self.get_column_name(attr.name)
+ )
type_name = type(attr).__name__
- flt = self.filter_converter.convert(type_name,
- attr,
- visible_name)
-
- return flt
+ return self.filter_converter.convert(type_name, attr, visible_name)
def is_valid_filter(self, filter):
return isinstance(filter, filters.BasePeeweeFilter)
@@ -416,17 +414,15 @@ def get_list(self, page, sort_column, sort_desc, search, filters,
query = f.apply(query, f.clean(value))
# Get count
- count = query.count() if not self.simple_list_pager else None
+ count = None if self.simple_list_pager else query.count()
# Apply sorting
if sort_column is not None:
sort_field = self._sortable_columns[sort_column]
order = [(sort_field, sort_desc)]
query, joins = self._order_by(query, joins, order)
- else:
- order = self._get_default_order()
- if order:
- query, joins = self._order_by(query, joins, order)
+ elif order := self._get_default_order():
+ query, joins = self._order_by(query, joins, order)
# Pagination
if page_size is None:
diff --git a/flask_admin/contrib/pymongo/tools.py b/flask_admin/contrib/pymongo/tools.py
index e931527de..638f69884 100644
--- a/flask_admin/contrib/pymongo/tools.py
+++ b/flask_admin/contrib/pymongo/tools.py
@@ -9,8 +9,8 @@ def parse_like_term(term):
Search term
"""
if term.startswith('^'):
- return '^{}'.format(re.escape(term[1:]))
+ return f'^{re.escape(term[1:])}'
elif term.startswith('='):
- return '^{}$'.format(re.escape(term[1:]))
+ return f'^{re.escape(term[1:])}$'
return re.escape(term)
diff --git a/flask_admin/contrib/pymongo/view.py b/flask_admin/contrib/pymongo/view.py
index fe859fba8..fa7fa73c4 100644
--- a/flask_admin/contrib/pymongo/view.py
+++ b/flask_admin/contrib/pymongo/view.py
@@ -97,7 +97,7 @@ def __init__(self, coll,
name = self._prettify_name(coll.name)
if endpoint is None:
- endpoint = ('%sview' % coll.name).lower()
+ endpoint = f'{coll.name}view'.lower()
super(ModelView, self).__init__(None, name, category, endpoint, url,
menu_class_name=menu_class_name,
@@ -184,11 +184,9 @@ def _search(self, query, search_term):
regex = parse_like_term(value)
- stmt = []
- for field in self._search_fields:
- stmt.append({field: {'$regex': regex}})
-
- if stmt:
+ if stmt := [
+ {field: {'$regex': regex}} for field in self._search_fields
+ ]:
if len(stmt) == 1:
queries.append(stmt[0])
else:
@@ -196,16 +194,8 @@ def _search(self, query, search_term):
# Construct final query
if queries:
- if len(queries) == 1:
- final = queries[0]
- else:
- final = {'$and': queries}
-
- if query:
- query = {'$and': [query, final]}
- else:
- query = final
-
+ final = queries[0] if len(queries) == 1 else {'$and': queries}
+ query = {'$and': [query, final]} if query else final
return query
def get_list(self, page, sort_column, sort_desc, search, filters,
@@ -251,29 +241,22 @@ def get_list(self, page, sort_column, sort_desc, search, filters,
query = self._search(query, search)
# Get count
- count = self.coll.find(query).count() if not self.simple_list_pager else None
+ count = None if self.simple_list_pager else self.coll.find(query).count()
# Sorting
sort_by = None
if sort_column:
sort_by = [(sort_column, pymongo.DESCENDING if sort_desc else pymongo.ASCENDING)]
- else:
- order = self._get_default_order()
-
- if order:
- sort_by = [(col, pymongo.DESCENDING if desc else pymongo.ASCENDING)
- for (col, desc) in order]
+ elif order := self._get_default_order():
+ sort_by = [(col, pymongo.DESCENDING if desc else pymongo.ASCENDING)
+ for (col, desc) in order]
# Pagination
if page_size is None:
page_size = self.page_size
- skip = 0
-
- if page and page_size:
- skip = page * page_size
-
+ skip = page * page_size if page and page_size else 0
results = self.coll.find(query, sort=sort_by, skip=skip, limit=page_size)
if execute:
@@ -386,12 +369,7 @@ def is_action_allowed(self, name):
lazy_gettext('Are you sure you want to delete selected records?'))
def action_delete(self, ids):
try:
- count = 0
-
- # TODO: Optimize me
- for pk in ids:
- if self.delete_model(self.get_one(pk)):
- count += 1
+ count = sum(1 for pk in ids if self.delete_model(self.get_one(pk)))
flash(ngettext('Record was successfully deleted.',
'%(count)s records were successfully deleted.',
diff --git a/flask_admin/contrib/rediscli.py b/flask_admin/contrib/rediscli.py
index 659dc5e0d..296c038c5 100644
--- a/flask_admin/contrib/rediscli.py
+++ b/flask_admin/contrib/rediscli.py
@@ -103,9 +103,7 @@ def _execute_command(self, name, args):
:param args:
Command arguments
"""
- # Do some remapping
- new_cmd = self.remapped_commands.get(name)
- if new_cmd:
+ if new_cmd := self.remapped_commands.get(name):
name = new_cmd
# Execute command
@@ -150,8 +148,11 @@ def _cmd_help(self, *args):
Help command implementation.
"""
if not args:
- help = 'Usage: help .\nList of supported commands: '
- help += ', '.join(n for n in sorted(self.commands))
+ help = (
+ 'Usage: help .\nList of supported commands: '
+ + ', '.join(sorted(self.commands))
+ )
+
return TextWrapper(help)
cmd = args[0]
@@ -188,7 +189,7 @@ def execute_view(self):
return self._execute_command(parts[0], parts[1:])
except CommandError as err:
- return self._error('Cli: %s' % err)
+ return self._error(f'Cli: {err}')
except Exception as ex:
log.exception(ex)
- return self._error('Cli: %s' % ex)
+ return self._error(f'Cli: {ex}')
diff --git a/flask_admin/contrib/sqla/ajax.py b/flask_admin/contrib/sqla/ajax.py
index 8e4ab3cf9..492790c2d 100644
--- a/flask_admin/contrib/sqla/ajax.py
+++ b/flask_admin/contrib/sqla/ajax.py
@@ -26,7 +26,10 @@ def __init__(self, name, session, model, **options):
self.filters = options.get('filters')
if not self.fields:
- raise ValueError('AJAX loading requires `fields` to be specified for %s.%s' % (model, self.name))
+ raise ValueError(
+ f'AJAX loading requires `fields` to be specified for {model}.{self.name}'
+ )
+
self._cached_fields = self._process_fields()
@@ -43,7 +46,7 @@ def _process_fields(self):
attr = getattr(self.model, field, None)
if not attr:
- raise ValueError('%s.%s does not exist.' % (self.model, field))
+ raise ValueError(f'{self.model}.{field} does not exist.')
remote_fields.append(attr)
else:
@@ -53,10 +56,7 @@ def _process_fields(self):
return remote_fields
def format(self, model):
- if not model:
- return None
-
- return getattr(model, self.pk), as_unicode(model)
+ return (getattr(model, self.pk), as_unicode(model)) if model else None
def get_query(self):
return self.session.query(self.model)
@@ -75,7 +75,11 @@ def get_list(self, term, offset=0, limit=DEFAULT_PAGE_SIZE):
query = query.filter(or_(*filters))
if self.filters:
- filters = [text("%s.%s" % (self.model.__tablename__.lower(), value)) for value in self.filters]
+ filters = [
+ text(f"{self.model.__tablename__.lower()}.{value}")
+ for value in self.filters
+ ]
+
query = query.filter(and_(*filters))
if self.order_by:
@@ -88,10 +92,10 @@ def create_ajax_loader(model, session, name, field_name, options):
attr = getattr(model, field_name, None)
if attr is None:
- raise ValueError('Model %s does not have field %s.' % (model, field_name))
+ raise ValueError(f'Model {model} does not have field {field_name}.')
if not is_relationship(attr) and not is_association_proxy(attr):
- raise ValueError('%s.%s is not a relation.' % (model, field_name))
+ raise ValueError(f'{model}.{field_name} is not a relation.')
if is_association_proxy(attr):
attr = attr.remote_attr
diff --git a/flask_admin/contrib/sqla/fields.py b/flask_admin/contrib/sqla/fields.py
index fcb22ceb0..35e4a5ef4 100644
--- a/flask_admin/contrib/sqla/fields.py
+++ b/flask_admin/contrib/sqla/fields.py
@@ -176,7 +176,7 @@ def pre_validate(self, form):
if self._invalid_formdata:
raise ValidationError(self.gettext(u'Not a valid choice'))
elif self.data:
- obj_list = list(x[1] for x in self._get_object_list())
+ obj_list = [x[1] for x in self._get_object_list()]
for v in self.data:
if v not in obj_list:
raise ValidationError(self.gettext(u'Not a valid choice'))
@@ -233,7 +233,7 @@ def process(self, formdata, data=unset_value):
def populate_obj(self, obj, name):
""" Combines each FormField key/value into a dictionary for storage """
- _fake = type(str('_fake'), (object, ), {})
+ _fake = type('_fake', (object, ), {})
output = {}
for form_field in self.entries:
@@ -297,7 +297,7 @@ def populate_obj(self, obj, name):
return
# Create primary key map
- pk_map = dict((get_obj_pk(v, self._pk), v) for v in values)
+ pk_map = {get_obj_pk(v, self._pk): v for v in values}
# Handle request data
for field in self.entries:
diff --git a/flask_admin/contrib/sqla/filters.py b/flask_admin/contrib/sqla/filters.py
index cdbe680eb..9ea2d956c 100644
--- a/flask_admin/contrib/sqla/filters.py
+++ b/flask_admin/contrib/sqla/filters.py
@@ -87,7 +87,7 @@ def operation(self):
class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
def apply(self, query, value, alias=None):
if value == '1':
- return query.filter(self.get_column(alias) == None) # noqa: E711
+ return query.filter(self.get_column(alias) is None)
else:
return query.filter(self.get_column(alias) != None) # noqa: E711
@@ -113,7 +113,7 @@ class FilterNotInList(FilterInList):
def apply(self, query, value, alias=None):
# NOT IN can exclude NULL values, so "or_ == None" needed to be added
column = self.get_column(alias)
- return query.filter(or_(~column.in_(value), column == None)) # noqa: E711
+ return query.filter(or_(~column.in_(value), column is None))
def operation(self):
return lazy_gettext('not in list')
@@ -384,7 +384,7 @@ def apply(self, query, user_query, alias=None):
break
if choice_type:
# != can exclude NULL values, so "or_ == None" needed to be added
- return query.filter(or_(column != choice_type, column == None)) # noqa: E711
+ return query.filter(or_(column != choice_type, column is None))
else:
return query
@@ -399,17 +399,20 @@ def apply(self, query, user_query, alias=None):
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
- for choice in column.type.choices:
- if user_query.lower() in choice.name.lower():
- choice_types.append(choice.value)
+ choice_types.extend(
+ choice.value
+ for choice in column.type.choices
+ if user_query.lower() in choice.name.lower()
+ )
+
else:
- for type, value in column.type.choices:
- if user_query.lower() in value.lower():
- choice_types.append(type)
- if choice_types:
- return query.filter(column.in_(choice_types))
- else:
- return query
+ choice_types.extend(
+ type
+ for type, value in column.type.choices
+ if user_query.lower() in value.lower()
+ )
+
+ return query.filter(column.in_(choice_types)) if choice_types else query
class ChoiceTypeNotLikeFilter(FilterNotLike):
@@ -422,16 +425,22 @@ def apply(self, query, user_query, alias=None):
if user_query:
# loop through choice 'values' looking for matches
if isinstance(column.type.choices, enum.EnumMeta):
- for choice in column.type.choices:
- if user_query.lower() in choice.name.lower():
- choice_types.append(choice.value)
+ choice_types.extend(
+ choice.value
+ for choice in column.type.choices
+ if user_query.lower() in choice.name.lower()
+ )
+
else:
- for type, value in column.type.choices:
- if user_query.lower() in value.lower():
- choice_types.append(type)
+ choice_types.extend(
+ type
+ for type, value in column.type.choices
+ if user_query.lower() in value.lower()
+ )
+
if choice_types:
# != can exclude NULL values, so "or_ == None" needed to be added
- return query.filter(or_(column.notin_(choice_types), column == None)) # noqa: E711
+ return query.filter(or_(column.notin_(choice_types), column is None))
else:
return query
diff --git a/flask_admin/contrib/sqla/form.py b/flask_admin/contrib/sqla/form.py
index f6dd499fb..88a372c60 100644
--- a/flask_admin/contrib/sqla/form.py
+++ b/flask_admin/contrib/sqla/form.py
@@ -44,13 +44,12 @@ def _get_label(self, name, field_args):
if 'label' in field_args:
return field_args['label']
- column_labels = get_property(self.view, 'column_labels', 'rename_columns')
-
- if column_labels:
+ if column_labels := get_property(
+ self.view, 'column_labels', 'rename_columns'
+ ):
return column_labels.get(name)
- prettify_override = getattr(self.view, 'prettify_name', None)
- if prettify_override:
+ if prettify_override := getattr(self.view, 'prettify_name', None):
return prettify_override(name)
return prettify_name(name)
@@ -59,23 +58,17 @@ def _get_description(self, name, field_args):
if 'description' in field_args:
return field_args['description']
- column_descriptions = getattr(self.view, 'column_descriptions', None)
-
- if column_descriptions:
+ if column_descriptions := getattr(self.view, 'column_descriptions', None):
return column_descriptions.get(name)
def _get_field_override(self, name):
- form_overrides = getattr(self.view, 'form_overrides', None)
-
- if form_overrides:
+ if form_overrides := getattr(self.view, 'form_overrides', None):
return form_overrides.get(name)
return None
def _model_select_field(self, prop, multiple, remote_model, **kwargs):
- loader = getattr(self.view, '_form_ajax_refs', {}).get(prop.key)
-
- if loader:
+ if loader := getattr(self.view, '_form_ajax_refs', {}).get(prop.key):
if multiple:
return AjaxSelectMultipleField(loader, **kwargs)
else:
@@ -118,9 +111,7 @@ def _convert_relation(self, name, prop, property_is_association_proxy, kwargs):
if not requirement_validator_specified:
kwargs['validators'].append(validators.InputRequired())
- # Override field type if necessary
- override = self._get_field_override(prop.key)
- if override:
+ if override := self._get_field_override(prop.key):
return override(**kwargs)
multiple = (property_is_association_proxy or
@@ -138,21 +129,29 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk):
}
if field_args:
- kwargs.update(field_args)
+ kwargs |= field_args
if kwargs['validators']:
# Create a copy of the list since we will be modifying it.
kwargs['validators'] = list(kwargs['validators'])
# Check if it is relation or property
- if hasattr(prop, 'direction') or is_association_proxy(prop):
+ if (
+ hasattr(prop, 'direction')
+ or not hasattr(prop, 'direction')
+ and is_association_proxy(prop)
+ ):
property_is_association_proxy = is_association_proxy(prop)
if property_is_association_proxy:
if not hasattr(prop.remote_attr, 'prop'):
raise Exception('Association proxy referencing another association proxy is not supported.')
prop = prop.remote_attr.prop
return self._convert_relation(name, prop, property_is_association_proxy, kwargs)
- elif hasattr(prop, 'columns'): # Ignore pk/fk
+ elif (
+ not hasattr(prop, 'direction')
+ and not is_association_proxy(prop)
+ and hasattr(prop, 'columns')
+ ): # Ignore pk/fk
# Check if more than one column mapped to the property
if len(prop.columns) > 1 and not isinstance(prop, ColumnProperty):
columns = filter_foreign_columns(model.__table__, prop.columns)
@@ -160,7 +159,10 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk):
if len(columns) == 0:
return None
elif len(columns) > 1:
- warnings.warn('Can not convert multiple-column properties (%s.%s)' % (model, prop.key))
+ warnings.warn(
+ f'Can not convert multiple-column properties ({model}.{prop.key})'
+ )
+
return None
column = columns[0]
@@ -184,18 +186,17 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk):
if hidden_pk:
# If requested to add hidden field, show it
return fields.HiddenField()
- else:
- # By default, don't show primary keys either
- # If PK is not explicitly allowed, ignore it
- if prop.key not in form_columns:
- return None
-
- # Current Unique Validator does not work with multicolumns-pks
- if not has_multiple_pks(model):
- kwargs['validators'].append(Unique(self.session,
- model,
- column))
- unique = True
+ # By default, don't show primary keys either
+ # If PK is not explicitly allowed, ignore it
+ if prop.key not in form_columns:
+ return None
+
+ # Current Unique Validator does not work with multicolumns-pks
+ if not has_multiple_pks(model):
+ kwargs['validators'].append(Unique(self.session,
+ model,
+ column))
+ unique = True
# If field is unique, validate it
if column.unique and not unique:
@@ -228,9 +229,8 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk):
if value is not None:
if getattr(default, 'is_callable', False):
value = lambda: default.arg(None) # noqa: E731
- else:
- if not getattr(default, 'is_scalar', True):
- value = None
+ elif not getattr(default, 'is_scalar', True):
+ value = None
if value is not None:
kwargs['default'] = value
@@ -239,16 +239,13 @@ def convert(self, model, mapper, name, prop, field_args, hidden_pk):
if column.nullable:
kwargs['validators'].append(validators.Optional())
- # Override field type if necessary
- override = self._get_field_override(prop.key)
- if override:
+ if override := self._get_field_override(prop.key):
return override(**kwargs)
# Check if a list of 'form_choices' are specified
form_choices = getattr(self.view, 'form_choices', None)
if mapper.class_ == self.view.model and form_choices:
- choices = form_choices.get(prop.key)
- if choices:
+ if choices := form_choices.get(prop.key):
return form.Select2Field(
choices=choices,
allow_blank=column.nullable,
@@ -441,7 +438,7 @@ def avoid_empty_strings(value):
except AttributeError:
# values are not always strings
pass
- return value if value else None
+ return value or None
def choice_type_coerce_factory(type_):
@@ -474,10 +471,7 @@ def _resolve_prop(prop):
Property to resolve
"""
# Try to see if it is proxied property
- if hasattr(prop, '_proxied_property'):
- return prop._proxied_property
-
- return prop
+ return prop._proxied_property if hasattr(prop, '_proxied_property') else prop
# Get list of fields and generate form
@@ -527,7 +521,11 @@ def find(name):
column, path = get_field_with_path(model, name, return_remote_proxy_attr=False)
- if path and not (is_relationship(column) or is_association_proxy(column)):
+ if (
+ path
+ and not is_relationship(column)
+ and not is_association_proxy(column)
+ ):
raise Exception("form column is located in another table and "
"requires inline_models: {0}".format(name))
@@ -539,7 +537,7 @@ def find(name):
if column is not None and hasattr(column, 'property'):
return relation_name, column.property
- raise ValueError('Invalid model property name %s.%s' % (model, name))
+ raise ValueError(f'Invalid model property name {model}.{name}')
# Filter properties while maintaining property order in 'only' list
properties = (find(x) for x in only)
@@ -602,20 +600,18 @@ def get_info(self, p):
if info is None:
if hasattr(p, '_sa_class_manager'):
return self.form_admin_class(p)
- else:
- model = getattr(p, 'model', None)
+ model = getattr(p, 'model', None)
- if model is None:
- raise Exception('Unknown inline model admin: %s' % repr(p))
+ if model is None:
+ raise Exception(f'Unknown inline model admin: {repr(p)}')
- attrs = dict()
- for attr in dir(p):
- if not attr.startswith('_') and attr != 'model':
- attrs[attr] = getattr(p, attr)
+ attrs = {
+ attr: getattr(p, attr)
+ for attr in dir(p)
+ if not attr.startswith('_') and attr != 'model'
+ }
- return self.form_admin_class(model, **attrs)
-
- info = self.form_admin_class(model, **attrs)
+ return self.form_admin_class(model, **attrs)
# Resolve AJAX FKs
info._form_ajax_refs = self.process_ajax_refs(info)
@@ -623,13 +619,11 @@ def get_info(self, p):
return info
def process_ajax_refs(self, info):
- refs = getattr(info, 'form_ajax_refs', None)
-
result = {}
- if refs:
+ if refs := getattr(info, 'form_ajax_refs', None):
for name, opts in iteritems(refs):
- new_name = '%s-%s' % (info.model.__name__.lower(), name)
+ new_name = f'{info.model.__name__.lower()}-{name}'
loader = None
if isinstance(opts, dict):
@@ -664,28 +658,30 @@ def _calculate_mapping_key_pair(self, model, info):
reverse_prop = None
for prop in target_mapper.iterate_properties:
- if hasattr(prop, 'direction') and prop.direction.name in ('MANYTOONE', 'MANYTOMANY'):
- if issubclass(model, prop.mapper.class_):
- reverse_prop = prop
- break
+ if (
+ hasattr(prop, 'direction')
+ and prop.direction.name in ('MANYTOONE', 'MANYTOMANY')
+ and issubclass(model, prop.mapper.class_)
+ ):
+ reverse_prop = prop
+ break
else:
- raise Exception('Cannot find reverse relation for model %s' % info.model)
+ raise Exception(f'Cannot find reverse relation for model {info.model}')
# Find forward property
forward_prop = None
- if prop.direction.name == 'MANYTOONE':
- candidate = 'ONETOMANY'
- else:
- candidate = 'MANYTOMANY'
-
+ candidate = 'ONETOMANY' if prop.direction.name == 'MANYTOONE' else 'MANYTOMANY'
for prop in mapper.iterate_properties:
- if hasattr(prop, 'direction') and prop.direction.name == candidate:
- if prop.mapper.class_ == target_mapper.class_:
- forward_prop = prop
- break
+ if (
+ hasattr(prop, 'direction')
+ and prop.direction.name == candidate
+ and prop.mapper.class_ == target_mapper.class_
+ ):
+ forward_prop = prop
+ break
else:
- raise Exception('Cannot find forward relation for model %s' % info.model)
+ raise Exception(f'Cannot find forward relation for model {info.model}')
return forward_prop.key, reverse_prop.key
@@ -745,10 +741,9 @@ def contribute(self, model, form_class, inline_model):
# Post-process form
child_form = info.postprocess_form(child_form)
- kwargs = dict()
+ kwargs = {}
- label = self.get_label(info, forward_prop_key)
- if label:
+ if label := self.get_label(info, forward_prop_key):
kwargs['label'] = label
if self.view.form_args:
@@ -776,7 +771,7 @@ def _calculate_mapping_key_pair(self, model, info):
mapper = info.model._sa_class_manager.mapper.base_mapper
target_mapper = model._sa_class_manager.mapper
- inline_relationship = dict()
+ inline_relationship = {}
for forward_prop in mapper.iterate_properties:
if not hasattr(forward_prop, 'direction'):
@@ -788,15 +783,9 @@ def _calculate_mapping_key_pair(self, model, info):
if forward_prop.mapper.class_ != target_mapper.class_:
continue
- # in case when model has few relationships to target model or
- # has just installed references manually. This is more quick
- # solution rather than rotate yet another one loop
- ref = getattr(forward_prop, 'backref')
-
- if not ref:
- ref = getattr(forward_prop, 'back_populates')
-
- if ref:
+ if ref := getattr(forward_prop, 'backref') or getattr(
+ forward_prop, 'back_populates'
+ ):
inline_relationship[ref] = forward_prop.key
continue
@@ -813,13 +802,11 @@ def _calculate_mapping_key_pair(self, model, info):
inline_relationship[backward_prop.key] = forward_prop.key
break
else:
- raise Exception(
- 'Cannot find reverse relation for model %s' % info.model)
+ raise Exception(f'Cannot find reverse relation for model {info.model}')
break
if not inline_relationship:
- raise Exception(
- 'Cannot find forward relation for model %s' % info.model)
+ raise Exception(f'Cannot find forward relation for model {info.model}')
return inline_relationship
@@ -829,7 +816,7 @@ def contribute(self, model, form_class, inline_model):
inline_relationships = self._calculate_mapping_key_pair(model, info)
# Remove reverse property from the list
- ignore = [value for value in inline_relationships.values()]
+ ignore = list(inline_relationships.values())
if info.form_excluded_columns:
exclude = ignore + list(info.form_excluded_columns)
@@ -855,7 +842,7 @@ def contribute(self, model, form_class, inline_model):
# Post-process form
child_form = info.postprocess_form(child_form)
- kwargs = dict()
+ kwargs = {}
# Contribute field
for key in inline_relationships.keys():
diff --git a/flask_admin/contrib/sqla/tools.py b/flask_admin/contrib/sqla/tools.py
index 28390d479..d56d0ccf2 100644
--- a/flask_admin/contrib/sqla/tools.py
+++ b/flask_admin/contrib/sqla/tools.py
@@ -19,13 +19,11 @@
def parse_like_term(term):
if term.startswith('^'):
- stmt = '%s%%' % term[1:]
+ return '%s%%' % term[1:]
elif term.startswith('='):
- stmt = term[1:]
+ return term[1:]
else:
- stmt = '%%%s%%' % term
-
- return stmt
+ return '%%%s%%' % term
def filter_foreign_columns(base_table, columns):
@@ -82,14 +80,9 @@ def tuple_operator_in(model_pk, ids):
"""
ands = []
for id in ids:
- k = []
- for i in range(len(model_pk)):
- k.append(eq(model_pk[i], id[i]))
+ k = [eq(model_pk[i], id[i]) for i in range(len(model_pk))]
ands.append(and_(*k))
- if len(ands) >= 1:
- return or_(*ands)
- else:
- return None
+ return or_(*ands) if ands else None
def get_query_for_ids(modelquery, model, ids):
@@ -125,7 +118,7 @@ def get_columns_for_field(field):
not hasattr(field, 'property') or
not hasattr(field.property, 'columns') or
not field.property.columns):
- raise Exception('Invalid field %s: does not contains any columns.' % field)
+ raise Exception(f'Invalid field {field}: does not contains any columns.')
return field.property.columns
@@ -177,7 +170,7 @@ def get_field_with_path(model, name, return_remote_proxy_attr=True):
columns = get_columns_for_field(attr)
if len(columns) > 1:
- raise Exception('Can only handle one column for %s' % name)
+ raise Exception(f'Can only handle one column for {name}')
column = columns[0]
@@ -190,32 +183,31 @@ def get_field_with_path(model, name, return_remote_proxy_attr=True):
# copied from sqlalchemy-utils
def get_hybrid_properties(model):
- return dict(
- (key, prop)
+ return {
+ key: prop
for key, prop in inspect(model).all_orm_descriptors.items()
if isinstance(prop, hybrid_property)
- )
+ }
def is_hybrid_property(model, attr_name):
- if isinstance(attr_name, string_types):
- names = attr_name.split('.')
- last_model = model
- for i in range(len(names) - 1):
- attr = getattr(last_model, names[i])
- if is_association_proxy(attr):
- attr = attr.remote_attr
- last_model = attr.property.argument
- if isinstance(last_model, string_types):
- last_model = attr.property._clsregistry_resolve_name(last_model)()
- elif isinstance(last_model, _class_resolver):
- last_model = model._decl_class_registry[last_model.arg]
- elif isinstance(last_model, (types.FunctionType, types.MethodType)):
- last_model = last_model()
- last_name = names[-1]
- return last_name in get_hybrid_properties(last_model)
- else:
+ if not isinstance(attr_name, string_types):
return attr_name.name in get_hybrid_properties(model)
+ names = attr_name.split('.')
+ last_model = model
+ for i in range(len(names) - 1):
+ attr = getattr(last_model, names[i])
+ if is_association_proxy(attr):
+ attr = attr.remote_attr
+ last_model = attr.property.argument
+ if isinstance(last_model, string_types):
+ last_model = attr.property._clsregistry_resolve_name(last_model)()
+ elif isinstance(last_model, _class_resolver):
+ last_model = model._decl_class_registry[last_model.arg]
+ elif isinstance(last_model, (types.FunctionType, types.MethodType)):
+ last_model = last_model()
+ last_name = names[-1]
+ return last_name in get_hybrid_properties(last_model)
def is_relationship(attr):
diff --git a/flask_admin/contrib/sqla/validators.py b/flask_admin/contrib/sqla/validators.py
index b82f1f85d..3db94866d 100644
--- a/flask_admin/contrib/sqla/validators.py
+++ b/flask_admin/contrib/sqla/validators.py
@@ -37,7 +37,7 @@ def __call__(self, form, field):
.filter(self.column == field.data)
.one())
- if not hasattr(form, '_obj') or not form._obj == obj:
+ if not hasattr(form, '_obj') or form._obj != obj:
if self.message is None:
self.message = field.gettext(u'Already exists.')
raise ValidationError(self.message)
diff --git a/flask_admin/contrib/sqla/view.py b/flask_admin/contrib/sqla/view.py
index 3095e8d3f..0d1b57f97 100755
--- a/flask_admin/contrib/sqla/view.py
+++ b/flask_admin/contrib/sqla/view.py
@@ -332,9 +332,9 @@ def __init__(self, model, session,
self._search_fields = None
- self._filter_joins = dict()
+ self._filter_joins = {}
- self._sortable_joins = dict()
+ self._sortable_joins = {}
if self.form_choices is None:
self.form_choices = {}
@@ -350,13 +350,12 @@ def __init__(self, model, session,
self._primary_key = self.scaffold_pk()
if self._primary_key is None:
- raise Exception('Model %s does not have primary key.' % self.model.__name__)
+ raise Exception(f'Model {self.model.__name__} does not have primary key.')
# Configuration
- if not self.column_select_related_list:
- self._auto_joins = self.scaffold_auto_joins()
- else:
- self._auto_joins = self.column_select_related_list
+ self._auto_joins = (
+ self.column_select_related_list or self.scaffold_auto_joins()
+ )
# Internal API
def _get_model_iterator(self, model=None):
@@ -441,7 +440,10 @@ def scaffold_list_columns(self):
if len(filtered) == 0:
continue
elif len(filtered) > 1:
- warnings.warn('Can not convert multiple-column properties (%s.%s)' % (self.model, p.key))
+ warnings.warn(
+ f'Can not convert multiple-column properties ({self.model}.{p.key})'
+ )
+
continue
column = filtered[0]
@@ -463,7 +465,7 @@ def scaffold_sortable_columns(self):
Return a dictionary of sortable columns.
Key is column name, value is sort column/field.
"""
- columns = dict()
+ columns = {}
for p in self._get_model_iterator():
if hasattr(p, 'columns'):
@@ -493,45 +495,43 @@ def get_sortable_columns(self):
If `column_sortable_list` is set, will use it. Otherwise, will call
`scaffold_sortable_columns` to get them from the model.
"""
- self._sortable_joins = dict()
+ self._sortable_joins = {}
if self.column_sortable_list is None:
return self.scaffold_sortable_columns()
- else:
- result = dict()
-
- for c in self.column_sortable_list:
- if isinstance(c, tuple):
- if isinstance(c[1], tuple):
- column, path = [], []
- for item in c[1]:
- column_item, path_item = tools.get_field_with_path(self.model, item)
- column.append(column_item)
- path.append(path_item)
- column_name = c[0]
- else:
- column, path = tools.get_field_with_path(self.model, c[1])
- column_name = c[0]
+ result = {}
+
+ for c in self.column_sortable_list:
+ if isinstance(c, tuple):
+ if isinstance(c[1], tuple):
+ column, path = [], []
+ for item in c[1]:
+ column_item, path_item = tools.get_field_with_path(self.model, item)
+ column.append(column_item)
+ path.append(path_item)
else:
- column, path = tools.get_field_with_path(self.model, c)
- column_name = text_type(c)
+ column, path = tools.get_field_with_path(self.model, c[1])
+ column_name = c[0]
+ else:
+ column, path = tools.get_field_with_path(self.model, c)
+ column_name = text_type(c)
- if path and (hasattr(path[0], 'property') or isinstance(path[0], list)):
- self._sortable_joins[column_name] = path
- elif path:
- raise Exception("For sorting columns in a related table, "
- "column_sortable_list requires a string "
- "like '.'. "
- "Failed on: {0}".format(c))
- else:
- # column is in same table, use only model attribute name
- if getattr(column, 'key', None) is not None:
- column_name = column.key
+ if path and (hasattr(path[0], 'property') or isinstance(path[0], list)):
+ self._sortable_joins[column_name] = path
+ elif path:
+ raise Exception("For sorting columns in a related table, "
+ "column_sortable_list requires a string "
+ "like '.'. "
+ "Failed on: {0}".format(c))
+ else:
+ # column is in same table, use only model attribute name
+ if getattr(column, 'key', None) is not None:
+ column_name = column.key
- # column_name must match column_name used in `get_list_columns`
- result[column_name] = column
+ # column_name must match column_name used in `get_list_columns`
+ result[column_name] = column
- return result
+ return result
def get_column_names(self, only_columns, excluded_columns):
"""
@@ -554,15 +554,10 @@ def get_column_names(self, only_columns, excluded_columns):
try:
column, path = tools.get_field_with_path(self.model, c)
- if path:
- # column is a relation (InstrumentedAttribute), use full path
- column_name = text_type(c)
+ if not path and getattr(column, 'key', None) is not None:
+ column_name = column.key
else:
- # column is in same table, use only model attribute name
- if getattr(column, 'key', None) is not None:
- column_name = column.key
- else:
- column_name = text_type(c)
+ column_name = text_type(c)
except AttributeError:
# TODO: See ticket #1299 - allow virtual columns. Probably figure out
# better way to handle it. For now just assume if column was not found - it
@@ -591,7 +586,7 @@ def init_search(self):
attr, joins = tools.get_field_with_path(self.model, name)
if not attr:
- raise Exception('Failed to find field for search field: %s' % name)
+ raise Exception(f'Failed to find field for search field: {name}')
if tools.is_hybrid_property(self.model, name):
column = attr
@@ -599,8 +594,10 @@ def init_search(self):
column.key = name.split('.')[-1]
self._search_fields.append((column, joins))
else:
- for column in tools.get_columns_for_field(attr):
- self._search_fields.append((column, joins))
+ self._search_fields.extend(
+ (column, joins)
+ for column in tools.get_columns_for_field(attr)
+ )
return bool(self.column_searchable_list)
@@ -639,7 +636,7 @@ def scaffold_filters(self, name):
attr, joins = tools.get_field_with_path(self.model, name)
if attr is None:
- raise Exception('Failed to find field for filter: %s' % name)
+ raise Exception(f'Failed to find field for filter: {name}')
# Figure out filters for related column
if is_relationship(attr):
@@ -653,15 +650,13 @@ def scaffold_filters(self, name):
if column.foreign_keys or column.primary_key:
continue
- visible_name = '%s / %s' % (self.get_column_name(attr.prop.target.name),
- self.get_column_name(p.key))
+ visible_name = f'{self.get_column_name(attr.prop.target.name)} / {self.get_column_name(p.key)}'
- type_name = type(column.type).__name__
- flt = self.filter_converter.convert(type_name,
- column,
- visible_name)
- if flt:
+ type_name = type(column.type).__name__
+ if flt := self.filter_converter.convert(
+ type_name, column, visible_name
+ ):
table = column.table
if joins:
@@ -682,7 +677,7 @@ def scaffold_filters(self, name):
columns = tools.get_columns_for_field(attr)
if len(columns) > 1:
- raise Exception('Can not filter more than on one column for %s' % name)
+ raise Exception(f'Can not filter more than on one column for {name}')
column = columns[0]
@@ -695,26 +690,19 @@ def scaffold_filters(self, name):
# Join not needed for hybrid properties
if (not is_hybrid_property and tools.need_join(self.model, column.table) and
name not in self.column_labels):
- if joined_column_name:
- visible_name = '%s / %s / %s' % (
- joined_column_name,
- self.get_column_name(column.table.name),
- self.get_column_name(column.name)
- )
- else:
- visible_name = '%s / %s' % (
- self.get_column_name(column.table.name),
- self.get_column_name(column.name)
- )
+ visible_name = (
+ f'{joined_column_name} / {self.get_column_name(column.table.name)} / {self.get_column_name(column.name)}'
+ if joined_column_name
+ else f'{self.get_column_name(column.table.name)} / {self.get_column_name(column.name)}'
+ )
+
+ elif not isinstance(name, string_types):
+ visible_name = self.get_column_name(name.property.key)
+ elif self.column_labels and name in self.column_labels:
+ visible_name = self.column_labels[name]
else:
- if not isinstance(name, string_types):
- visible_name = self.get_column_name(name.property.key)
- else:
- if self.column_labels and name in self.column_labels:
- visible_name = self.column_labels[name]
- else:
- visible_name = self.get_column_name(name)
- visible_name = visible_name.replace('.', ' / ')
+ visible_name = self.get_column_name(name)
+ visible_name = visible_name.replace('.', ' / ')
type_name = type(column.type).__name__
@@ -837,13 +825,11 @@ def scaffold_auto_joins(self):
if p.direction.name in ['MANYTOONE', 'MANYTOMANY']:
relations.add(p.key)
- joined = []
-
- for prop, name in self._list_columns:
- if prop in relations:
- joined.append(getattr(self.model, prop))
-
- return joined
+ return [
+ getattr(self.model, prop)
+ for prop, name in self._list_columns
+ if prop in relations
+ ]
# AJAX foreignkey support
def _create_ajax_loader(self, name, options):
@@ -900,11 +886,7 @@ def _order_by(self, query, joins, sort_joins, sort_field, sort_desc):
column = sort_field if alias is None else getattr(alias, sort_field.key)
- if sort_desc:
- query = query.order_by(desc(column))
- else:
- query = query.order_by(column)
-
+ query = query.order_by(desc(column)) if sort_desc else query.order_by(column)
return query, joins
def _get_default_order(self):
@@ -914,21 +896,20 @@ def _get_default_order(self):
yield attr, joins, direction
def _apply_sorting(self, query, joins, sort_column, sort_desc):
- if sort_column is not None:
- if sort_column in self._sortable_columns:
- sort_field = self._sortable_columns[sort_column]
- sort_joins = self._sortable_joins.get(sort_column)
-
- if isinstance(sort_field, list):
- for field_item, join_item in zip(sort_field, sort_joins):
- query, joins = self._order_by(query, joins, join_item, field_item, sort_desc)
- else:
- query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
- else:
+ if sort_column is None:
order = self._get_default_order()
for sort_field, sort_joins, sort_desc in order:
query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
+ elif sort_column in self._sortable_columns:
+ sort_field = self._sortable_columns[sort_column]
+ sort_joins = self._sortable_joins.get(sort_column)
+
+ if isinstance(sort_field, list):
+ for field_item, join_item in zip(sort_field, sort_joins):
+ query, joins = self._order_by(query, joins, join_item, field_item, sort_desc)
+ else:
+ query, joins = self._order_by(query, joins, sort_joins, sort_field, sort_desc)
return query, joins
def _apply_search(self, query, count_query, joins, count_joins, search):
@@ -1057,7 +1038,7 @@ def get_list(self, page, sort_column, sort_desc, search, filters,
count_joins = {}
query = self.get_query()
- count_query = self.get_count_query() if not self.simple_list_pager else None
+ count_query = None if self.simple_list_pager else self.get_count_query()
# Ignore eager-loaded relations (prevent unnecessary joins)
# TODO: Separate join detection for query and count query?
@@ -1243,11 +1224,7 @@ def action_delete(self, ids):
if self.fast_mass_delete:
count = query.delete(synchronize_session=False)
else:
- count = 0
-
- for m in query.all():
- if self.delete_model(m):
- count += 1
+ count = sum(1 for m in query.all() if self.delete_model(m))
self.session.commit()
diff --git a/flask_admin/form/__init__.py b/flask_admin/form/__init__.py
index 9db32c5a4..4a019126c 100644
--- a/flask_admin/form/__init__.py
+++ b/flask_admin/form/__init__.py
@@ -35,7 +35,10 @@ def recreate_field(unbound):
UnboundField instance
"""
if not isinstance(unbound, UnboundField):
- raise ValueError('recreate_field expects UnboundField instance, %s was passed.' % type(unbound))
+ raise ValueError(
+ f'recreate_field expects UnboundField instance, {type(unbound)} was passed.'
+ )
+
return unbound.field_class(*unbound.args, **unbound.kwargs)
diff --git a/flask_admin/form/rules.py b/flask_admin/form/rules.py
index 4651a6632..057c7688f 100644
--- a/flask_admin/form/rules.py
+++ b/flask_admin/form/rules.py
@@ -82,8 +82,7 @@ def visible_fields(self):
"""
visible_fields = []
for rule in self.rules:
- for field in rule.visible_fields:
- visible_fields.append(field)
+ visible_fields.extend(iter(rule.visible_fields))
return visible_fields
def __iter__(self):
@@ -103,10 +102,7 @@ def __call__(self, form, form_opts=None, field_args={}):
:param field_args:
Optional arguments that should be passed to template or the field
"""
- result = []
-
- for r in self.rules:
- result.append(r(form, form_opts, field_args))
+ result = [r(form, form_opts, field_args) for r in self.rules]
return Markup(self.separator.join(result))
@@ -130,10 +126,7 @@ def __init__(self, text, escape=True):
self.escape = escape
def __call__(self, form, form_opts=None, field_args={}):
- if self.escape:
- return self.text
-
- return Markup(self.text)
+ return self.text if self.escape else Markup(self.text)
class HTML(Text):
@@ -205,10 +198,10 @@ def __call__(self, form, form_opts=None, field_args={}):
macro = self._resolve(context, self.macro_name)
if not macro:
- raise ValueError('Cannot find macro %s in current context.' % self.macro_name)
+ raise ValueError(f'Cannot find macro {self.macro_name} in current context.')
opts = dict(self.default_args)
- opts.update(field_args)
+ opts |= field_args
return macro(**opts)
@@ -302,12 +295,12 @@ def __call__(self, form, form_opts=None, field_args={}):
field = getattr(form, self.field_name, None)
if field is None:
- raise ValueError('Form %s does not have field %s' % (form, self.field_name))
+ raise ValueError(f'Form {form} does not have field {self.field_name}')
opts = {}
if form_opts:
- opts.update(form_opts.widget_args.get(self.field_name, {}))
+ opts |= form_opts.widget_args.get(self.field_name, {})
opts.update(field_args)
@@ -351,11 +344,7 @@ def __init__(self, rules, header=None, separator=''):
:param separator:
Child rule separator
"""
- if header:
- rule_set = [Header(header)] + list(rules)
- else:
- rule_set = list(rules)
-
+ rule_set = [Header(header)] + list(rules) if header else list(rules)
super(FieldSet, self).__init__(rule_set, separator=separator)
@@ -425,9 +414,7 @@ def __init__(self, field_name, prepend=None, append=None, **kwargs):
@property
def visible_fields(self):
fields = [self.field_name]
- for cnf in self._addons:
- if cnf['type'] == 'field':
- fields.append(cnf['name'])
+ fields.extend(cnf['name'] for cnf in self._addons if cnf['type'] == 'field')
return fields
def __call__(self, form, form_opts=None, field_args={}):
@@ -444,13 +431,9 @@ def __call__(self, form, form_opts=None, field_args={}):
field = getattr(form, self.field_name, None)
if field is None:
- raise ValueError('Form %s does not have field %s' % (form, self.field_name))
-
- if form_opts:
- widget_args = form_opts.widget_args
- else:
- widget_args = {}
+ raise ValueError(f'Form {form} does not have field {self.field_name}')
+ widget_args = form_opts.widget_args if form_opts else {}
opts = {}
prepend = []
append = []
@@ -459,8 +442,7 @@ def __call__(self, form, form_opts=None, field_args={}):
typ = cnf['type']
if typ == 'field':
name = cnf['name']
- fld = form._fields.get(name, None)
- if fld:
+ if fld := form._fields.get(name, None):
w_args = widget_args.setdefault(name, {})
if fld.type in ('BooleanField', 'RadioField'):
w_args.setdefault('class', 'form-check-input')
@@ -484,7 +466,7 @@ def __call__(self, form, form_opts=None, field_args={}):
if append:
opts['append'] = Markup(''.join(append))
- opts.update(widget_args.get(self.field_name, {}))
+ opts |= widget_args.get(self.field_name, {})
opts.update(field_args)
params = {
@@ -516,8 +498,7 @@ def __init__(self, view, rules):
def visible_fields(self):
visible_fields = []
for rule in self.rules:
- for field in rule.visible_fields:
- visible_fields.append(field)
+ visible_fields.extend(iter(rule.visible_fields))
return visible_fields
def convert_string(self, value):
@@ -558,5 +539,4 @@ def __iter__(self):
"""
Iterate through registered rules.
"""
- for r in self.rules:
- yield r
+ yield from self.rules
diff --git a/flask_admin/form/upload.py b/flask_admin/form/upload.py
index 009a79f3e..cb9894841 100644
--- a/flask_admin/form/upload.py
+++ b/flask_admin/form/upload.py
@@ -61,16 +61,21 @@ def __call__(self, field, **kwargs):
else:
value = field.data or ''
- return Markup(template % {
- 'text': html_params(type='text',
- readonly='readonly',
- value=value,
- name=field.name),
- 'file': html_params(type='file',
- value=value,
- **kwargs),
- 'marker': '_%s-delete' % field.name
- })
+ return Markup(
+ (
+ template
+ % {
+ 'text': html_params(
+ type='text',
+ readonly='readonly',
+ value=value,
+ name=field.name,
+ ),
+ 'file': html_params(type='file', value=value, **kwargs),
+ 'marker': f'_{field.name}-delete',
+ }
+ )
+ )
class ImageUploadInput(object):
@@ -95,14 +100,12 @@ def __call__(self, field, **kwargs):
kwargs.setdefault('name', field.name)
args = {
- 'text': html_params(type='hidden',
- value=field.data,
- name=field.name),
- 'file': html_params(type='file',
- **kwargs),
- 'marker': '_%s-delete' % field.name
+ 'text': html_params(type='hidden', value=field.data, name=field.name),
+ 'file': html_params(type='file', **kwargs),
+ 'marker': f'_{field.name}-delete',
}
+
if field.data and isinstance(field.data, string_types):
url = self.get_url(field)
args['image'] = html_params(src=url)
@@ -198,12 +201,15 @@ def is_file_allowed(self, filename):
:param filename:
File name to check
"""
- if not self.allowed_extensions:
- return True
-
- return ('.' in filename and
- filename.rsplit('.', 1)[1].lower() in
- map(lambda x: x.lower(), self.allowed_extensions))
+ return (
+ (
+ '.' in filename
+ and filename.rsplit('.', 1)[1].lower()
+ in map(lambda x: x.lower(), self.allowed_extensions)
+ )
+ if self.allowed_extensions
+ else True
+ )
def _is_uploaded_file(self, data):
return (data and isinstance(data, FileStorage) and data.filename)
@@ -221,7 +227,7 @@ def pre_validate(self, form):
def process(self, formdata, data=unset_value, extra_filters=None):
if formdata:
- marker = '_%s-delete' % self.name
+ marker = f'_{self.name}-delete'
if marker in formdata:
self._should_delete = True
@@ -241,12 +247,10 @@ def process_formdata(self, valuelist):
def populate_obj(self, obj, name):
field = getattr(obj, name, None)
- if field:
- # If field should be deleted, clean it up
- if self._should_delete:
- self._delete_file(field)
- setattr(obj, name, None)
- return
+ if field and self._should_delete:
+ self._delete_file(field)
+ setattr(obj, name, None)
+ return
if self._is_uploaded_file(self.data):
if field:
@@ -412,7 +416,7 @@ def pre_validate(self, form):
try:
self.image = Image.open(self.data)
except Exception as e:
- raise ValidationError('Invalid image: %s' % e)
+ raise ValidationError(f'Invalid image: {e}')
# Deletion
def _delete_file(self, filename):
@@ -465,10 +469,9 @@ def _resize(self, image, size):
if image.size[0] > width or image.size[1] > height:
if force:
return ImageOps.fit(self.image, (width, height), Image.ANTIALIAS)
- else:
- thumb = self.image.copy()
- thumb.thumbnail((width, height), Image.ANTIALIAS)
- return thumb
+ thumb = self.image.copy()
+ thumb.thumbnail((width, height), Image.ANTIALIAS)
+ return thumb
return image
@@ -485,7 +488,7 @@ def _save_image(self, image, path, format='JPEG'):
def _get_save_format(self, filename, image):
if image.format not in self.keep_image_formats:
name, ext = op.splitext(filename)
- filename = '%s.jpg' % name
+ filename = f'{name}.jpg'
return filename, 'JPEG'
return filename, image.format
@@ -504,4 +507,4 @@ def thumbgen_filename(filename):
Generate thumbnail name from filename.
"""
name, ext = op.splitext(filename)
- return '%s_thumb%s' % (name, ext)
+ return f'{name}_thumb{ext}'
diff --git a/flask_admin/helpers.py b/flask_admin/helpers.py
index da6033715..ec4a8772d 100644
--- a/flask_admin/helpers.py
+++ b/flask_admin/helpers.py
@@ -51,10 +51,12 @@ def is_required_form_field(field):
WTForms field to check
"""
from flask_admin.form.validators import FieldListInputRequired
- for validator in field.validators:
- if isinstance(validator, (DataRequired, InputRequired, FieldListInputRequired)):
- return True
- return False
+ return any(
+ isinstance(
+ validator, (DataRequired, InputRequired, FieldListInputRequired)
+ )
+ for validator in field.validators
+ )
def is_form_submitted():
@@ -104,7 +106,7 @@ def is_field_error(errors):
def flash_errors(form, message):
from flask_admin.babel import gettext
for field_name, errors in iteritems(form.errors):
- errors = form[field_name].label.text + u": " + u", ".join(errors)
+ errors = f"{form[field_name].label.text}: " + u", ".join(errors)
flash(gettext(message, error=str(errors)), 'error')
diff --git a/flask_admin/menu.py b/flask_admin/menu.py
index 2257d3384..a983dabb9 100644
--- a/flask_admin/menu.py
+++ b/flask_admin/menu.py
@@ -27,11 +27,7 @@ def is_category(self):
return False
def is_active(self, view):
- for c in self._children:
- if c.is_active(view):
- return True
-
- return False
+ return any(c.is_active(view) for c in self._children)
def get_class_name(self):
return self.class_name
@@ -63,18 +59,10 @@ def is_category(self):
return True
def is_visible(self):
- for c in self._children:
- if c.is_visible():
- return True
-
- return False
+ return any(c.is_visible() for c in self._children)
def is_accessible(self):
- for c in self._children:
- if c.is_accessible():
- return True
-
- return False
+ return any(c.is_accessible() for c in self._children)
class MenuView(BaseMenu):
@@ -100,7 +88,7 @@ def get_url(self):
if self._cached_url:
return self._cached_url
- url = self._view.get_url('%s.%s' % (self._view.endpoint, self._view._default_view))
+ url = self._view.get_url(f'{self._view.endpoint}.{self._view._default_view}')
if self._cache:
self._cached_url = url
@@ -108,22 +96,13 @@ def get_url(self):
return url
def is_active(self, view):
- if view == self._view:
- return True
-
- return super(MenuView, self).is_active(view)
+ return True if view == self._view else super(MenuView, self).is_active(view)
def is_visible(self):
- if self._view is None:
- return False
-
- return self._view.is_visible()
+ return False if self._view is None else self._view.is_visible()
def is_accessible(self):
- if self._view is None:
- return False
-
- return self._view.is_accessible()
+ return False if self._view is None else self._view.is_accessible()
class MenuLink(BaseMenu):
diff --git a/flask_admin/model/base.py b/flask_admin/model/base.py
index 665e64680..01c046840 100755
--- a/flask_admin/model/base.py
+++ b/flask_admin/model/base.py
@@ -56,11 +56,7 @@ def __init__(self, page=None, page_size=None, sort=None, sort_desc=None,
self.extra_args = extra_args or dict()
def clone(self, **kwargs):
- if self.filters:
- flt = list(self.filters)
- else:
- flt = None
-
+ flt = list(self.filters) if self.filters else None
kwargs.setdefault('page', self.page)
kwargs.setdefault('page_size', self.page_size)
kwargs.setdefault('sort', self.sort)
@@ -85,8 +81,7 @@ def non_lazy(self):
for item in self.filters:
copy = dict(item)
copy['operation'] = as_unicode(copy['operation'])
- options = copy['options']
- if options:
+ if options := copy['options']:
copy['options'] = [(k, text_type(v)) for k, v in options]
filters.append(copy)
@@ -803,7 +798,7 @@ def __init__(self, model,
# If name not provided, it is model name
if name is None:
- name = '%s' % self._prettify_class_name(model.__name__)
+ name = f'{self._prettify_class_name(model.__name__)}'
super(BaseModelView, self).__init__(name, category, endpoint, url, static_folder,
menu_class_name=menu_class_name,
@@ -938,7 +933,7 @@ def _refresh_cache(self):
self.column_type_formatters_detail = dict(typefmt.DETAIL_FORMATTERS)
if self.column_descriptions is None:
- self.column_descriptions = dict()
+ self.column_descriptions = {}
# Filters
self._refresh_filters_cache()
@@ -1088,16 +1083,15 @@ def get_sortable_columns(self):
"""
if self.column_sortable_list is None:
return self.scaffold_sortable_columns() or dict()
- else:
- result = dict()
+ result = {}
- for c in self.column_sortable_list:
- if isinstance(c, tuple):
- result[c[0]] = c[1]
- else:
- result[c] = c
+ for c in self.column_sortable_list:
+ if isinstance(c, tuple):
+ result[c[0]] = c[1]
+ else:
+ result[c] = c
- return result
+ return result
def init_search(self):
"""
@@ -1150,21 +1144,18 @@ def get_filters(self):
If your model backend implementation does not support filters,
override this method and return `None`.
"""
- if self.column_filters:
- collection = []
-
- for n in self.column_filters:
- if self.is_valid_filter(n):
- collection.append(self.handle_filter(n))
- else:
- flt = self.scaffold_filters(n)
- if flt:
- collection.extend(flt)
- else:
- raise Exception('Unsupported filter type %s' % n)
- return collection
- else:
+ if not self.column_filters:
return None
+ collection = []
+
+ for n in self.column_filters:
+ if self.is_valid_filter(n):
+ collection.append(self.handle_filter(n))
+ elif flt := self.scaffold_filters(n):
+ collection.extend(flt)
+ else:
+ raise Exception(f'Unsupported filter type {n}')
+ return collection
def get_filter_arg(self, index, flt):
"""
@@ -1178,21 +1169,20 @@ def get_filter_arg(self, index, flt):
:param flt:
Filter instance
"""
- if self.named_filter_urls:
- operation = flt.operation()
+ if not self.named_filter_urls:
+ return str(index)
+ operation = flt.operation()
- try:
- # get lazy string original value
- operation = operation._args[0]
- except AttributeError:
- pass
+ try:
+ # get lazy string original value
+ operation = operation._args[0]
+ except AttributeError:
+ pass
- name = ('%s %s' % (flt.name, as_unicode(operation))).lower()
- name = filter_char_re.sub('', name)
- name = filter_compact_re.sub('_', name)
- return name
- else:
- return str(index)
+ name = f'{flt.name} {as_unicode(operation)}'.lower()
+ name = filter_char_re.sub('', name)
+ name = filter_compact_re.sub('_', name)
+ return name
def _get_filter_groups(self):
"""
@@ -1241,10 +1231,7 @@ def get_form(self):
Override to implement customized behavior.
"""
- if self.form is not None:
- return self.form
-
- return self.scaffold_form()
+ return self.form if self.form is not None else self.scaffold_form()
def get_list_form(self):
"""
@@ -1271,11 +1258,12 @@ def get_list_form(self):
"""
if self.form_args:
# get only validators, other form_args can break FieldList wrapper
- validators = dict(
- (key, {'validators': value["validators"]})
+ validators = {
+ key: {'validators': value["validators"]}
for key, value in iteritems(self.form_args)
if value.get("validators")
- )
+ }
+
else:
validators = None
@@ -1405,9 +1393,9 @@ def _get_ruleset_missing_fields(self, ruleset, form):
if ruleset:
visible_fields = ruleset.visible_fields
- for field in form:
- if field.name not in visible_fields:
- missing_fields.append(field.name)
+ missing_fields.extend(
+ field.name for field in form if field.name not in visible_fields
+ )
return missing_fields
@@ -1415,27 +1403,36 @@ def _show_missing_fields_warning(self, text):
warnings.warn(text)
def _validate_form_class(self, ruleset, form_class, remove_missing=True):
- form_fields = []
- for name, obj in iteritems(form_class.__dict__):
- if isinstance(obj, UnboundField):
- form_fields.append(name)
+ form_fields = [
+ name
+ for name, obj in iteritems(form_class.__dict__)
+ if isinstance(obj, UnboundField)
+ ]
missing_fields = []
if ruleset:
visible_fields = ruleset.visible_fields
- for field_name in form_fields:
- if field_name not in visible_fields:
- missing_fields.append(field_name)
+ missing_fields.extend(
+ field_name
+ for field_name in form_fields
+ if field_name not in visible_fields
+ )
if missing_fields:
- self._show_missing_fields_warning('Fields missing from ruleset: %s' % (','.join(missing_fields)))
+ self._show_missing_fields_warning(
+ f"Fields missing from ruleset: {','.join(missing_fields)}"
+ )
+
if remove_missing:
self._remove_fields_from_form_class(missing_fields, form_class)
def _validate_form_instance(self, ruleset, form, remove_missing=True):
missing_fields = self._get_ruleset_missing_fields(ruleset=ruleset, form=form)
if missing_fields:
- self._show_missing_fields_warning('Fields missing from ruleset: %s' % (','.join(missing_fields)))
+ self._show_missing_fields_warning(
+ f"Fields missing from ruleset: {','.join(missing_fields)}"
+ )
+
if remove_missing:
self._remove_fields_from_form_instance(missing_fields, form)
@@ -1711,33 +1708,32 @@ def get_invalid_value_msg(self, value, filter):
# URL generation helpers
def _get_list_filter_args(self):
- if self._filters:
- filters = []
-
- for arg in request.args:
- if not arg.startswith('flt'):
- continue
+ if not self._filters:
+ return None
+ filters = []
- if '_' not in arg:
- continue
+ for arg in request.args:
+ if not arg.startswith('flt'):
+ continue
- pos, key = arg[3:].split('_', 1)
+ if '_' not in arg:
+ continue
- if key in self._filter_args:
- idx, flt = self._filter_args[key]
+ pos, key = arg[3:].split('_', 1)
- value = request.args[arg]
+ if key in self._filter_args:
+ idx, flt = self._filter_args[key]
- if flt.validate(value):
- data = (pos, (idx, as_unicode(flt.name), value))
- filters.append(data)
- else:
- flash(self.get_invalid_value_msg(value, flt), 'error')
+ value = request.args[arg]
- # Sort filters
- return [v[1] for v in sorted(filters, key=lambda n: n[0])]
+ if flt.validate(value):
+ data = (pos, (idx, as_unicode(flt.name), value))
+ filters.append(data)
+ else:
+ flash(self.get_invalid_value_msg(value, flt), 'error')
- return None
+ # Sort filters
+ return [v[1] for v in sorted(filters, key=lambda n: n[0])]
def _get_list_extra_args(self):
"""
@@ -1788,7 +1784,7 @@ def _get_list_url(self, view_args):
desc = 1 if view_args.sort_desc else None
kwargs = dict(page=page, sort=view_args.sort, desc=desc, search=view_args.search)
- kwargs.update(view_args.extra_args)
+ kwargs |= view_args.extra_args
if view_args.page_size:
kwargs['page_size'] = view_args.page_size
@@ -1836,15 +1832,18 @@ def _get_list_value(self, context, model, name, column_formatters,
else:
value = self._get_field_value(model, name)
- choices_map = self._column_choices_map.get(name, {})
- if choices_map:
+ if choices_map := self._column_choices_map.get(name, {}):
return choices_map.get(value) or value
- type_fmt = None
- for typeobj, formatter in column_type_formatters.items():
- if isinstance(value, typeobj):
- type_fmt = formatter
- break
+ type_fmt = next(
+ (
+ formatter
+ for typeobj, formatter in column_type_formatters.items()
+ if isinstance(value, typeobj)
+ ),
+ None,
+ )
+
if type_fmt is not None:
value = type_fmt(self, value)
@@ -1912,10 +1911,7 @@ def get_export_name(self, export_type='csv'):
"""
:return: The exported csv file name.
"""
- filename = '%s_%s.%s' % (self.name,
- time.strftime("%Y-%m-%d_%H-%M-%S"),
- export_type)
- return filename
+ return f'{self.name}_{time.strftime("%Y-%m-%d_%H-%M-%S")}.{export_type}'
# AJAX references
def _process_ajax_references(self):
@@ -1932,7 +1928,7 @@ def _process_ajax_references(self):
elif isinstance(options, AjaxModelLoader):
result[name] = options
else:
- raise ValueError('%s.form_ajax_refs can not handle %s types' % (self, type(options)))
+ raise ValueError(f'{self}.form_ajax_refs can not handle {type(options)} types')
return result
@@ -2083,10 +2079,7 @@ def create_view(self):
self._validate_form_instance(ruleset=self._form_create_rules, form=form)
if self.validate_form(form):
- # in versions 1.1.0 and before, this returns a boolean
- # in later versions, this is the model itself
- model = self.create_model(form)
- if model:
+ if model := self.create_model(form):
flash(gettext('Record was successfully created.'), 'success')
if '_add_another' in request.form:
return redirect(request.url)
@@ -2138,16 +2131,15 @@ def edit_view(self):
if not hasattr(form, '_validated_ruleset') or not form._validated_ruleset:
self._validate_form_instance(ruleset=self._form_edit_rules, form=form)
- if self.validate_form(form):
- if self.update_model(form, model):
- flash(gettext('Record was successfully saved.'), 'success')
- if '_add_another' in request.form:
- return redirect(self.get_url('.create_view', url=return_url))
- elif '_continue_editing' in request.form:
- return redirect(self.get_url('.edit_view', id=self.get_pk_value(model)))
- else:
- # save button
- return redirect(self.get_save_return_url(model, is_created=False))
+ if self.validate_form(form) and self.update_model(form, model):
+ flash(gettext('Record was successfully saved.'), 'success')
+ if '_add_another' in request.form:
+ return redirect(self.get_url('.create_view', url=return_url))
+ elif '_continue_editing' in request.form:
+ return redirect(self.get_url('.edit_view', id=self.get_pk_value(model)))
+ else:
+ # save button
+ return redirect(self.get_save_return_url(model, is_created=False))
if request.method == 'GET' or form.errors:
self.on_form_prefill(form, id)
@@ -2334,13 +2326,13 @@ def _export_tablib(self, export_type, return_url):
filename = self.get_export_name(export_type)
- disposition = 'attachment;filename=%s' % (secure_filename(filename),)
+ disposition = f'attachment;filename={secure_filename(filename)}'
mimetype, encoding = mimetypes.guess_type(filename)
if not mimetype:
mimetype = 'application/octet-stream'
if encoding:
- mimetype = '%s; charset=%s' % (mimetype, encoding)
+ mimetype = f'{mimetype}; charset={encoding}'
ds = tablib.Dataset(headers=[csv_encode(c[1]) for c in self._export_columns])
@@ -2394,9 +2386,7 @@ def ajax_update(self):
# prevent validation issues due to submitting a single field
# delete all fields except the submitted fields and csrf token
for field in list(form):
- if (field.name in request.form) or (field.name == 'csrf_token'):
- pass
- else:
+ if field.name not in request.form and field.name != 'csrf_token':
form.__delitem__(field.name)
if self.validate_form(form):
@@ -2409,11 +2399,10 @@ def ajax_update(self):
if self.update_model(form, record):
# Success
return gettext('Record was successfully saved.')
- else:
# Error: No records changed, or problem saving to database.
- msgs = ", ".join([msg for msg in get_flashed_messages()])
- return gettext('Failed to update record. %(error)s',
- error=msgs), 500
+ msgs = ", ".join(list(get_flashed_messages()))
+ return gettext('Failed to update record. %(error)s',
+ error=msgs), 500
else:
for field in form:
for error in field.errors:
diff --git a/flask_admin/model/fields.py b/flask_admin/model/fields.py
index ed28c7e88..da4bc46d4 100644
--- a/flask_admin/model/fields.py
+++ b/flask_admin/model/fields.py
@@ -20,9 +20,7 @@ def __init__(self, *args, **kwargs):
super(InlineFieldList, self).__init__(*args, **kwargs)
def __call__(self, **kwargs):
- # Create template
- meta = getattr(self, 'meta', None)
- if meta:
+ if meta := getattr(self, 'meta', None):
template = self.unbound_field.bind(form=None, name='', _meta=meta)
else:
template = self.unbound_field.bind(form=None, name='')
@@ -47,7 +45,7 @@ def process(self, formdata, data=unset_value, extra_filters=None):
# Postprocess - contribute flag
if formdata:
for f in self.entries:
- key = 'del-%s' % f.id
+ key = f'del-{f.id}'
f._should_delete = key in formdata
return res
@@ -60,17 +58,17 @@ def validate(self, form, extra_validators=tuple()):
that FieldList validates all its enclosed fields first before running any
of its own validators.
"""
- self.errors = []
+ self.errors = [
+ subfield.errors
+ for subfield in self.entries
+ if not self.should_delete(subfield) and not subfield.validate(form)
+ ]
- # Run validators on all entries within
- for subfield in self.entries:
- if not self.should_delete(subfield) and not subfield.validate(form):
- self.errors.append(subfield.errors)
chain = itertools.chain(self.validators, extra_validators)
self._run_validation_chain(form, chain)
- return len(self.errors) == 0
+ return not self.errors
def should_delete(self, field):
return getattr(field, '_should_delete', False)
@@ -83,7 +81,7 @@ def populate_obj(self, obj, name):
ivalues = iter([])
candidates = itertools.chain(ivalues, itertools.repeat(None))
- _fake = type(str('_fake'), (object, ), {})
+ _fake = type('_fake', (object, ), {})
output = []
for field, data in zip(self.entries, candidates):
@@ -192,8 +190,7 @@ def __init__(self, loader, label=None, validators=None, default=None, **kwargs):
self._invalid_formdata = False
def _get_data(self):
- formdata = self._formdata
- if formdata:
+ if formdata := self._formdata:
data = []
# TODO: Optimize?
diff --git a/flask_admin/model/filters.py b/flask_admin/model/filters.py
index 414392e95..47f19f6f2 100644
--- a/flask_admin/model/filters.py
+++ b/flask_admin/model/filters.py
@@ -36,9 +36,7 @@ def get_options(self, view):
:param view:
Associated administrative view class.
"""
- options = self.options
-
- if options:
+ if options := self.options:
if callable(options):
options = options()
@@ -178,10 +176,7 @@ def validate(self, value):
for range in value.split(' to ')]
# if " to " is missing, fail validation
# sqlalchemy's .between() will not work if end date is before start date
- if (len(value) == 2) and (value[0] <= value[1]):
- return True
- else:
- return False
+ return len(value) == 2 and value[0] <= value[1]
except ValueError:
return False
@@ -216,10 +211,7 @@ def validate(self, value):
try:
value = [datetime.datetime.strptime(range, '%Y-%m-%d %H:%M:%S')
for range in value.split(' to ')]
- if (len(value) == 2) and (value[0] <= value[1]):
- return True
- else:
- return False
+ return len(value) == 2 and value[0] <= value[1]
except ValueError:
return False
@@ -261,13 +253,9 @@ def validate(self, value):
try:
timetuples = [time.strptime(range, '%H:%M:%S')
for range in value.split(' to ')]
- if (len(timetuples) == 2) and (timetuples[0] <= timetuples[1]):
- return True
- else:
- return False
+ return len(timetuples) == 2 and timetuples[0] <= timetuples[1]
except ValueError:
raise
- return False
class BaseUuidFilter(BaseFilter):
@@ -313,7 +301,7 @@ class BaseFilterConverter(object):
logic.
"""
def __init__(self):
- self.converters = dict()
+ self.converters = {}
for p in dir(self):
attr = getattr(self, p)
diff --git a/flask_admin/model/form.py b/flask_admin/model/form.py
index 63da1ae1b..9f9d53641 100644
--- a/flask_admin/model/form.py
+++ b/flask_admin/model/form.py
@@ -77,10 +77,7 @@ def __init__(self, **kwargs):
for k, v in iteritems(kwargs):
setattr(self, k, v)
- # Convert form rules
- form_rules = getattr(self, 'form_rules', None)
-
- if form_rules:
+ if form_rules := getattr(self, 'form_rules', None):
self._form_rules = rules.RuleSet(self, form_rules)
else:
self._form_rules = None
@@ -175,17 +172,19 @@ def get_converter(self, column):
# Search by module + name
for col_type in types:
- type_string = '%s.%s' % (col_type.__module__, col_type.__name__)
+ type_string = f'{col_type.__module__}.{col_type.__name__}'
if type_string in self.converters:
return self.converters[type_string]
- # Search by name
- for col_type in types:
- if col_type.__name__ in self.converters:
- return self.converters[col_type.__name__]
-
- return None
+ return next(
+ (
+ self.converters[col_type.__name__]
+ for col_type in types
+ if col_type.__name__ in self.converters
+ ),
+ None,
+ )
def get_form(self, model, base_class=BaseForm,
only=None, exclude=None,
@@ -214,8 +213,7 @@ def get_label(self, info, name):
:param name:
Field name
"""
- form_name = getattr(info, 'form_label', None)
- if form_name:
+ if form_name := getattr(info, 'form_label', None):
return form_name
column_labels = getattr(self.view, 'column_labels', None)
diff --git a/flask_admin/model/template.py b/flask_admin/model/template.py
index 495e4199a..8ecd98897 100644
--- a/flask_admin/model/template.py
+++ b/flask_admin/model/template.py
@@ -14,12 +14,11 @@ def render_ctx(self, context, row_id, row):
return self.render(context, row_id, row)
def _resolve_symbol(self, context, symbol):
- if '.' in symbol:
- parts = symbol.split('.')
- m = context.resolve(parts[0])
- return reduce(getattr, parts[1:], m)
- else:
+ if '.' not in symbol:
return context.resolve(symbol)
+ parts = symbol.split('.')
+ m = context.resolve(parts[0])
+ return reduce(getattr, parts[1:], m)
class LinkRowAction(BaseListRowAction):
diff --git a/flask_admin/model/widgets.py b/flask_admin/model/widgets.py
index e906790dc..18d2619d5 100644
--- a/flask_admin/model/widgets.py
+++ b/flask_admin/model/widgets.py
@@ -52,12 +52,9 @@ def __call__(self, field, **kwargs):
kwargs['value'] = separator.join(ids)
kwargs['data-json'] = json.dumps(result)
kwargs['data-multiple'] = u'1'
- else:
- data = field.loader.format(field.data)
-
- if data:
- kwargs['value'] = data[0]
- kwargs['data-json'] = json.dumps(data)
+ elif data := field.loader.format(field.data):
+ kwargs['value'] = data[0]
+ kwargs['data-json'] = json.dumps(data)
placeholder = field.loader.options.get('placeholder', gettext('Please select model'))
kwargs.setdefault('data-placeholder', placeholder)
@@ -65,7 +62,7 @@ def __call__(self, field, **kwargs):
minimum_input_length = int(field.loader.options.get('minimum_input_length', 1))
kwargs.setdefault('data-minimum-input-length', minimum_input_length)
- return Markup('' % html_params(name=field.name, **kwargs))
+ return Markup(f'')
class XEditableWidget(object):
@@ -94,10 +91,7 @@ def __call__(self, field, **kwargs):
kwargs = self.get_kwargs(field, kwargs)
- return Markup(
- '%s' % (html_params(**kwargs),
- escape(display_value))
- )
+ return Markup(f'{escape(display_value)}')
def get_kwargs(self, field, kwargs):
"""
@@ -177,6 +171,6 @@ def get_kwargs(self, field, kwargs):
else:
kwargs['data-value'] = text_type(selected_ids[0])
else:
- raise Exception('Unsupported field type: %s' % (type(field),))
+ raise Exception(f'Unsupported field type: {type(field)}')
return kwargs
diff --git a/flask_admin/tests/fileadmin/test_fileadmin_azure.py b/flask_admin/tests/fileadmin/test_fileadmin_azure.py
index 31a5be066..82109b864 100644
--- a/flask_admin/tests/fileadmin/test_fileadmin_azure.py
+++ b/flask_admin/tests/fileadmin/test_fileadmin_azure.py
@@ -16,7 +16,7 @@ def setUp(self):
if not azure.BlockBlobService:
raise SkipTest('AzureFileAdmin dependencies not installed')
- self._container_name = 'fileadmin-tests-%s' % uuid4()
+ self._container_name = f'fileadmin-tests-{uuid4()}'
if not self._test_storage or not self._container_name:
raise SkipTest('AzureFileAdmin test credentials not set')
diff --git a/flask_admin/tests/geoa/test_basic.py b/flask_admin/tests/geoa/test_basic.py
index d6718dcda..7cbb09ae8 100644
--- a/flask_admin/tests/geoa/test_basic.py
+++ b/flask_admin/tests/geoa/test_basic.py
@@ -89,12 +89,12 @@ def test_model():
html = rv.data.decode('utf-8')
pattern = r'(.|\n)+({.*"type": ?"Point".*})(.|\n)+'
- group = re.match(pattern, html).group(2)
+ group = re.match(pattern, html)[2]
p = json.loads(group)
assert p['coordinates'][0] == 125.8
assert p['coordinates'][1] == 10.0
- url = '/admin/geomodel/edit/?id=%s' % model.id
+ url = f'/admin/geomodel/edit/?id={model.id}'
rv = client.get(url)
assert rv.status_code == 200
data = rv.data.decode('utf-8')
@@ -121,7 +121,7 @@ def test_model():
# assert list(to_shape(model.multi).geoms[0].coords) == [(100.0, 0.0])
# assert list(to_shape(model.multi).geoms[1].coords) == [(101.0, 1.0])
- url = '/admin/geomodel/delete/?id=%s' % model.id
+ url = f'/admin/geomodel/delete/?id={model.id}'
rv = client.post(url)
assert rv.status_code == 302
assert db.session.query(GeoModel).count() == 0
@@ -147,7 +147,7 @@ def test_none():
model = db.session.query(GeoModel).first()
- url = '/admin/geomodel/edit/?id=%s' % model.id
+ url = f'/admin/geomodel/edit/?id={model.id}'
rv = client.get(url)
assert rv.status_code == 200
data = rv.data.decode('utf-8')
diff --git a/flask_admin/tests/mongoengine/test_basic.py b/flask_admin/tests/mongoengine/test_basic.py
index b977d23b5..8967ec1dc 100644
--- a/flask_admin/tests/mongoengine/test_basic.py
+++ b/flask_admin/tests/mongoengine/test_basic.py
@@ -123,7 +123,7 @@ def test_model():
assert rv.status_code == 200
assert b'test1large' in rv.data
- url = '/admin/model1/edit/?id=%s' % model.id
+ url = f'/admin/model1/edit/?id={model.id}'
rv = client.get(url)
assert rv.status_code == 200
@@ -137,7 +137,7 @@ def test_model():
assert model.test3 == ''
assert model.test4 == ''
- url = '/admin/model1/delete/?id=%s' % model.id
+ url = f'/admin/model1/delete/?id={model.id}'
rv = client.post(url)
assert rv.status_code == 302
assert Model1.objects.count() == 0
@@ -253,12 +253,12 @@ def test_details_view():
assert '/admin/model2/details/' in data
# test redirection when details are disabled
- url = '/admin/model1/details/?url=%2Fadmin%2Fmodel1%2F&id=' + str(m1_id)
+ url = f'/admin/model1/details/?url=%2Fadmin%2Fmodel1%2F&id={str(m1_id)}'
rv = client.get(url)
assert rv.status_code == 302
# test if correct data appears in details view when enabled
- url = '/admin/model2/details/?url=%2Fadmin%2Fmodel2%2F&id=' + str(m2_id)
+ url = f'/admin/model2/details/?url=%2Fadmin%2Fmodel2%2F&id={str(m2_id)}'
rv = client.get(url)
data = rv.data.decode('utf-8')
assert 'String Field' in data
@@ -266,7 +266,7 @@ def test_details_view():
assert 'Int Field' in data
# test column_details_list
- url = '/admin/sf_view/details/?url=%2Fadmin%2Fsf_view%2F&id=' + str(m2_id)
+ url = f'/admin/sf_view/details/?url=%2Fadmin%2Fsf_view%2F&id={str(m2_id)}'
rv = client.get(url)
data = rv.data.decode('utf-8')
assert 'String Field' in data
@@ -1185,7 +1185,7 @@ def test_export_csv():
endpoint='row_limit_2')
admin.add_view(view)
- for x in range(5):
+ for _ in range(5):
fill_db(Model1, Model2)
client = app.test_client()
diff --git a/flask_admin/tests/peeweemodel/test_basic.py b/flask_admin/tests/peeweemodel/test_basic.py
index fa7f97645..4ad062d30 100644
--- a/flask_admin/tests/peeweemodel/test_basic.py
+++ b/flask_admin/tests/peeweemodel/test_basic.py
@@ -154,7 +154,7 @@ def test_model():
assert rv.status_code == 200
assert b'test1large' in rv.data
- url = '/admin/model1/edit/?id=%s' % model.id
+ url = f'/admin/model1/edit/?id={model.id}'
rv = client.get(url)
assert rv.status_code == 200
@@ -168,7 +168,7 @@ def test_model():
assert model.test3 is None or model.test3 == ''
assert model.test4 is None or model.test4 == ''
- url = '/admin/model1/delete/?id=%s' % model.id
+ url = f'/admin/model1/delete/?id={model.id}'
rv = client.post(url)
assert rv.status_code == 302
assert Model1.select().count() == 0
@@ -1045,7 +1045,7 @@ def test_export_csv():
endpoint='row_limit_2')
admin.add_view(view)
- for x in range(5):
+ for _ in range(5):
fill_db(Model1, Model2)
client = app.test_client()
diff --git a/flask_admin/tests/pymongo/test_basic.py b/flask_admin/tests/pymongo/test_basic.py
index c7d598607..d21dc9748 100644
--- a/flask_admin/tests/pymongo/test_basic.py
+++ b/flask_admin/tests/pymongo/test_basic.py
@@ -61,7 +61,7 @@ def test_model():
assert rv.status_code == 200
assert 'test1large' in rv.data.decode('utf-8')
- url = '/admin/testview/edit/?id=%s' % model['_id']
+ url = f"/admin/testview/edit/?id={model['_id']}"
rv = client.get(url)
assert rv.status_code == 200
@@ -75,7 +75,7 @@ def test_model():
assert model['test1'] == 'test1small'
assert model['test2'] == 'test2large'
- url = '/admin/testview/delete/?id=%s' % model['_id']
+ url = f"/admin/testview/delete/?id={model['_id']}"
rv = client.post(url)
assert rv.status_code == 302
assert db.test.count() == 0
diff --git a/flask_admin/tests/sqla/test_basic.py b/flask_admin/tests/sqla/test_basic.py
index 3aaba3325..ce919ed54 100644
--- a/flask_admin/tests/sqla/test_basic.py
+++ b/flask_admin/tests/sqla/test_basic.py
@@ -254,7 +254,7 @@ def test_model():
assert u'test1large' in rv.data.decode('utf-8')
# check that we can retrieve an edit view
- url = '/admin/model1/edit/?id=%s' % model.id
+ url = f'/admin/model1/edit/?id={model.id}'
rv = client.get(url)
assert rv.status_code == 200
@@ -299,7 +299,7 @@ def test_model():
assert model.sqla_utils_color is None
# check that the model can be deleted
- url = '/admin/model1/delete/?id=%s' % model.id
+ url = f'/admin/model1/delete/?id={model.id}'
rv = client.post(url)
assert rv.status_code == 302
assert db.session.query(Model1).count() == 0
@@ -1583,8 +1583,8 @@ def number_of_pixels_str(self):
return str(self.number_of_pixels())
@number_of_pixels_str.expression
- def number_of_pixels_str(cls):
- return cast(cls.width * cls.height, db.String)
+ def number_of_pixels_str(self):
+ return cast(self.width * self.height, db.String)
db.create_all()
@@ -1643,7 +1643,7 @@ class Model1(db.Model):
@hybrid_property
def fullname(self):
- return '{} {}'.format(self.firstname, self.lastname)
+ return f'{self.firstname} {self.lastname}'
class Model2(db.Model):
id = db.Column(db.Integer, primary_key=True)
@@ -2605,7 +2605,7 @@ def test_export_csv():
app, db, admin = setup()
Model1, Model2 = create_models(db)
- for x in range(5):
+ for _ in range(5):
fill_db(db, Model1, Model2)
view = CustomModelView(Model1, db.session, can_export=True,
diff --git a/flask_admin/tests/test_base.py b/flask_admin/tests/test_base.py
index ea0d69d2e..d2848c82a 100644
--- a/flask_admin/tests/test_base.py
+++ b/flask_admin/tests/test_base.py
@@ -29,16 +29,10 @@ def _handle_view(self, name, **kwargs):
return 'Failure!'
def is_accessible(self):
- if self.allow_access:
- return super(MockView, self).is_accessible()
-
- return False
+ return super(MockView, self).is_accessible() if self.allow_access else False
def is_visible(self):
- if self.visible:
- return super(MockView, self).is_visible()
-
- return False
+ return super(MockView, self).is_visible() if self.visible else False
class MockMethodView(base.BaseView):
@@ -465,7 +459,7 @@ def check_class_name():
def check_endpoint():
class CustomView(MockView):
def _get_endpoint(self, endpoint):
- return 'admin.' + super(CustomView, self)._get_endpoint(endpoint)
+ return f'admin.{super(CustomView, self)._get_endpoint(endpoint)}'
view = CustomView()
assert view.endpoint == 'admin.customview'
diff --git a/flask_admin/tests/test_model.py b/flask_admin/tests/test_model.py
index eee8c401e..c35c678c6 100644
--- a/flask_admin/tests/test_model.py
+++ b/flask_admin/tests/test_model.py
@@ -62,11 +62,7 @@ def __init__(self, model, data=None, name=None, category=None,
self.search_arguments = []
- if data is None:
- self.all_models = {1: Model(1), 2: Model(2)}
- else:
- self.all_models = data
-
+ self.all_models = {1: Model(1), 2: Model(2)} if data is None else data
self.last_id = len(self.all_models) + 1
# Scaffolding
@@ -355,8 +351,7 @@ def scaffold_form(self):
def get_csrf_token(data):
data = data.split('name="csrf_token" type="hidden" value="')[1]
- token = data.split('"')[0]
- return token
+ return data.split('"')[0]
app, admin = setup()
diff --git a/flask_admin/tools.py b/flask_admin/tools.py
index b0533aae2..9933d3d05 100644
--- a/flask_admin/tools.py
+++ b/flask_admin/tools.py
@@ -94,11 +94,14 @@ def get_dict_attr(obj, attr, default=None):
:param default:
Default value if attribute was not found
"""
- for obj in [obj] + obj.__class__.mro():
- if attr in obj.__dict__:
- return obj.__dict__[attr]
-
- return default
+ return next(
+ (
+ obj.__dict__[attr]
+ for obj in [obj] + obj.__class__.mro()
+ if attr in obj.__dict__
+ ),
+ default,
+ )
def escape(value):
@@ -134,17 +137,16 @@ def iterdecode(value):
escaped = False
for c in value:
- if not escaped:
- if c == CHAR_ESCAPE:
- escaped = True
- continue
- elif c == CHAR_SEPARATOR:
- result.append(accumulator)
- accumulator = u''
- continue
- else:
+ if escaped:
escaped = False
+ elif c == CHAR_ESCAPE:
+ escaped = True
+ continue
+ elif c == CHAR_SEPARATOR:
+ result.append(accumulator)
+ accumulator = u''
+ continue
accumulator += c
result.append(accumulator)