Coverage for aiocoap/transports/tcp.py: 92%

251 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-30 11:17 +0000

1# SPDX-FileCopyrightText: Christian Amsüss and the aiocoap contributors 

2# 

3# SPDX-License-Identifier: MIT 

4 

5import asyncio 

6import socket 

7from logging import Logger 

8from typing import Dict, Optional, Set, Tuple 

9 

10from aiocoap.transports import rfc8323common 

11from aiocoap import interfaces, error, util 

12from aiocoap import COAP_PORT, Message 

13from aiocoap import defaults 

14from aiocoap.message import Direction 

15 

16 

17def _extract_message_size(data: bytes): 

18 """Read out the full length of a CoAP messsage represented by data. 

19 

20 Returns None if data is too short to read the (full) length. 

21 

22 The number returned is the number of bytes that has to be read into data to 

23 start reading the next message; it consists of a constant term, the token 

24 length and the extended length of options-plus-payload.""" 

25 

26 if not data: 

27 return None 

28 

29 length = data[0] >> 4 

30 tokenoffset = 2 

31 tkl = data[0] & 0x0F 

32 

33 if length >= 13: 

34 if length == 13: 

35 extlen = 1 

36 offset = 13 

37 elif length == 14: 

38 extlen = 2 

39 offset = 269 

40 else: 

41 extlen = 4 

42 offset = 65805 

43 if len(data) < extlen + 1: 

44 return None 

45 tokenoffset = 2 + extlen 

46 length = int.from_bytes(data[1 : 1 + extlen], "big") + offset 

47 return tokenoffset, tkl, length 

48 

49 

50def _decode_message(data: bytes) -> Message: 

51 tokenoffset, tkl, _ = _extract_message_size(data) 

52 if tkl > 8: 

53 raise error.UnparsableMessage("Overly long token") 

54 code = data[tokenoffset - 1] 

55 token = data[tokenoffset : tokenoffset + tkl] 

56 

57 msg = Message(code=code, _token=token) 

58 

59 msg.payload = msg.opt.decode(data[tokenoffset + tkl :]) 

60 msg.direction = Direction.INCOMING 

61 

62 return msg 

63 

64 

65def _encode_length(length: int): 

66 if length < 13: 

67 return (length, b"") 

68 elif length < 269: 

69 return (13, (length - 13).to_bytes(1, "big")) 

70 elif length < 65805: 

71 return (14, (length - 269).to_bytes(2, "big")) 

72 else: 

73 return (15, (length - 65805).to_bytes(4, "big")) 

74 

75 

76def _serialize(msg: Message) -> bytes: 

77 data_list = [msg.opt.encode()] 

78 if msg.payload: 

79 data_list += [b"\xff", msg.payload] 

80 data = b"".join(data_list) 

81 length, extlen = _encode_length(len(data)) 

82 

83 tkl = len(msg.token) 

84 if tkl > 8: 

85 raise ValueError("Overly long token") 

86 

87 return b"".join( 

88 (bytes(((length << 4) | tkl,)), extlen, bytes((msg.code,)), msg.token, data) 

89 ) 

90 

91 

92class TcpConnection( 

93 asyncio.Protocol, rfc8323common.RFC8323Remote, interfaces.EndpointAddress 

94): 

95 # currently, both the protocol and the EndpointAddress are the same object. 

96 # if, at a later point in time, the keepaliving of TCP connections should 

97 # depend on whether the library user still keeps a usable address around, 

98 # those functions could be split. 

99 

100 def __init__(self, ctx, log, loop, *, is_server) -> None: 

101 super().__init__() 

102 self._ctx = ctx 

103 self.log = log 

104 self.loop = loop 

105 

106 self._spool = b"" 

107 

108 self._remote_settings = None 

109 

110 self._transport: Optional[asyncio.Transport] = None 

111 self._local_is_server = is_server 

112 

113 @property 

114 def scheme(self): 

115 return self._ctx._scheme 

116 

117 def _send_message(self, msg: Message): 

118 self.log.debug("Sending message: %r", msg) 

