Coverage for aiocoap / transports / tcp.py: 92%

249 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-29 12:32 +0000

1# SPDX-FileCopyrightText: Christian Amsüss and the aiocoap contributors 

2# 

3# SPDX-License-Identifier: MIT 

4 

5import asyncio 

6import socket 

7from logging import Logger 

8from typing import Dict, Optional, Set, Tuple 

9 

10from aiocoap.transports import rfc8323common 

11from aiocoap import interfaces, error, util 

12from aiocoap import COAP_PORT, Message 

13from aiocoap import defaults 

14from aiocoap.message import Direction 

15 

16 

17def _extract_message_size(data: bytes): 

18 """Read out the full length of a CoAP message represented by data. 

19 

20 Returns None if data is too short to read the (full) length. 

21 

22 The number returned is the number of bytes that has to be read into data to 

23 start reading the next message; it consists of a constant term, the token 

24 length and the extended length of options-plus-payload.""" 

25 

26 if not data: 

27 return None 

28 

29 length = data[0] >> 4 

30 tokenoffset = 2 

31 tkl = data[0] & 0x0F 

32 

33 if length >= 13: 

34 if length == 13: 

35 extlen = 1 

36 offset = 13 

37 elif length == 14: 

38 extlen = 2 

39 offset = 269 

40 else: 

41 extlen = 4 

42 offset = 65805 

43 if len(data) < extlen + 1: 

44 return None 

45 tokenoffset = 2 + extlen 

46 length = int.from_bytes(data[1 : 1 + extlen], "big") + offset 

47 return tokenoffset, tkl, length 

48 

49 

50def _decode_message(data: bytes) -> Message: 

51 tokenoffset, tkl, _ = _extract_message_size(data) 

52 if tkl > 8: 

53 raise error.UnparsableMessage("Overly long token") 

54 code = data[tokenoffset - 1] 

55 token = data[tokenoffset : tokenoffset + tkl] 

56 

57 msg = Message(code=code, _token=token) 

58 

59 msg.payload = msg.opt.decode(data[tokenoffset + tkl :]) 

60 msg.direction = Direction.INCOMING 

61 

62 return msg 

63 

64 

65def _encode_length(length: int): 

66 if length < 13: 

67 return (length, b"") 

68 elif length < 269: 

69 return (13, (length - 13).to_bytes(1, "big")) 

70 elif length < 65805: 

71 return (14, (length - 269).to_bytes(2, "big")) 

72 else: 

73 return (15, (length - 65805).to_bytes(4, "big")) 

74 

75 

76def _serialize(msg: Message) -> bytes: 

77 data_list = [msg.opt.encode()] 

78 if msg.payload: 

79 data_list += [b"\xff", msg.payload] 

80 data = b"".join(data_list) 

81 length, extlen = _encode_length(len(data)) 

82 

83 tkl = len(msg.token) 

84 if tkl > 8: 

85 raise ValueError("Overly long token") 

86 

87 return b"".join( 

88 (bytes(((length << 4) | tkl,)), extlen, bytes((msg.code,)), msg.token, data) 

89 ) 

90 

91 

92class TcpConnection( 

93 asyncio.Protocol, rfc8323common.RFC8323Remote, interfaces.EndpointAddress 

94): 

95 # currently, both the protocol and the EndpointAddress are the same object. 

96 # if, at a later point in time, the keepaliving of TCP connections should 

97 # depend on whether the library user still keeps a usable address around, 

98 # those functions could be split. 

99 

100 def __init__(self, ctx, log, loop, *, is_server) -> None: 

101 super().__init__() 

102 self._ctx = ctx 

103 self.log = log 

104 self.loop = loop 

105 

106 self._spool = b"" 

107 

108 self._remote_settings = None 

109 

110 self._transport: Optional[asyncio.Transport] = None 

111 self._local_is_server = is_server 

112 

113 @property 

114 def scheme(self): 

115 return self._ctx._scheme 

116 

117 def _send_message(self, msg: Message): 

118 self.log.debug("Sending message: %r", msg) 

