diff --git a/src/aioice/ice.py b/src/aioice/ice.py index f33a8dc..0eb99bd 100644 --- a/src/aioice/ice.py +++ b/src/aioice/ice.py @@ -865,7 +865,6 @@ async def get_component_candidates( except OSError as exc: self.__log_info("Could not bind to %s - %s", address, exc) continue - protocol = cast(StunProtocol, protocol) host_protocols.append(protocol) # add host candidate @@ -911,7 +910,6 @@ async def get_component_candidates( ssl=self.turn_ssl, transport=self.turn_transport, ) - protocol = cast(StunProtocol, protocol) self._protocols.append(protocol) # add relayed candidate diff --git a/src/aioice/mdns.py b/src/aioice/mdns.py index 2eac59e..7982a9c 100644 --- a/src/aioice/mdns.py +++ b/src/aioice/mdns.py @@ -189,4 +189,4 @@ async def create_mdns_protocol() -> MDnsProtocol: sock=rx_sock, ) - return cast(MDnsProtocol, protocol) + return protocol diff --git a/src/aioice/turn.py b/src/aioice/turn.py index 360cf50..c54c9e6 100644 --- a/src/aioice/turn.py +++ b/src/aioice/turn.py @@ -4,7 +4,18 @@ import socket import struct import time -from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union, cast +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Text, + Tuple, + TypeVar, + Union, + cast, +) from . import stun from .utils import random_transaction_id @@ -17,6 +28,8 @@ UDP_TRANSPORT = 0x11000000 UDP_SOCKET_BUFFER_SIZE = 262144 +_ProtocolT = TypeVar("_ProtocolT", bound=asyncio.BaseProtocol) + def is_channel_data(data: bytes) -> bool: return (data[0] & 0xC0) == 0x40 @@ -371,7 +384,7 @@ async def _connect(self) -> None: async def create_turn_endpoint( - protocol_factory: Callable, + protocol_factory: Callable[[], _ProtocolT], server_addr: Tuple[str, int], username: Optional[str], password: Optional[str], @@ -379,11 +392,13 @@ async def create_turn_endpoint( channel_refresh_time: int = DEFAULT_CHANNEL_REFRESH_TIME, ssl: bool = False, transport: str = "udp", -) -> Tuple[TurnTransport, asyncio.Protocol]: +) -> Tuple[TurnTransport, _ProtocolT]: """ Create datagram connection relayed over TURN. """ loop = asyncio.get_event_loop() + inner_protocol: asyncio.BaseProtocol + inner_transport: asyncio.BaseTransport if transport == "tcp": inner_transport, inner_protocol = await loop.create_connection( lambda: TurnClientTcpProtocol(