119 assert self._transport is not None, ( 

120 "Attempted to send message before connection" 

121 ) 

122 self._transport.write(_serialize(msg)) 

123 

124 def _abort_with(self, abort_msg): 

125 if self._transport is not None: 

126 self._send_message(abort_msg) 

127 self._transport.close() 

128 else: 

129 # FIXME: find out how this happens; i've only seen it after nmap 

130 # runs against an aiocoap server and then shutting it down. 

131 # "poisoning" the object to make sure this can not be exploited to 

132 # bypass the server shutdown. 

133 self._ctx = None 

134 

135 # implementing asyncio.Protocol 

136 

137 def connection_made(self, transport): 

138 self._transport = transport 

139 

140 ssl_object = transport.get_extra_info("ssl_object") 

141 if ssl_object is not None: 

142 server_name = getattr(ssl_object, "indicated_server_name", None) 

143 else: 

144 server_name = None 

145 

146 # `host` already contains the interface identifier, so throwing away 

147 # scope and interface identifier 

148 self._local_hostinfo = transport.get_extra_info("sockname")[:2] 

149 self._remote_hostinfo = transport.get_extra_info("peername")[:2] 

150 

151 def none_default_port(sockname): 

152 return ( 

153 sockname[0], 

154 None if sockname[1] == self._ctx._default_port else sockname[1], 

155 ) 

156 

157 self._local_hostinfo = none_default_port(self._local_hostinfo) 

158 self._remote_hostinfo = none_default_port(self._remote_hostinfo) 

159 

160 # SNI information available 

161 if server_name is not None: 

162 if self._local_is_server: 

163 self._local_hostinfo = (server_name, self._local_hostinfo[1]) 

164 else: 

165 self._remote_hostinfo = (server_name, self._remote_hostinfo[1]) 

166 

167 self._send_initial_csm() 

168 

169 def connection_lost(self, exc): 

170 # FIXME react meaningfully: 

171 # * send event through pool so it can propagate the error to all 

172 # requests on the same remote 

173 # * mark the address as erroneous so it won't be recognized by 

174 # fill_or_recognize_remote 

175 

176 self._ctx._dispatch_error(self, exc) 

177 

178 def data_received(self, data): 

179 # A rope would be more efficient here, but the expected case is that 

180 # _spool is b"" and spool gets emptied soon -- most messages will just 

181 # fit in a single TCP package and not be nagled together. 

182 # 

183 # (If this does become a bottleneck, say self._spool = SomeRope(b"") 

184 # and barely change anything else). 

185 

186 self._spool += data 

187 

188 while True: 

189 msglen = _extract_message_size(self._spool) 

190 if msglen is None: 

191 break 

192 msglen = sum(msglen) 

193 if msglen > self._my_max_message_size: 

194 self.abort("Overly large message announced") 

195 return 

196 

197 if msglen > len(self._spool): 

198 break 

199 

200 msg = self._spool[:msglen] 

201 try: 

202 msg = _decode_message(msg) 

203 except error.UnparsableMessage: 

204 self.abort("Failed to parse message") 

205 return 

206 msg.remote = self 

207 

208 self.log.debug("Received message: %r", msg) 

209 

210 self._spool = self._spool[msglen:] 

211 

212 if msg.code.is_signalling(): 

213 try: 

214 self._process_signaling(msg) 

215 except rfc8323common.CloseConnection as e: 

216 self._ctx._dispatch_error(self, e.args[0]) 

217 self._transport.close() 

218 continue 

219 

220 if self._remote_settings is None: 

221 self.abort("No CSM received") 

222 return 

223 

224 self._ctx._dispatch_incoming(self, msg) 

225 

226 def eof_received(self): 

227 # FIXME: as with connection_lost, but less noisy if announced 

228 # FIXME: return true and initiate own shutdown if that is what CoAP prescribes 

229 pass 

230 

231 def pause_writing(self): 

232 # FIXME: do something ;-) 

233 pass 

234 

235 def resume_writing(self): 

236 # FIXME: do something ;-) 

237 pass 

238 