119 assert self._transport is not None, ( 

120 "Attempted to send message before connection" 

121 ) 

122 self._transport.write(_serialize(msg)) 

123 

124 def _abort_with(self, abort_msg): 

125 if self._transport is not None: 

126 self._send_message(abort_msg) 

127 self._transport.close() 

128 else: 

129 # FIXME: find out how this happens; i've only seen it after nmap 

130 # runs against an aiocoap server and then shutting it down. 

131 # "poisoning" the object to make sure this can not be exploited to 

132 # bypass the server shutdown. 

133 self._ctx = None 

134 

135 # implementing asyncio.Protocol 

136 

137 def connection_made(self, transport): 

138 self._transport = transport 

139 

140 ssl_object = transport.get_extra_info("ssl_object") 

141 if ssl_object is not None: 

142 server_name = getattr(ssl_object, "indicated_server_name", None) 

143 else: 

144 server_name = None 

145 

146 # `host` already contains the interface identifier, so throwing away 

147 # scope and interface identifier 

148 self._local_hostinfo = transport.get_extra_info("sockname")[:2] 

149 self._remote_hostinfo = transport.get_extra_info("peername")[:2] 

150 

151 def none_default_port(sockname): 

152 return ( 

153 sockname[0], 

154 None if sockname[1] == self._ctx._default_port else sockname[1], 

155 ) 

156 

157 self._local_hostinfo = none_default_port(self._local_hostinfo) 

158 self._remote_hostinfo = none_default_port(self._remote_hostinfo) 

159 

160 # SNI information available 

161 if server_name is not None: 

162 if self._local_is_server: 

163 self._local_hostinfo = (server_name, self._local_hostinfo[1]) 

164 else: 

165 self._remote_hostinfo = (server_name, self._remote_hostinfo[1]) 

166 

167 self._send_initial_csm() 

168 

169 def connection_lost(self, exc): 

170 # FIXME react meaningfully: 

171 # * send event through pool so it can propagate the error to all 

172 # requests on the same remote 

173 # * mark the address as erroneous so it won't be recognized by 

174 # determine_remote 

175 

176 self._ctx._dispatch_error(self, exc) 

177 

178 def data_received(self, data): 

179 # A rope would be more efficient here, but the expected case is that 

180 # _spool is b"" and spool gets emptied soon -- most messages will just 

181 # fit in a single TCP package and not be nagled together. 

182 # 

183 # (If this does become a bottleneck, say self._spool = SomeRope(b"") 

184 # and barely change anything else). 

185 

186 self._spool += data 

187 

188 while True: 

189 msglen = _extract_message_size(self._spool) 

190 if msglen is None: 

191 break 

192 msglen = sum(msglen) 

193 if msglen > self._my_max_message_size: 

194 self.abort("Overly large message announced") 

195 return 

196 

197 if msglen > len(self._spool): 

198 break 

199 

200 msg = self._spool[:msglen] 

201 try: 

202 msg = _decode_message(msg) 

203 except error.UnparsableMessage: 

204 self.abort("Failed to parse message") 

205 return 

206 msg.remote = self 

207 

208 self.log.debug("Received message: %r", msg) 

209 

210 self._spool = self._spool[msglen:] 

211 

212 if msg.code.is_signalling(): 

213 try: 

214 self._process_signaling(msg) 

215 except rfc8323common.CloseConnection as e: 

216 self._ctx._dispatch_error(self, e.args[0]) 

217 self._transport.close() 

218 continue 

219 

220 if self._remote_settings is None: 

221 self.abort("No CSM received") 

222 return 

223 

224 self._ctx._dispatch_incoming(self, msg) 

225 

226 def eof_received(self): 

227 # FIXME: as with connection_lost, but less noisy if announced 

228 # FIXME: return true and initiate own shutdown if that is what CoAP prescribes 

229 pass 

230 

231 def pause_writing(self): 

232 # FIXME: do something ;-) 

233 pass 

234 

235 def resume_writing(self): 

