Coverage for aiocoap/transports/tcp.py: 91%
249 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 12:34 +0000
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 12:34 +0000
1# SPDX-FileCopyrightText: Christian Amsüss and the aiocoap contributors
2#
3# SPDX-License-Identifier: MIT
5import asyncio
6import socket
7from logging import Logger
8from typing import Dict, Optional, Set, Tuple
10from aiocoap.transports import rfc8323common
11from aiocoap import interfaces, error, util
12from aiocoap import COAP_PORT, Message
13from aiocoap import defaults
16def _extract_message_size(data: bytes):
17 """Read out the full length of a CoAP messsage represented by data.
19 Returns None if data is too short to read the (full) length.
21 The number returned is the number of bytes that has to be read into data to
22 start reading the next message; it consists of a constant term, the token
23 length and the extended length of options-plus-payload."""
25 if not data:
26 return None
28 length = data[0] >> 4
29 tokenoffset = 2
30 tkl = data[0] & 0x0F
32 if length >= 13:
33 if length == 13:
34 extlen = 1
35 offset = 13
36 elif length == 14:
37 extlen = 2
38 offset = 269
39 else:
40 extlen = 4
41 offset = 65805
42 if len(data) < extlen + 1:
43 return None
44 tokenoffset = 2 + extlen
45 length = int.from_bytes(data[1 : 1 + extlen], "big") + offset
46 return tokenoffset, tkl, length
49def _decode_message(data: bytes) -> Message:
50 tokenoffset, tkl, _ = _extract_message_size(data)
51 if tkl > 8:
52 raise error.UnparsableMessage("Overly long token")
53 code = data[tokenoffset - 1]
54 token = data[tokenoffset : tokenoffset + tkl]
56 msg = Message(code=code, token=token)
58 msg.payload = msg.opt.decode(data[tokenoffset + tkl :])
60 return msg
63def _encode_length(length: int):
64 if length < 13:
65 return (length, b"")
66 elif length < 269:
67 return (13, (length - 13).to_bytes(1, "big"))
68 elif length < 65805:
69 return (14, (length - 269).to_bytes(2, "big"))
70 else:
71 return (15, (length - 65805).to_bytes(4, "big"))
74def _serialize(msg: Message) -> bytes:
75 data_list = [msg.opt.encode()]
76 if msg.payload:
77 data_list += [b"\xff", msg.payload]
78 data = b"".join(data_list)
79 length, extlen = _encode_length(len(data))
81 tkl = len(msg.token)
82 if tkl > 8:
83 raise ValueError("Overly long token")
85 return b"".join(
86 (bytes(((length << 4) | tkl,)), extlen, bytes((msg.code,)), msg.token, data)
87 )
90class TcpConnection(
91 asyncio.Protocol, rfc8323common.RFC8323Remote, interfaces.EndpointAddress
92):
93 # currently, both the protocol and the EndpointAddress are the same object.
94 # if, at a later point in time, the keepaliving of TCP connections should
95 # depend on whether the library user still keeps a usable address around,
96 # those functions could be split.
98 def __init__(self, ctx, log, loop, *, is_server) -> None:
99 super().__init__()
100 self._ctx = ctx
101 self.log = log
102 self.loop = loop
104 self._spool = b""
106 self._remote_settings = None
108 self._transport: Optional[asyncio.Transport] = None
109 self._local_is_server = is_server
111 @property
112 def scheme(self):
113 return self._ctx._scheme
115 def _send_message(self, msg: Message):
116 self.log.debug("Sending message: %r", msg)
117 assert (
118 self._transport is not None
119 ), "Attempted to send message before connection"
120 self._transport.write(_serialize(msg))
122 def _abort_with(self, abort_msg):
123 if self._transport is not None:
124 self._send_message(abort_msg)
125 self._transport.close()
126 else:
127 # FIXME: find out how this happens; i've only seen it after nmap
128 # runs against an aiocoap server and then shutting it down.
129 # "poisoning" the object to make sure this can not be exploited to
130 # bypass the server shutdown.
131 self._ctx = None
133 # implementing asyncio.Protocol
135 def connection_made(self, transport):
136 self._transport = transport
138 ssl_object = transport.get_extra_info("ssl_object")
139 if ssl_object is not None:
140 server_name = getattr(ssl_object, "indicated_server_name", None)
141 else:
142 server_name = None
144 # `host` already contains the interface identifier, so throwing away
145 # scope and interface identifier
146 self._local_hostinfo = transport.get_extra_info("sockname")[:2]
147 self._remote_hostinfo = transport.get_extra_info("peername")[:2]
149 def none_default_port(sockname):
150 return (
151 sockname[0],
152 None if sockname[1] == self._ctx._default_port else sockname[1],
153 )
155 self._local_hostinfo = none_default_port(self._local_hostinfo)
156 self._remote_hostinfo = none_default_port(self._remote_hostinfo)
158 # SNI information available
159 if server_name is not None:
160 if self._local_is_server:
161 self._local_hostinfo = (server_name, self._local_hostinfo[1])
162 else:
163 self._remote_hostinfo = (server_name, self._remote_hostinfo[1])
165 self._send_initial_csm()
167 def connection_lost(self, exc):
168 # FIXME react meaningfully:
169 # * send event through pool so it can propagate the error to all
170 # requests on the same remote
171 # * mark the address as erroneous so it won't be recognized by
172 # fill_or_recognize_remote
174 self._ctx._dispatch_error(self, exc)
176 def data_received(self, data):
177 # A rope would be more efficient here, but the expected case is that
178 # _spool is b"" and spool gets emptied soon -- most messages will just
179 # fit in a single TCP package and not be nagled together.
180 #
181 # (If this does become a bottleneck, say self._spool = SomeRope(b"")
182 # and barely change anything else).
184 self._spool += data
186 while True:
187 msglen = _extract_message_size(self._spool)
188 if msglen is None:
189 break
190 msglen = sum(msglen)
191 if msglen > self._my_max_message_size:
192 self.abort("Overly large message announced")
193 return
195 if msglen > len(self._spool):
196 break
198 msg = self._spool[:msglen]
199 try:
200 msg = _decode_message(msg)
201 except error.UnparsableMessage:
202 self.abort("Failed to parse message")
203 return
204 msg.remote = self
206 self.log.debug("Received message: %r", msg)
208 self._spool = self._spool[msglen:]
210 if msg.code.is_signalling():
211 try:
212 self._process_signaling(msg)
213 except rfc8323common.CloseConnection as e:
214 self._ctx._dispatch_error(self, e.args[0])
215 self._transport.close()
216 continue
218 if self._remote_settings is None:
219 self.abort("No CSM received")
220 return
222 self._ctx._dispatch_incoming(self, msg)
224 def eof_received(self):
225 # FIXME: as with connection_lost, but less noisy if announced
226 # FIXME: return true and initiate own shutdown if that is what CoAP prescribes
227 pass
229 def pause_writing(self):
230 # FIXME: do something ;-)
231 pass
233 def resume_writing(self):
234 # FIXME: do something ;-)
235 pass
237 # RFC8323Remote.release recommends subclassing this, but there's no easy
238 # awaitable here yet, and no important business to finish, timeout-wise.
241class _TCPPooling:
242 # implementing TokenInterface
244 def send_message(self, message, messageerror_monitor):
245 # Ignoring messageerror_monitor: CoAP over reliable transports has no
246 # way of indicating that a particular message was bad, it always shuts
247 # down the complete connection
249 if message.code.is_response():
250 no_response = (message.opt.no_response or 0) & (
251 1 << message.code.class_ - 1
252 ) != 0
253 if no_response:
254 return
256 message.opt.no_response = None
258 message.remote._send_message(message)
260 # used by the TcpConnection instances
262 def _dispatch_incoming(self, connection, msg):
263 if msg.code == 0:
264 pass
266 if msg.code.is_response():
267 self._tokenmanager.process_response(msg)
268 # ignoring the return value; unexpected responses can be the
269 # asynchronous result of cancelled observations
270 else:
271 self._tokenmanager.process_request(msg)
273 def _dispatch_error(self, connection, exc):
274 self._evict_from_pool(connection)
276 if self._tokenmanager is None:
277 if exc is not None:
278 self.log.warning("Ignoring late error during shutdown: %s", exc)
279 else:
280 # it's just a regular connection loss, that's to be expected during shutdown
281 pass
282 return
284 self._tokenmanager.dispatch_error(exc, connection)
286 # for diverting behavior of _TLSMixIn
287 _scheme = "coap+tcp"
288 _default_port = COAP_PORT
291class TCPServer(_TCPPooling, interfaces.TokenInterface):
292 def __init__(self) -> None:
293 self._pool: Set[TcpConnection] = set()
294 self.log: Optional[Logger] = None
295 self.server = None
297 @classmethod
298 async def create_server(
299 cls, bind, tman: interfaces.TokenManager, log, loop, *, _server_context=None
300 ):
301 self = cls()
302 self._tokenmanager = tman
303 self.log = log
304 # self.loop = loop
306 bind = bind or ("::", None)
307 bind = (
308 bind[0],
309 bind[1] + (self._default_port - COAP_PORT)
310 if bind[1]
311 else self._default_port,
312 )
314 def new_connection():
315 c = TcpConnection(self, log, loop, is_server=True)
316 self._pool.add(c)
317 return c
319 try:
320 server = await loop.create_server(
321 new_connection,
322 bind[0],
323 bind[1],
324 ssl=_server_context,
325 reuse_port=defaults.has_reuse_port(),
326 )
327 except socket.gaierror as e:
328 raise error.ResolutionError(
329 "No local bindable address found for %s" % bind[0]
330 ) from e
331 self.server = server
333 return self
335 def _evict_from_pool(self, connection):
336 # May easily happen twice, once when an error comes in and once when
337 # the connection is (subsequently) closed.
338 if connection in self._pool:
339 self._pool.remove(connection)
341 # implementing TokenInterface
343 async def fill_or_recognize_remote(self, message):
344 if (
345 message.remote is not None
346 and isinstance(message.remote, TcpConnection)
347 and message.remote._ctx is self
348 ):
349 return True
351 return False
353 async def shutdown(self):
354 self.log.debug("Shutting down server %r", self)
355 self._tokenmanager = None
356 self.server.close()
357 # Since server has been closed, we won't be getting any *more*
358 # connections, so we can process them all now:
359 shutdowns = [
360 asyncio.create_task(
361 c.release(),
362 name="Close client %s" % c,
363 )
364 for c in self._pool
365 ]
366 shutdowns.append(
367 asyncio.create_task(
368 self.server.wait_closed(), name="Close server %s" % self
369 ),
370 )
371 # There is at least one member, so we can just .wait()
372 await asyncio.wait(shutdowns)
375class TCPClient(_TCPPooling, interfaces.TokenInterface):
376 def __init__(self) -> None:
377 self._pool: Dict[
378 Tuple[str, int], TcpConnection
379 ] = {} #: (host, port) -> connection
380 # note that connections are filed by host name, so different names for
381 # the same address might end up with different connections, which is
382 # probably okay for TCP, and crucial for later work with TLS.
383 self.log: Optional[Logger] = None
384 self.loop: Optional[asyncio.AbstractEventLoop] = None
385 self.credentials = None
387 async def _spawn_protocol(self, message):
388 if message.unresolved_remote is None:
389 host = message.opt.uri_host
390 port = message.opt.uri_port or self._default_port
391 if host is None:
392 raise ValueError(
393 "No location found to send message to (neither in .opt.uri_host nor in .remote)"
394 )
395 else:
396 host, port = util.hostportsplit(message.unresolved_remote)
397 port = port or self._default_port
399 if (host, port) in self._pool:
400 return self._pool[(host, port)]
402 try:
403 _, protocol = await self.loop.create_connection(
404 lambda: TcpConnection(self, self.log, self.loop, is_server=False),
405 host,
406 port,
407 ssl=self._ssl_context_factory(message.unresolved_remote),
408 )
409 except socket.gaierror as e:
410 raise error.ResolutionError(
411 "No address information found for requests to %r" % host
412 ) from e
413 except OSError as e:
414 raise error.NetworkError("Connection failed to %r" % host) from e
416 self._pool[(host, port)] = protocol
418 return protocol
420 # for diverting behavior of TLSClient
421 def _ssl_context_factory(self, hostinfo):
422 return None
424 def _evict_from_pool(self, connection):
425 keys = []
426 for k, p in self._pool.items():
427 if p is connection:
428 keys.append(k)
429 # should really be zero or one
430 for k in keys:
431 self._pool.pop(k)
433 @classmethod
434 async def create_client_transport(
435 cls, tman: interfaces.TokenManager, log, loop, credentials=None
436 ):
437 # this is not actually asynchronous, and even though the interface
438 # between the context and the creation of interfaces is not fully
439 # standardized, this stays in the other inferfaces' style.
440 self = cls()
441 self._tokenmanager = tman
442 self.log = log
443 self.loop = loop
444 # used by the TLS variant; FIXME not well thought through
445 self.credentials = credentials
447 return self
449 # implementing TokenInterface
451 async def fill_or_recognize_remote(self, message):
452 if (
453 message.remote is not None
454 and isinstance(message.remote, TcpConnection)
455 and message.remote._ctx is self
456 ):
457 return True
459 if message.requested_scheme == self._scheme:
460 # FIXME: This could pool outgoing connections.
461 # (Checking if an incoming connection is a pool candidate is
462 # probably overkill because even if a URI can be constructed from a
463 # ephemeral client port, nobody but us can use it, and we can just
464 # set .remote).
465 message.remote = await self._spawn_protocol(message)
466 return True
468 return False
470 async def shutdown(self):
471 self.log.debug("Shutting down any outgoing connections on on %r", self)
472 self._tokenmanager = None
474 shutdowns = [
475 asyncio.create_task(
476 c.release(),
477 name="Close client %s" % c,
478 )
479 for c in self._pool.values()
480 ]
481 if not shutdowns:
482 # wait is documented to require a non-empty set
483 return
484 await asyncio.wait(shutdowns)