Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typings to the redis.Redis class #3252

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dump.rdb
_build
vagrant/.vagrant
.python-version
.tool-versions
.cache
.eggs
.idea
Expand Down
89 changes: 43 additions & 46 deletions redis/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import copy
import re
import threading
import time
import warnings
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, Optional, List, Union, Type

from redis._cache import (
DEFAULT_ALLOW_LIST,
Expand Down Expand Up @@ -105,7 +107,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"""

@classmethod
def from_url(cls, url: str, **kwargs) -> "Redis":
bitterteriyaki marked this conversation as resolved.
Show resolved Hide resolved
def from_url(cls, url: str, **kwargs: Any) -> Redis:
"""
Return a Redis client object configured from the given URL

Expand Down Expand Up @@ -156,66 +158,61 @@ class initializer. In the case of conflicting arguments, querystring
return client

@classmethod
def from_pool(
bitterteriyaki marked this conversation as resolved.
Show resolved Hide resolved
cls: Type["Redis"],
connection_pool: ConnectionPool,
) -> "Redis":
def from_pool(cls, connection_pool: ConnectionPool) -> Redis:
"""
Return a Redis client from the given connection pool.
The Redis client will take ownership of the connection pool and
close it when the Redis client is closed.
"""
client = cls(
connection_pool=connection_pool,
)
client = cls(connection_pool=connection_pool)
client.auto_close_connection_pool = True
return client

def __init__(
self,
host="localhost",
port=6379,
db=0,
password=None,
socket_timeout=None,
socket_connect_timeout=None,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
socket_timeout: Optional[float] = None,
socket_connect_timeout: Optional[float] = None,
socket_keepalive=None,
socket_keepalive_options=None,
connection_pool=None,
unix_socket_path=None,
encoding="utf-8",
encoding_errors="strict",
encoding: str = "utf-8",
encoding_errors: str = "strict",
charset=None,
errors=None,
decode_responses=False,
retry_on_timeout=False,
decode_responses: bool = False,
retry_on_timeout: bool = False,
retry_on_error=None,
ssl=False,
ssl: bool = False,
ssl_keyfile=None,
ssl_certfile=None,
ssl_cert_reqs="required",
ssl_cert_reqs: str = "required",
ssl_ca_certs=None,
ssl_ca_path=None,
ssl_ca_data=None,
ssl_check_hostname=False,
ssl_check_hostname: bool = False,
ssl_password=None,
ssl_validate_ocsp=False,
ssl_validate_ocsp_stapled=False,
ssl_validate_ocsp: bool = False,
ssl_validate_ocsp_stapled: bool = False,
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
ssl_ciphers=None,
max_connections=None,
single_connection_client=False,
health_check_interval=0,
single_connection_client: bool = False,
health_check_interval: int = 0,
client_name=None,
lib_name="redis-py",
lib_version=get_lib_version(),
lib_name: str = "redis-py",
lib_version: str = get_lib_version(),
username=None,
retry=None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
protocol: int = 2,
cache_enabled: bool = False,
client_cache: Optional[AbstractCache] = None,
cache_max_size: int = 10000,
Expand Down Expand Up @@ -345,18 +342,18 @@ def __repr__(self) -> str:
f"({repr(self.connection_pool)})>"
)

def get_encoder(self) -> "Encoder":
def get_encoder(self) -> Encoder:
"""Get the connection pool's encoder"""
return self.connection_pool.get_encoder()

def get_connection_kwargs(self) -> Dict:
def get_connection_kwargs(self) -> Dict[str, Any]:
bitterteriyaki marked this conversation as resolved.
Show resolved Hide resolved
"""Get the connection's key-word arguments"""
return self.connection_pool.connection_kwargs

def get_retry(self) -> Optional["Retry"]:
def get_retry(self) -> Optional[Retry]:
return self.get_connection_kwargs().get("retry")

def set_retry(self, retry: "Retry") -> None:
def set_retry(self, retry: Retry) -> None:
self.get_connection_kwargs().update({"retry": retry})
self.connection_pool.set_retry(retry)

Expand Down Expand Up @@ -387,7 +384,7 @@ def load_external_module(self, funcname, func) -> None:
"""
setattr(self, funcname, func)

def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline":
def pipeline(self, transaction: bool = True, shard_hint=None) -> Pipeline:
"""
Return a new pipeline object that can queue multiple commands for
later execution. ``transaction`` indicates whether all commands
Expand All @@ -400,7 +397,7 @@ def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline":
)

def transaction(
self, func: Callable[["Pipeline"], None], *watches, **kwargs
self, func: Callable[[Pipeline], None], *watches, **kwargs
) -> None:
"""
Convenience method for executing the callable `func` as a transaction
Expand Down Expand Up @@ -430,7 +427,7 @@ def lock(
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[float] = None,
lock_class: Union[None, Any] = None,
lock_class: Optional[Type[Lock]] = None,
thread_local: bool = True,
):
"""
Expand Down Expand Up @@ -497,32 +494,32 @@ def lock(
thread_local=thread_local,
)

def pubsub(self, **kwargs):
def pubsub(self, **kwargs) -> PubSub:
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
return PubSub(self.connection_pool, **kwargs)

def monitor(self):
def monitor(self) -> Monitor:
return Monitor(self.connection_pool)

def client(self):
def client(self) -> Redis:
return self.__class__(
connection_pool=self.connection_pool, single_connection_client=True
)

def __enter__(self):
def __enter__(self) -> Redis:
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

def __del__(self):
def __del__(self) -> None:
self.close()

def close(self):
def close(self) -> None:
# In case a connection property does not yet exist
# (due to a crash earlier in the Redis() constructor), return
# immediately as there is nothing to clean-up.
Expand Down Expand Up @@ -602,19 +599,19 @@ def parse_response(self, connection, command_name, **options):
return self.response_callbacks[command_name](response, **options)
return response

def flush_cache(self):
def flush_cache(self) -> None:
if self.connection:
self.connection.flush_cache()
else:
self.connection_pool.flush_cache()

def delete_command_from_cache(self, command):
def delete_command_from_cache(self, command) -> None:
if self.connection:
self.connection.delete_command_from_cache(command)
else:
self.connection_pool.delete_command_from_cache(command)

def invalidate_key_from_cache(self, key):
def invalidate_key_from_cache(self, key) -> None:
if self.connection:
self.connection.invalidate_key_from_cache(key)
else:
Expand Down
Loading