236 # FIXME: do something ;-) 

237 pass 

238 

239 # RFC8323Remote.release recommends subclassing this, but there's no easy 

240 # awaitable here yet, and no important business to finish, timeout-wise. 

241 

242 

243class _TCPPooling: 

244 # implementing TokenInterface 

245 

246 def send_message(self, message, messageerror_monitor): 

247 # Ignoring messageerror_monitor: CoAP over reliable transports has no 

248 # way of indicating that a particular message was bad, it always shuts 

249 # down the complete connection 

250 

251 if message.code.is_response(): 

252 no_response = (message.opt.no_response or 0) & ( 

253 1 << message.code.class_ - 1 

254 ) != 0 

255 if no_response: 

256 return 

257 

258 message.opt.no_response = None 

259 

260 message.remote._send_message(message) 

261 

262 # used by the TcpConnection instances 

263 

264 def _dispatch_incoming(self, connection, msg): 

265 if msg.code == 0: 

266 pass 

267 

268 if msg.code.is_response(): 

269 self._tokenmanager.process_response(msg) 

270 # ignoring the return value; unexpected responses can be the 

271 # asynchronous result of cancelled observations 

272 else: 

273 self._tokenmanager.process_request(msg) 

274 

275 def _dispatch_error(self, connection, exc): 

276 self._evict_from_pool(connection) 

277 

278 if self._tokenmanager is None: 

279 if exc is not None: 

280 self.log.warning("Ignoring late error during shutdown: %s", exc) 

281 else: 

282 # it's just a regular connection loss, that's to be expected during shutdown 

283 pass 

284 return 

285 

286 self._tokenmanager.dispatch_error(exc, connection) 

287 

288 # for diverting behavior of _TLSMixIn 

289 _scheme = "coap+tcp" 

290 _default_port = COAP_PORT 

291 

292 

293class TCPServer(_TCPPooling, interfaces.TokenInterface): 

294 def __init__(self) -> None: 

295 self._pool: Set[TcpConnection] = set() 

296 self.log: Optional[Logger] = None 

297 self.server = None 

298 

299 @classmethod 

300 async def create_server( 

301 cls, bind, tman: interfaces.TokenManager, log, loop, *, _server_context=None 

302 ): 

303 self = cls() 

304 self._tokenmanager = tman 

305 self.log = log 

306 # self.loop = loop 

307 

308 bind = bind or ("::", None) 

309 bind = ( 

310 bind[0], 

311 bind[1] + (self._default_port - COAP_PORT) 

312 if bind[1] 

313 else self._default_port, 

314 ) 

315 

316 def new_connection(): 

317 c = TcpConnection(self, log, loop, is_server=True) 

318 self._pool.add(c) 

319 return c 

320 

321 try: 

322 server = await loop.create_server( 

323 new_connection, 

324 bind[0], 

325 bind[1], 

326 ssl=_server_context, 

327 reuse_port=defaults.has_reuse_port(), 

328 ) 

329 except socket.gaierror as e: 

330 raise error.ResolutionError( 

331 "No local bindable address found for %s" % bind[0] 

332 ) from e 

333 self.server = server 

334 

335 return self 

336 

337 def _evict_from_pool(self, connection): 

338 # May easily happen twice, once when an error comes in and once when 

339 # the connection is (subsequently) closed. 

340 if connection in self._pool: 

341 self._pool.remove(connection) 

342 

343 # implementing TokenInterface 

344 

345 async def recognize_remote(self, message): 

346 return isinstance(message.remote, TcpConnection) and message.remote._ctx is self 

347 

348 async def determine_remote(self, message): 

349 # We're a server at the transport level 

350 return None 

351 

352 async def shutdown(self): 

353 self.log.debug("Shutting down server %r", self) 

354 self._tokenmanager = None 

355 self.server.close() 

356 # Since server has been closed, we won't be getting any *more* 

357 # connections, so we can process them all now: 

358 shutdowns = [ 

359 asyncio.create_task( 

360 c.release(), 

361 name="Close client %s" % c, 

362 ) 

363 for c in self._pool 

364 ] 

