Coverage for aiocoap/transports/tcp.py: 92%
251 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-30 11:17 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-30 11:17 +0000
1# SPDX-FileCopyrightText: Christian Amsüss and the aiocoap contributors
2#
3# SPDX-License-Identifier: MIT
5import asyncio
6import socket
7from logging import Logger
8from typing import Dict, Optional, Set, Tuple
10from aiocoap.transports import rfc8323common
11from aiocoap import interfaces, error, util
12from aiocoap import COAP_PORT, Message
13from aiocoap import defaults
14from aiocoap.message import Direction
17def _extract_message_size(data: bytes):
18 """Read out the full length of a CoAP messsage represented by data.
20 Returns None if data is too short to read the (full) length.
22 The number returned is the number of bytes that has to be read into data to
23 start reading the next message; it consists of a constant term, the token
24 length and the extended length of options-plus-payload."""
26 if not data:
27 return None
29 length = data[0] >> 4
30 tokenoffset = 2
31 tkl = data[0] & 0x0F
33 if length >= 13:
34 if length == 13:
35 extlen = 1
36 offset = 13
37 elif length == 14:
38 extlen = 2
39 offset = 269
40 else:
41 extlen = 4
42 offset = 65805
43 if len(data) < extlen + 1:
44 return None
45 tokenoffset = 2 + extlen
46 length = int.from_bytes(data[1 : 1 + extlen], "big") + offset
47 return tokenoffset, tkl, length
50def _decode_message(data: bytes) -> Message:
51 tokenoffset, tkl, _ = _extract_message_size(data)
52 if tkl > 8:
53 raise error.UnparsableMessage("Overly long token")
54 code = data[tokenoffset - 1]
55 token = data[tokenoffset : tokenoffset + tkl]
57 msg = Message(code=code, _token=token)
59 msg.payload = msg.opt.decode(data[tokenoffset + tkl :])
60 msg.direction = Direction.INCOMING
62 return msg
65def _encode_length(length: int):
66 if length < 13:
67 return (length, b"")
68 elif length < 269:
69 return (13, (length - 13).to_bytes(1, "big"))
70 elif length < 65805:
71 return (14, (length - 269).to_bytes(2, "big"))
72 else:
73 return (15, (length - 65805).to_bytes(4, "big"))
76def _serialize(msg: Message) -> bytes:
77 data_list = [msg.opt.encode()]
78 if msg.payload:
79 data_list += [b"\xff", msg.payload]
80 data = b"".join(data_list)
81 length, extlen = _encode_length(len(data))
83 tkl = len(msg.token)
84 if tkl > 8:
85 raise ValueError("Overly long token")
87 return b"".join(
88 (bytes(((length << 4) | tkl,)), extlen, bytes((msg.code,)), msg.token, data)
89 )
92class TcpConnection(
93 asyncio.Protocol, rfc8323common.RFC8323Remote, interfaces.EndpointAddress
94):
95 # currently, both the protocol and the EndpointAddress are the same object.
96 # if, at a later point in time, the keepaliving of TCP connections should
97 # depend on whether the library user still keeps a usable address around,
98 # those functions could be split.
100 def __init__(self, ctx, log, loop, *, is_server) -> None:
101 super().__init__()
102 self._ctx = ctx
103 self.log = log
104 self.loop = loop
106 self._spool = b""
108 self._remote_settings = None
110 self._transport: Optional[asyncio.Transport] = None
111 self._local_is_server = is_server
113 @property
114 def scheme(self):
115 return self._ctx._scheme
117 def _send_message(self, msg: Message):
118 self.log.debug("Sending message: %r", msg)
119 assert self._transport is not None, (
120 "Attempted to send message before connection"
121 )
122 self._transport.write(_serialize(msg))
124 def _abort_with(self, abort_msg):
125 if self._transport is not None:
126 self._send_message(abort_msg)
127 self._transport.close()
128 else:
129 # FIXME: find out how this happens; i've only seen it after nmap
130 # runs against an aiocoap server and then shutting it down.
131 # "poisoning" the object to make sure this can not be exploited to
132 # bypass the server shutdown.
133 self._ctx = None
135 # implementing asyncio.Protocol
137 def connection_made(self, transport):
138 self._transport = transport
140 ssl_object = transport.get_extra_info("ssl_object")
141 if ssl_object is not None:
142 server_name = getattr(ssl_object, "indicated_server_name", None)
143 else:
144 server_name = None
146 # `host` already contains the interface identifier, so throwing away
147 # scope and interface identifier
148 self._local_hostinfo = transport.get_extra_info("sockname")[:2]
149 self._remote_hostinfo = transport.get_extra_info("peername")[:2]
151 def none_default_port(sockname):
152 return (
153 sockname[0],
154 None if sockname[1] == self._ctx._default_port else sockname[1],
155 )
157 self._local_hostinfo = none_default_port(self._local_hostinfo)
158 self._remote_hostinfo = none_default_port(self._remote_hostinfo)
160 # SNI information available
161 if server_name is not None:
162 if self._local_is_server:
163 self._local_hostinfo = (server_name, self._local_hostinfo[1])
164 else:
165 self._remote_hostinfo = (server_name, self._remote_hostinfo[1])
167 self._send_initial_csm()
169 def connection_lost(self, exc):
170 # FIXME react meaningfully:
171 # * send event through pool so it can propagate the error to all
172 # requests on the same remote
173 # * mark the address as erroneous so it won't be recognized by
174 # fill_or_recognize_remote
176 self._ctx._dispatch_error(self, exc)
178 def data_received(self, data):
179 # A rope would be more efficient here, but the expected case is that
180 # _spool is b"" and spool gets emptied soon -- most messages will just
181 # fit in a single TCP package and not be nagled together.
182 #
183 # (If this does become a bottleneck, say self._spool = SomeRope(b"")
184 # and barely change anything else).
186 self._spool += data
188 while True:
189 msglen = _extract_message_size(self._spool)
190 if msglen is None:
191 break
192 msglen = sum(msglen)
193 if msglen > self._my_max_message_size:
194 self.abort("Overly large message announced")
195 return
197 if msglen > len(self._spool):
198 break
200 msg = self._spool[:msglen]
201 try:
202 msg = _decode_message(msg)
203 except error.UnparsableMessage:
204 self.abort("Failed to parse message")
205 return
206 msg.remote = self
208 self.log.debug("Received message: %r", msg)
210 self._spool = self._spool[msglen:]
212 if msg.code.is_signalling():
213 try:
214 self._process_signaling(msg)
215 except rfc8323common.CloseConnection as e:
216 self._ctx._dispatch_error(self, e.args[0])
217 self._transport.close()
218 continue
220 if self._remote_settings is None:
221 self.abort("No CSM received")
222 return
224 self._ctx._dispatch_incoming(self, msg)
226 def eof_received(self):
227 # FIXME: as with connection_lost, but less noisy if announced
228 # FIXME: return true and initiate own shutdown if that is what CoAP prescribes
229 pass
231 def pause_writing(self):
232 # FIXME: do something ;-)
233 pass
235 def resume_writing(self):
236 # FIXME: do something ;-)
237 pass
239 # RFC8323Remote.release recommends subclassing this, but there's no easy
240 # awaitable here yet, and no important business to finish, timeout-wise.
243class _TCPPooling:
244 # implementing TokenInterface
246 def send_message(self, message, messageerror_monitor):
247 # Ignoring messageerror_monitor: CoAP over reliable transports has no
248 # way of indicating that a particular message was bad, it always shuts
249 # down the complete connection
251 if message.code.is_response():
252 no_response = (message.opt.no_response or 0) & (
253 1 << message.code.class_ - 1
254 ) != 0
255 if no_response:
256 return
258 message.opt.no_response = None
260 message.remote._send_message(message)
262 # used by the TcpConnection instances
264 def _dispatch_incoming(self, connection, msg):
265 if msg.code == 0:
266 pass
268 if msg.code.is_response():
269 self._tokenmanager.process_response(msg)
270 # ignoring the return value; unexpected responses can be the
271 # asynchronous result of cancelled observations
272 else:
273 self._tokenmanager.process_request(msg)
275 def _dispatch_error(self, connection, exc):
276 self._evict_from_pool(connection)
278 if self._tokenmanager is None:
279 if exc is not None:
280 self.log.warning("Ignoring late error during shutdown: %s", exc)
281 else:
282 # it's just a regular connection loss, that's to be expected during shutdown
283 pass
284 return
286 self._tokenmanager.dispatch_error(exc, connection)
288 # for diverting behavior of _TLSMixIn
289 _scheme = "coap+tcp"
290 _default_port = COAP_PORT
293class TCPServer(_TCPPooling, interfaces.TokenInterface):
294 def __init__(self) -> None:
295 self._pool: Set[TcpConnection] = set()
296 self.log: Optional[Logger] = None
297 self.server = None
299 @classmethod
300 async def create_server(
301 cls, bind, tman: interfaces.TokenManager, log, loop, *, _server_context=None
302 ):
303 self = cls()
304 self._tokenmanager = tman
305 self.log = log
306 # self.loop = loop
308 bind = bind or ("::", None)
309 bind = (
310 bind[0],
311 bind[1] + (self._default_port - COAP_PORT)
312 if bind[1]
313 else self._default_port,
314 )
316 def new_connection():
317 c = TcpConnection(self, log, loop, is_server=True)
318 self._pool.add(c)
319 return c
321 try:
322 server = await loop.create_server(
323 new_connection,
324 bind[0],
325 bind[1],
326 ssl=_server_context,
327 reuse_port=defaults.has_reuse_port(),
328 )
329 except socket.gaierror as e:
330 raise error.ResolutionError(
331 "No local bindable address found for %s" % bind[0]
332 ) from e
333 self.server = server
335 return self
337 def _evict_from_pool(self, connection):
338 # May easily happen twice, once when an error comes in and once when
339 # the connection is (subsequently) closed.
340 if connection in self._pool:
341 self._pool.remove(connection)
343 # implementing TokenInterface
345 async def fill_or_recognize_remote(self, message):
346 if (
347 message.remote is not None
348 and isinstance(message.remote, TcpConnection)
349 and message.remote._ctx is self
350 ):
351 return True
353 return False
355 async def shutdown(self):
356 self.log.debug("Shutting down server %r", self)
357 self._tokenmanager = None
358 self.server.close()
359 # Since server has been closed, we won't be getting any *more*
360 # connections, so we can process them all now:
361 shutdowns = [
362 asyncio.create_task(
363 c.release(),
364 name="Close client %s" % c,
365 )
366 for c in self._pool
367 ]
368 shutdowns.append(
369 asyncio.create_task(
370 self.server.wait_closed(), name="Close server %s" % self
371 ),
372 )
373 # There is at least one member, so we can just .wait()
374 await asyncio.wait(shutdowns)
377class TCPClient(_TCPPooling, interfaces.TokenInterface):
378 def __init__(self) -> None:
379 self._pool: Dict[
380 Tuple[str, int], TcpConnection
381 ] = {} #: (host, port) -> connection
382 # note that connections are filed by host name, so different names for
383 # the same address might end up with different connections, which is
384 # probably okay for TCP, and crucial for later work with TLS.
385 self.log: Optional[Logger] = None
386 self.loop: Optional[asyncio.AbstractEventLoop] = None
387 self.credentials = None
389 async def _spawn_protocol(self, message):
390 if message.unresolved_remote is None:
391 host = message.opt.uri_host
392 port = message.opt.uri_port or self._default_port
393 if host is None:
394 raise ValueError(
395 "No location found to send message to (neither in .opt.uri_host nor in .remote)"
396 )
397 else:
398 host, port = util.hostportsplit(message.unresolved_remote)
399 port = port or self._default_port
401 if (host, port) in self._pool:
402 return self._pool[(host, port)]
404 try:
405 _, protocol = await self.loop.create_connection(
406 lambda: TcpConnection(self, self.log, self.loop, is_server=False),
407 host,
408 port,
409 ssl=self._ssl_context_factory(message.unresolved_remote),
410 )
411 except socket.gaierror as e:
412 raise error.ResolutionError(
413 "No address information found for requests to %r" % host
414 ) from e
415 except OSError as e:
416 raise error.NetworkError("Connection failed to %r" % host) from e
418 self._pool[(host, port)] = protocol
420 return protocol
422 # for diverting behavior of TLSClient
423 def _ssl_context_factory(self, hostinfo):
424 return None
426 def _evict_from_pool(self, connection):
427 keys = []
428 for k, p in self._pool.items():
429 if p is connection:
430 keys.append(k)
431 # should really be zero or one
432 for k in keys:
433 self._pool.pop(k)
435 @classmethod
436 async def create_client_transport(
437 cls, tman: interfaces.TokenManager, log, loop, credentials=None
438 ):
439 # this is not actually asynchronous, and even though the interface
440 # between the context and the creation of interfaces is not fully
441 # standardized, this stays in the other inferfaces' style.
442 self = cls()
443 self._tokenmanager = tman
444 self.log = log
445 self.loop = loop
446 # used by the TLS variant; FIXME not well thought through
447 self.credentials = credentials
449 return self
451 # implementing TokenInterface
453 async def fill_or_recognize_remote(self, message):
454 if (
455 message.remote is not None
456 and isinstance(message.remote, TcpConnection)
457 and message.remote._ctx is self
458 ):
459 return True
461 if message.requested_scheme == self._scheme:
462 # FIXME: This could pool outgoing connections.
463 # (Checking if an incoming connection is a pool candidate is
464 # probably overkill because even if a URI can be constructed from a
465 # ephemeral client port, nobody but us can use it, and we can just
466 # set .remote).
467 message.remote = await self._spawn_protocol(message)
468 return True
470 return False
472 async def shutdown(self):
473 self.log.debug("Shutting down any outgoing connections on on %r", self)
474 self._tokenmanager = None
476 shutdowns = [
477 asyncio.create_task(
478 c.release(),
479 name="Close client %s" % c,
480 )
481 for c in self._pool.values()
482 ]
483 if not shutdowns:
484 # wait is documented to require a non-empty set
485 return
486 await asyncio.wait(shutdowns)