diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 00000000..9b9d62d6 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,40 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python package + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: windows-latest # Womp Womp Linux + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + if (Test-Path requirements.txt) { pip install -r requirements.txt } + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + python -m pytest -v # Added -v for verbose output diff --git a/supabase/_async/client.py b/supabase/_async/client.py index ad21c7cd..f69626a2 100644 --- a/supabase/_async/client.py +++ b/supabase/_async/client.py @@ -1,15 +1,12 @@ import asyncio +import logging import re from typing import Any, Dict, List, Optional, Union from gotrue import AsyncMemoryStorage from gotrue.types import AuthChangeEvent, Session from httpx import Timeout -from postgrest import ( - AsyncPostgrestClient, - AsyncRequestBuilder, - AsyncRPCFilterRequestBuilder, -) +from postgrest import AsyncPostgrestClient, AsyncRequestBuilder, AsyncRPCFilterRequestBuilder from postgrest.constants import DEFAULT_POSTGREST_CLIENT_TIMEOUT from realtime import AsyncRealtimeChannel, AsyncRealtimeClient, RealtimeChannelOptions from storage3 import AsyncStorageClient @@ -20,7 +17,6 @@ from .auth_client import AsyncSupabaseAuthClient -# Create an exception class when user does not provide a valid url or key. class SupabaseException(Exception): def __init__(self, message: str): self.message = message @@ -54,11 +50,9 @@ def __init__( if not supabase_key: raise SupabaseException("supabase_key is required") - # Check if the url and key are valid + # Validate the URL and key if not re.match(r"^(https?)://.+", supabase_url): raise SupabaseException("Invalid URL") - - # Check if the key is a valid JWT if not re.match( r"^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*$", supabase_key ): @@ -77,7 +71,7 @@ def __init__( self.storage_url = f"{supabase_url}/storage/v1" self.functions_url = f"{supabase_url}/functions/v1" - # Instantiate clients. + # Instantiate clients self.auth = self._init_supabase_auth_client( auth_url=self.auth_url, client_options=options, @@ -99,6 +93,22 @@ async def create( supabase_key: str, options: Optional[ClientOptions] = None, ): + """Create a Supabase client instance. + + Parameters + ---------- + supabase_url: str + The URL to the Supabase instance that should be connected to. + supabase_key: str + The API key to the Supabase instance that should be connected to. + **options + Any extra settings to be optionally specified - also see the + `DEFAULT_OPTIONS` dict. + + Returns + ------- + AsyncClient + """ auth_header = options.headers.get("Authorization") if options else None client = cls(supabase_url, supabase_key, options) @@ -116,19 +126,11 @@ async def create( return client def table(self, table_name: str) -> AsyncRequestBuilder: - """Perform a table operation. - - Note that the supabase client uses the `from` method, but in Python, - this is a reserved keyword, so we have elected to use the name `table`. - Alternatively you can use the `.from_()` method. - """ + """Perform a table operation.""" return self.from_(table_name) def schema(self, schema: str) -> AsyncPostgrestClient: - """Select a schema to query or perform an function (rpc) call. - - The schema needs to be on the list of exposed schemas inside Supabase. - """ + """Select a schema to query or perform a function (RPC) call.""" if self.options.schema != schema: self.options.schema = schema if self._postgrest: @@ -136,36 +138,20 @@ def schema(self, schema: str) -> AsyncPostgrestClient: return self.postgrest def from_(self, table_name: str) -> AsyncRequestBuilder: - """Perform a table operation. - - See the `table` method. - """ + """Perform a table operation.""" return self.postgrest.from_(table_name) def rpc( self, fn: str, params: Optional[Dict[Any, Any]] = None ) -> AsyncRPCFilterRequestBuilder: - """Performs a stored procedure call. - - Parameters - ---------- - fn : callable - The stored procedure call to be executed. - params : dict of any - Parameters passed into the stored procedure call. - - Returns - ------- - SyncFilterRequestBuilder - Returns a filter builder. This lets you apply filters on the response - of an RPC. - """ + """Performs a stored procedure call.""" if params is None: params = {} return self.postgrest.rpc(fn, params) @property def postgrest(self): + """PostgREST client property.""" if self._postgrest is None: self._postgrest = self._init_postgrest_client( rest_url=self.rest_url, @@ -173,11 +159,11 @@ def postgrest(self): schema=self.options.schema, timeout=self.options.postgrest_client_timeout, ) - return self._postgrest @property def storage(self): + """Storage client property.""" if self._storage is None: self._storage = self._init_storage_client( storage_url=self.storage_url, @@ -188,6 +174,7 @@ def storage(self): @property def functions(self): + """Functions client property.""" if self._functions is None: self._functions = AsyncFunctionsClient( self.functions_url, @@ -207,20 +194,80 @@ def get_channels(self) -> List[AsyncRealtimeChannel]: return self.realtime.get_channels() async def remove_channel(self, channel: AsyncRealtimeChannel) -> None: - """Unsubscribes and removes Realtime channel from Realtime client.""" + """Unsubscribes and removes a Realtime channel from the Realtime client.""" await self.realtime.remove_channel(channel) async def remove_all_channels(self) -> None: - """Unsubscribes and removes all Realtime channels from Realtime client.""" + """Unsubscribes and removes all Realtime channels from the Realtime client.""" await self.realtime.remove_all_channels() + def _create_auth_header(self, token: str): + """Creates an authorization header.""" + return f"Bearer {token}" + + def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]: + """Helper method to get auth headers.""" + if authorization is None: + authorization = self.options.headers.get( + "Authorization", self._create_auth_header(self.supabase_key) + ) + return { + "apiKey": self.supabase_key, + "Authorization": authorization, + } + + def _listen_to_auth_events( + self, event: AuthChangeEvent, session: Optional[Session] + ): + """Listens to authentication state changes.""" + access_token = self.supabase_key + if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]: + self._postgrest = None + self._storage = None + self._functions = None + access_token = session.access_token if session else self.supabase_key + + self.options.headers["Authorization"] = self._create_auth_header(access_token) + asyncio.create_task(self.realtime.set_auth(access_token)) + + async def connect_to_realtime(self): + """Connect to Supabase Realtime service and handle disconnections with retries.""" + try: + await self.realtime.connect() + logging.info("Connected to Supabase realtime successfully.") + except Exception as e: + logging.error(f"Connection to Supabase realtime failed: {e}") + await self._retry_realtime_connection() + + async def _retry_realtime_connection(self): + """Retries the connection to the Realtime service with exponential backoff.""" + retries = 0 + max_retries = 5 + base_delay = 2 + + while retries < max_retries: + try: + await asyncio.sleep(base_delay * (2 ** retries)) # Exponential backoff + await self.realtime.connect() + logging.info("Reconnected to Supabase realtime.") + return + except Exception as e: + retries += 1 + logging.error(f"Retry {retries} failed: {e}") + + logging.error("Max retries reached, could not reconnect to Supabase realtime.") + @staticmethod def _init_realtime_client( realtime_url: str, supabase_key: str, options: Optional[Dict[str, Any]] = None ) -> AsyncRealtimeClient: + """Private method for creating an instance of the realtime client.""" if options is None: options = {} - """Private method for creating an instance of the realtime-py client.""" + + # Configure connection options if needed + options['timeout'] = 30 # Example timeout setting, adjust as needed + return AsyncRealtimeClient(realtime_url, token=supabase_key, **options) @staticmethod @@ -231,6 +278,7 @@ def _init_storage_client( verify: bool = True, proxy: Optional[str] = None, ) -> AsyncStorageClient: + """Initializes the storage client.""" return AsyncStorageClient( storage_url, headers, storage_client_timeout, verify, proxy ) @@ -263,7 +311,7 @@ def _init_postgrest_client( verify: bool = True, proxy: Optional[str] = None, ) -> AsyncPostgrestClient: - """Private helper for creating an instance of the Postgrest client.""" + """Initializes the PostgREST client.""" return AsyncPostgrestClient( rest_url, headers=headers, @@ -273,42 +321,13 @@ def _init_postgrest_client( proxy=proxy, ) - def _create_auth_header(self, token: str): - return f"Bearer {token}" - - def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]: - if authorization is None: - authorization = self.options.headers.get( - "Authorization", self._create_auth_header(self.supabase_key) - ) - - """Helper method to get auth headers.""" - return { - "apiKey": self.supabase_key, - "Authorization": authorization, - } - - def _listen_to_auth_events( - self, event: AuthChangeEvent, session: Optional[Session] - ): - access_token = self.supabase_key - if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]: - # reset postgrest and storage instance on event change - self._postgrest = None - self._storage = None - self._functions = None - access_token = session.access_token if session else self.supabase_key - - self.options.headers["Authorization"] = self._create_auth_header(access_token) - asyncio.create_task(self.realtime.set_auth(access_token)) - async def create_client( supabase_url: str, supabase_key: str, options: Optional[ClientOptions] = None, ) -> AsyncClient: - """Create client function to instantiate supabase client like JS runtime. + """Create client function to instantiate supabase client. Parameters ---------- @@ -320,19 +339,9 @@ async def create_client( Any extra settings to be optionally specified - also see the `DEFAULT_OPTIONS` dict. - Examples - -------- - Instantiating the client. - >>> import os - >>> from supabase import create_client, Client - >>> - >>> url: str = os.environ.get("SUPABASE_TEST_URL") - >>> key: str = os.environ.get("SUPABASE_TEST_KEY") - >>> supabase: Client = create_client(url, key) - Returns ------- - Client + AsyncClient """ return await AsyncClient.create( supabase_url=supabase_url, supabase_key=supabase_key, options=options diff --git a/tests/test_client.py b/tests/test_client.py index 672e5e6d..b025f17c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,12 +1,13 @@ from __future__ import annotations import os +import asyncio from typing import Any -from unittest.mock import MagicMock +from unittest.mock import MagicMock, AsyncMock import pytest -from supabase import Client, ClientOptions, create_client +from supabase import AsyncClient, ClientOptions, create_client @pytest.mark.xfail( @@ -16,7 +17,7 @@ @pytest.mark.parametrize("key", ["", None, "valeefgpoqwjgpj", 139, -1, {}, []]) def test_incorrect_values_dont_instantiate_client(url: Any, key: Any) -> None: """Ensure we can't instantiate client with invalid values.""" - _: Client = create_client(url, key) + _: AsyncClient = create_client(url, key) def test_uses_key_as_authorization_header_by_default() -> None: @@ -91,3 +92,50 @@ def test_updates_the_authorization_header_on_auth_events() -> None: assert client.storage.session.headers.get("apiKey") == key assert client.storage.session.headers.get("Authorization") == updated_authorization + + +@pytest.mark.asyncio +async def test_connect_to_realtime_success(mocker): + url = os.environ.get("SUPABASE_TEST_URL") + key = os.environ.get("SUPABASE_TEST_KEY") + + client = create_client(url, key) + mock_connect = mocker.patch.object(client.realtime, 'connect', return_value=None) + + await client.connect_to_realtime() + mock_connect.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_to_realtime_failure(mocker): + url = os.environ.get("SUPABASE_TEST_URL") + key = os.environ.get("SUPABASE_TEST_KEY") + + client = create_client(url, key) + mock_connect = mocker.patch.object(client.realtime, 'connect', side_effect=Exception("Connection failed")) + + with pytest.raises(Exception, match="Connection failed"): + await client.connect_to_realtime() + + +@pytest.mark.asyncio +async def test_reconnect_logic(mocker): + url = os.environ.get("SUPABASE_TEST_URL") + key = os.environ.get("SUPABASE_TEST_KEY") + + client = create_client(url, key) + mock_reconnect = mocker.patch.object(client.realtime, 'connect', side_effect=[Exception("Failed"), None]) + + await client.connect_to_realtime() # First attempt fails + await client.connect_to_realtime() # Second attempt should succeed + assert mock_reconnect.call_count == 2 # Check that it tried to reconnect + + +def test_logging_on_connection_failure(caplog): + url = os.environ.get("SUPABASE_TEST_URL") + key = os.environ.get("SUPABASE_TEST_KEY") + + client = create_client(url, key) + with caplog.at_level(logging.ERROR): + await client.connect_to_realtime() # Assume it fails + assert "Connection failed" in caplog.text