239 # RFC8323Remote.release recommends subclassing this, but there's no easy 

240 # awaitable here yet, and no important business to finish, timeout-wise. 

241 

242 

243class _TCPPooling: 

244 # implementing TokenInterface 

245 

246 def send_message(self, message, messageerror_monitor): 

247 # Ignoring messageerror_monitor: CoAP over reliable transports has no 

248 # way of indicating that a particular message was bad, it always shuts 

249 # down the complete connection 

250 

251 if message.code.is_response(): 

252 no_response = (message.opt.no_response or 0) & ( 

253 1 << message.code.class_ - 1 

254 ) != 0 

255 if no_response: 

256 return 

257 

258 message.opt.no_response = None 

259 

260 message.remote._send_message(message) 

261 

262 # used by the TcpConnection instances 

263 

264 def _dispatch_incoming(self, connection, msg): 

265 if msg.code == 0: 

266 pass 

267 

268 if msg.code.is_response(): 

269 self._tokenmanager.process_response(msg) 

270 # ignoring the return value; unexpected responses can be the 

271 # asynchronous result of cancelled observations 

272 else: 

273 self._tokenmanager.process_request(msg) 

274 

275 def _dispatch_error(self, connection, exc): 

276 self._evict_from_pool(connection) 

277 

278 if self._tokenmanager is None: 

279 if exc is not None: 

280 self.log.warning("Ignoring late error during shutdown: %s", exc) 

281 else: 

282 # it's just a regular connection loss, that's to be expected during shutdown 

283 pass 

284 return 

285 

286 self._tokenmanager.dispatch_error(exc, connection) 

287 

288 # for diverting behavior of _TLSMixIn 

289 _scheme = "coap+tcp" 

290 _default_port = COAP_PORT 

291 

292 

293class TCPServer(_TCPPooling, interfaces.TokenInterface): 

294 def __init__(self) -> None: 

295 self._pool: Set[TcpConnection] = set() 

296 self.log: Optional[Logger] = None 

297 self.server = None 

298 

299 @classmethod 

300 async def create_server( 

301 cls, bind, tman: interfaces.TokenManager, log, loop, *, _server_context=None 

302 ): 

303 self = cls() 

304 self._tokenmanager = tman 

305 self.log = log 

306 # self.loop = loop 

307 

308 bind = bind or ("::", None) 

309 bind = ( 

310 bind[0], 

311 bind[1] + (self._default_port - COAP_PORT) 

312 if bind[1] 

313 else self._default_port, 

314 ) 

315 

316 def new_connection(): 

317 c = TcpConnection(self, log, loop, is_server=True) 

318 self._pool.add(c) 

319 return c 

320 

321 try: 

322 server = await loop.create_server( 

323 new_connection, 

324 bind[0], 

325 bind[1], 

326 ssl=_server_context, 

327 reuse_port=defaults.has_reuse_port(), 

328 ) 

329 except socket.gaierror as e: 

330 raise error.ResolutionError( 

331 "No local bindable address found for %s" % bind[0] 

332 ) from e 

333 self.server = server 

334 

335 return self 

336 

337 def _evict_from_pool(self, connection): 

338 # May easily happen twice, once when an error comes in and once when 

339 # the connection is (subsequently) closed. 

340 if connection in self._pool: 

341 self._pool.remove(connection) 

342 

343 # implementing TokenInterface 

344 

345 async def fill_or_recognize_remote(self, message): 

346 if ( 

347 message.remote is not None 

348 and isinstance(message.remote, TcpConnection) 

349 and message.remote._ctx is self 

350 ): 

351 return True 

352 

353 return False 

354 

355 async def shutdown(self): 

356 self.log.debug("Shutting down server %r", self) 

357 self._tokenmanager = None 

358 self.server.close() 

359 # Since server has been closed, we won't be getting any *more* 

360 # connections, so we can process them all now: 

361 shutdowns = [ 

362 asyncio.create_task( 

363 c.release(), 

364 name="Close client %s" % c, 

365 ) 

366 for c in self._pool 

367 ] 

