Source code for aiocometd.client

"""Client class implementation"""
import asyncio
import reprlib
import logging
from collections import abc
from contextlib import suppress

from .transports import create_transport
from .constants import DEFAULT_CONNECTION_TYPE, ConnectionType, MetaChannel, \
    SERVICE_CHANNEL_PREFIX, TransportState
from .exceptions import ServerError, ClientInvalidOperation, \
    TransportTimeoutError, ClientError
from .utils import is_server_error_message


LOGGER = logging.getLogger(__name__)


[docs]class Client: # pylint: disable=too-many-instance-attributes """CometD client""" #: Predefined server error messages by channel name _SERVER_ERROR_MESSAGES = { MetaChannel.HANDSHAKE: "Handshake request failed.", MetaChannel.CONNECT: "Connect request failed.", MetaChannel.DISCONNECT: "Disconnect request failed.", MetaChannel.SUBSCRIBE: "Subscribe request failed.", MetaChannel.UNSUBSCRIBE: "Unsubscribe request failed." } #: Defualt connection types list _DEFAULT_CONNECTION_TYPES = [ConnectionType.WEBSOCKET, ConnectionType.LONG_POLLING] def __init__(self, url, connection_types=None, *, connection_timeout=10.0, ssl=None, max_pending_count=100, extensions=None, auth=None, loop=None): """ :param str url: CometD service url :param connection_types: List of connection types in order of \ preference, or a single connection type name. If ``None``, \ [:obj:`~ConnectionType.WEBSOCKET`, \ :obj:`~ConnectionType.LONG_POLLING`] will be used as a default value. :type connection_types: list[ConnectionType], ConnectionType or None :param connection_timeout: The maximum amount of time to wait for the \ transport to re-establish a connection with the server when the \ connection fails. :type connection_timeout: int, float or None :param ssl: SSL validation mode. None for default SSL check \ (:func:`ssl.create_default_context` is used), False for skip SSL \ certificate validation, \ `aiohttp.Fingerprint <https://aiohttp.readthedocs.io/en/stable/\ client_reference.html#aiohttp.Fingerprint>`_ for fingerprint \ validation, :obj:`ssl.SSLContext` for custom SSL certificate \ validation. :param int max_pending_count: The maximum number of messages to \ prefetch from the server. If the number of prefetched messages reach \ this size then the connection will be suspended, until messages are \ consumed. \ If it is less than or equal to zero, the count is infinite. :param extensions: List of protocol extension objects :type extensions: list[Extension] or None :param AuthExtension auth: An auth extension :param loop: Event :obj:`loop <asyncio.BaseEventLoop>` used to schedule tasks. If *loop* is ``None`` then :func:`asyncio.get_event_loop` is used to get the default event loop. """ #: CometD service url self.url = url #: List of connection types to use in order of preference self._connection_types = None if isinstance(connection_types, ConnectionType): self._connection_types = [connection_types] elif isinstance(connection_types, abc.Iterable): self._connection_types = list(connection_types) else: self._connection_types = self._DEFAULT_CONNECTION_TYPES self._loop = loop or asyncio.get_event_loop() #: queue for consuming incoming event messages self._incoming_queue = None #: transport object self._transport = None #: marks whether the client is open or closed self._closed = True #: The maximum amount of time to wait for the transport to re-establish #: a connection with the server when the connection fails self.connection_timeout = connection_timeout #: SSL validation mode self.ssl = ssl #: the maximum number of messages to prefetch from the server self._max_pending_count = max_pending_count #: List of protocol extension objects self.extensions = extensions #: An auth extension self.auth = auth def __repr__(self): """Formal string representation""" cls_name = type(self).__name__ fmt_spec = "{}({}, {}, connection_timeout={}, ssl={}, " \ "max_pending_count={}, extensions={}, auth={}, loop={})" return fmt_spec.format(cls_name, reprlib.repr(self.url), reprlib.repr(self._connection_types), reprlib.repr(self.connection_timeout), reprlib.repr(self.ssl), reprlib.repr(self._max_pending_count), reprlib.repr(self.extensions), reprlib.repr(self.auth), reprlib.repr(self._loop)) @property def closed(self): """Marks whether the client is open or closed""" return self._closed @property def subscriptions(self): """Set of subscribed channels""" if self._transport: return self._transport.subscriptions return set() @property def connection_type(self): """The current connection type in use if the client is open, otherwise ``None``""" if self._transport is not None: return self._transport.connection_type return None @property def pending_count(self): """The number of pending incoming messages Once :obj:`open` is called the client starts listening for messages from the server. The incoming messages are retrieved and stored in an internal queue until they get consumed by calling :obj:`receive`. """ if self._incoming_queue is None: return 0 return self._incoming_queue.qsize() @property def has_pending_messages(self): """Marks whether the client has any pending incoming messages""" return self.pending_count > 0 def _pick_connection_type(self, connection_types): """Pick a connection type based on the *connection_types* supported by the server and on the user's preferences :param list[str] connection_types: Connection types \ supported by the server :return: The connection type with the highest precedence \ which is supported by the server :rtype: ConnectionType or None """ server_connection_types = [] for type_string in connection_types: with suppress(ValueError): server_connection_types.append(ConnectionType(type_string)) intersection = (set(server_connection_types) & set(self._connection_types)) if not intersection: return None result = min(intersection, key=self._connection_types.index) return result async def _negotiate_transport(self): """Negotiate the transport type to use with the server and create the transport object :return: Transport object :rtype: Transport :raise ClientError: If none of the connection types offered by the \ server are supported """ self._incoming_queue = asyncio.Queue(maxsize=self._max_pending_count) transport = create_transport(DEFAULT_CONNECTION_TYPE, url=self.url, incoming_queue=self._incoming_queue, ssl=self.ssl, extensions=self.extensions, auth=self.auth, loop=self._loop) try: response = await transport.handshake(self._connection_types) self._verify_response(response) LOGGER.info("Connection types supported by the server: %r", response["supportedConnectionTypes"]) connection_type = self._pick_connection_type( response["supportedConnectionTypes"] ) if not connection_type: raise ClientError("None of the connection types offered by " "the server are supported.") if transport.connection_type != connection_type: client_id = transport.client_id await transport.close() transport = create_transport( connection_type, url=self.url, incoming_queue=self._incoming_queue, client_id=client_id, ssl=self.ssl, extensions=self.extensions, auth=self.auth, loop=self._loop) return transport except Exception: await transport.close() raise
[docs] async def open(self): """Establish a connection with the CometD server This method works mostly the same way as the `handshake` method of CometD clients in the reference implementations. :raise ClientError: If none of the connection types offered by the \ server are supported :raise ClientInvalidOperation: If the client is already open, or in \ other words if it isn't :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the handshake or the first connect request \ gets rejected by the server. """ if not self.closed: raise ClientInvalidOperation("Client is already open.") LOGGER.info("Opening client with connection types %r ...", [t.value for t in self._connection_types]) self._transport = await self._negotiate_transport() response = await self._transport.connect() self._verify_response(response) self._closed = False LOGGER.info("Client opened with connection_type %r", self.connection_type.value)
[docs] async def close(self): """Disconnect from the CometD server""" if not self.closed: if self.pending_count == 0: LOGGER.info("Closing client...") else: LOGGER.warning( "Closing client while %s messages are still pending...", self.pending_count) try: if self._transport: await self._transport.disconnect() await self._transport.close() finally: self._closed = True LOGGER.info("Client closed.")
[docs] async def subscribe(self, channel): """Subscribe to *channel* :param str channel: Name of the channel :raise ClientInvalidOperation: If the client is :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the subscribe request gets rejected by the \ server """ if self.closed: raise ClientInvalidOperation("Can't send subscribe request while, " "the client is closed.") await self._check_server_disconnected() response = await self._transport.subscribe(channel) self._verify_response(response) LOGGER.info("Subscribed to channel %s", channel)
[docs] async def unsubscribe(self, channel): """Unsubscribe from *channel* :param str channel: Name of the channel :raise ClientInvalidOperation: If the client is :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the unsubscribe request gets rejected by the \ server """ if self.closed: raise ClientInvalidOperation("Can't send unsubscribe request " "while, the client is closed.") await self._check_server_disconnected() response = await self._transport.unsubscribe(channel) self._verify_response(response) LOGGER.info("Unsubscribed from channel %s", channel)
[docs] async def publish(self, channel, data): """Publish *data* to the given *channel* :param str channel: Name of the channel :param dict data: Data to send to the server :return: Publish response :rtype: dict :raise ClientInvalidOperation: If the client is :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the publish request gets rejected by the server """ if self.closed: raise ClientInvalidOperation("Can't publish data while, " "the client is closed.") await self._check_server_disconnected() response = await self._transport.publish(channel, data) self._verify_response(response) return response
def _verify_response(self, response): """Check the ``successful`` status of the *response* and raise \ the appropriate :obj:`~aiocometd.exceptions.ServerError` if it's False If the *response* has no ``successful`` field, it's considered to be successful. :param dict response: Response message :raise ServerError: If the *response* is not ``successful`` """ if is_server_error_message(response): self._raise_server_error(response) def _raise_server_error(self, response): """Raise the appropriate :obj:`~aiocometd.exceptions.ServerError` for \ the failed *response* :param dict response: Response message :raise ServerError: If the *response* is not ``successful`` """ channel = response["channel"] message = type(self)._SERVER_ERROR_MESSAGES.get(channel) if not message: if channel.startswith(SERVICE_CHANNEL_PREFIX): message = "Service request failed." else: message = "Publish request failed." raise ServerError(message, response)
[docs] async def receive(self): """Wait for incoming messages from the server :return: Incoming message :rtype: dict :raise ClientInvalidOperation: If the client is closed, and has no \ more pending incoming messages :raise ServerError: If the client receives a confirmation message \ which is not ``successful`` :raise TransportTimeoutError: If the transport can't re-establish \ connection with the server in :obj:`connection_timeout` time. """ if not self.closed or self.has_pending_messages: response = await self._get_message(self.connection_timeout) self._verify_response(response) return response else: raise ClientInvalidOperation("The client is closed and there are " "no pending messages.")
async def __aiter__(self): """Asynchronous iterator :raise ServerError: If the client receives a confirmation message \ which is not ``successful`` :raise TransportTimeoutError: If the transport can't re-establish \ connection with the server in :obj:`connection_timeout` time. """ while True: try: yield await self.receive() except ClientInvalidOperation: break async def __aenter__(self): """Enter the runtime context and call :obj:`open` :raise ClientInvalidOperation: If the client is already open, or in \ other words if it isn't :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the handshake or the first connect request \ gets rejected by the server. :return: The client object itself :rtype: Client """ try: await self.open() except Exception: await self.close() raise return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Exit the runtime context and call :obj:`open`""" await self.close() async def _get_message(self, connection_timeout): """Get the next incoming message :param connection_timeout: The maximum amount of time to wait for the \ transport to re-establish a connection with the server when the \ connection fails. :return: Incoming message :rtype: dict :raise TransportTimeoutError: If the transport can't re-establish \ connection with the server in :obj:`connection_timeout` time. :raise ServerError: If the connection gets closed by the server. """ tasks = [] # task waiting on connection timeout if connection_timeout: timeout_task = asyncio.ensure_future( self._wait_connection_timeout(connection_timeout), loop=self._loop ) tasks.append(timeout_task) # task waiting on incoming messages get_task = asyncio.ensure_future(self._incoming_queue.get(), loop=self._loop) tasks.append(get_task) # task waiting on server side disconnect server_disconnected_task = asyncio.ensure_future( self._transport.wait_for_state( TransportState.SERVER_DISCONNECTED), loop=self._loop ) tasks.append(server_disconnected_task) try: done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED, loop=self._loop) # cancel all pending tasks for task in pending: task.cancel() # handle the completed task if get_task in done: return get_task.result() elif server_disconnected_task in done: await self._check_server_disconnected() else: raise TransportTimeoutError("Lost connection with the " "server.") except asyncio.CancelledError: # cancel all tasks for task in tasks: task.cancel() raise async def _wait_connection_timeout(self, timeout): """Wait for and return when the transport can't re-establish \ connection with the server in *timeout* time :param timeout: The maximum amount of time to wait for the \ transport to re-establish a connection with the server when the \ connection fails. """ while True: await self._transport.wait_for_state(TransportState.CONNECTING) try: await asyncio.wait_for( self._transport.wait_for_state(TransportState.CONNECTED), timeout, loop=self._loop ) except asyncio.TimeoutError: break async def _check_server_disconnected(self): """Checks whether the current transport'state is :obj:`TransportState.SERVER_DISCONNECTED` and if it is then closes the client and raises an error :raise ServerError: If the current transport's state is \ :obj:`TransportState.SERVER_DISCONNECTED` """ if (self._transport and self._transport.state == TransportState.SERVER_DISCONNECTED): await self.close() raise ServerError("Connection closed by the server", self._transport.last_connect_result)