365 shutdowns.append( 

366 asyncio.create_task( 

367 self.server.wait_closed(), name="Close server %s" % self 

368 ), 

369 ) 

370 # There is at least one member, so we can just .wait() 

371 await asyncio.wait(shutdowns) 

372 

373 

374class TCPClient(_TCPPooling, interfaces.TokenInterface): 

375 def __init__(self) -> None: 

376 self._pool: Dict[ 

377 Tuple[str, int], TcpConnection 

378 ] = {} #: (host, port) -> connection 

379 # note that connections are filed by host name, so different names for 

380 # the same address might end up with different connections, which is 

381 # probably okay for TCP, and crucial for later work with TLS. 

382 self.log: Optional[Logger] = None 

383 self.loop: Optional[asyncio.AbstractEventLoop] = None 

384 self.credentials = None 

385 

386 async def _spawn_protocol(self, message): 

387 if message.unresolved_remote is None: 

388 host = message.opt.uri_host 

389 port = message.opt.uri_port or self._default_port 

390 if host is None: 

391 raise ValueError( 

392 "No location found to send message to (neither in .opt.uri_host nor in .remote)" 

393 ) 

394 else: 

395 host, port = util.hostportsplit(message.unresolved_remote) 

396 port = port or self._default_port 

397 

398 if (host, port) in self._pool: 

399 return self._pool[(host, port)] 

400 

401 try: 

402 _, protocol = await self.loop.create_connection( 

403 lambda: TcpConnection(self, self.log, self.loop, is_server=False), 

404 host, 

405 port, 

406 ssl=self._ssl_context_factory(message.unresolved_remote), 

407 ) 

408 except socket.gaierror as e: 

409 raise error.ResolutionError( 

410 "No address information found for requests to %r" % host 

411 ) from e 

412 except OSError as e: 

413 raise error.NetworkError("Connection failed to %r" % host) from e 

414 

415 self._pool[(host, port)] = protocol 

416 

417 return protocol 

418 

419 # for diverting behavior of TLSClient 

420 def _ssl_context_factory(self, hostinfo): 

421 return None 

422 

423 def _evict_from_pool(self, connection): 

424 keys = [] 

425 for k, p in self._pool.items(): 

426 if p is connection: 

427 keys.append(k) 

428 # should really be zero or one 

429 for k in keys: 

430 self._pool.pop(k) 

431 

432 @classmethod 

433 async def create_client_transport( 

434 cls, tman: interfaces.TokenManager, log, loop, credentials=None 

435 ): 

436 # this is not actually asynchronous, and even though the interface 

437 # between the context and the creation of interfaces is not fully 

438 # standardized, this stays in the other inferfaces' style. 

439 self = cls() 

440 self._tokenmanager = tman 

441 self.log = log 

442 self.loop = loop 

443 # used by the TLS variant; FIXME not well thought through 

444 self.credentials = credentials 

445 

446 return self 

447 

448 # implementing TokenInterface 

449 

450 async def recognize_remote(self, message): 

451 return isinstance(message.remote, TcpConnection) and message.remote._ctx is self 

452 

453 async def determine_remote(self, message): 

454 if message.requested_scheme == self._scheme: 

455 # FIXME: This could pool outgoing connections. 

456 # (Checking if an incoming connection is a pool candidate is 

457 # probably overkill because even if a URI can be constructed from a 

458 # ephemeral client port, nobody but us can use it, and we can just 

459 # set .remote). 

460 return await self._spawn_protocol(message) 

461 

462 async def shutdown(self): 

463 self.log.debug("Shutting down any outgoing connections on on %r", self) 

464 self._tokenmanager = None 

465 

466 shutdowns = [ 

467 asyncio.create_task( 

468 c.release(), 

469 name="Close client %s" % c, 

470 ) 

471 for c in self._pool.values() 

472 ] 

473 if not shutdowns: 

474 # wait is documented to require a non-empty set 

475 return 

476 await asyncio.wait(shutdowns)