368 shutdowns.append( 

369 asyncio.create_task( 

370 self.server.wait_closed(), name="Close server %s" % self 

371 ), 

372 ) 

373 # There is at least one member, so we can just .wait() 

374 await asyncio.wait(shutdowns) 

375 

376 

377class TCPClient(_TCPPooling, interfaces.TokenInterface): 

378 def __init__(self) -> None: 

379 self._pool: Dict[ 

380 Tuple[str, int], TcpConnection 

381 ] = {} #: (host, port) -> connection 

382 # note that connections are filed by host name, so different names for 

383 # the same address might end up with different connections, which is 

384 # probably okay for TCP, and crucial for later work with TLS. 

385 self.log: Optional[Logger] = None 

386 self.loop: Optional[asyncio.AbstractEventLoop] = None 

387 self.credentials = None 

388 

389 async def _spawn_protocol(self, message): 

390 if message.unresolved_remote is None: 

391 host = message.opt.uri_host 

392 port = message.opt.uri_port or self._default_port 

393 if host is None: 

394 raise ValueError( 

395 "No location found to send message to (neither in .opt.uri_host nor in .remote)" 

396 ) 

397 else: 

398 host, port = util.hostportsplit(message.unresolved_remote) 

399 port = port or self._default_port 

400 

401 if (host, port) in self._pool: 

402 return self._pool[(host, port)] 

403 

404 try: 

405 _, protocol = await self.loop.create_connection( 

406 lambda: TcpConnection(self, self.log, self.loop, is_server=False), 

407 host, 

408 port, 

409 ssl=self._ssl_context_factory(message.unresolved_remote), 

410 ) 

411 except socket.gaierror as e: 

412 raise error.ResolutionError( 

413 "No address information found for requests to %r" % host 

414 ) from e 

415 except OSError as e: 

416 raise error.NetworkError("Connection failed to %r" % host) from e 

417 

418 self._pool[(host, port)] = protocol 

419 

420 return protocol 

421 

422 # for diverting behavior of TLSClient 

423 def _ssl_context_factory(self, hostinfo): 

424 return None 

425 

426 def _evict_from_pool(self, connection): 

427 keys = [] 

428 for k, p in self._pool.items(): 

429 if p is connection: 

430 keys.append(k) 

431 # should really be zero or one 

432 for k in keys: 

433 self._pool.pop(k) 

434 

435 @classmethod 

436 async def create_client_transport( 

437 cls, tman: interfaces.TokenManager, log, loop, credentials=None 

438 ): 

439 # this is not actually asynchronous, and even though the interface 

440 # between the context and the creation of interfaces is not fully 

441 # standardized, this stays in the other inferfaces' style. 

442 self = cls() 

443 self._tokenmanager = tman 

444 self.log = log 

445 self.loop = loop 

446 # used by the TLS variant; FIXME not well thought through 

447 self.credentials = credentials 

448 

449 return self 

450 

451 # implementing TokenInterface 

452 

453 async def fill_or_recognize_remote(self, message): 

454 if ( 

455 message.remote is not None 

456 and isinstance(message.remote, TcpConnection) 

457 and message.remote._ctx is self 

458 ): 

459 return True 

460 

461 if message.requested_scheme == self._scheme: 

462 # FIXME: This could pool outgoing connections. 

463 # (Checking if an incoming connection is a pool candidate is 

464 # probably overkill because even if a URI can be constructed from a 

465 # ephemeral client port, nobody but us can use it, and we can just 

466 # set .remote). 

467 message.remote = await self._spawn_protocol(message) 

468 return True 

469 

470 return False 

471 

472 async def shutdown(self): 

473 self.log.debug("Shutting down any outgoing connections on on %r", self) 

474 self._tokenmanager = None 

475 

476 shutdowns = [ 

477 asyncio.create_task( 

478 c.release(), 

479 name="Close client %s" % c, 

480 ) 

481 for c in self._pool.values() 

482 ] 

483 if not shutdowns: 

484 # wait is documented to require a non-empty set 

485 return 

486 await asyncio.wait(shutdowns)