Coverage for aiocoap/transports/tcp.py: 91%

249 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-11-28 12:34 +0000

1# SPDX-FileCopyrightText: Christian Amsüss and the aiocoap contributors 

2# 

3# SPDX-License-Identifier: MIT 

4 

5import asyncio 

6import socket 

7from logging import Logger 

8from typing import Dict, Optional, Set, Tuple 

9 

10from aiocoap.transports import rfc8323common 

11from aiocoap import interfaces, error, util 

12from aiocoap import COAP_PORT, Message 

13from aiocoap import defaults 

14 

15 

16def _extract_message_size(data: bytes): 

17 """Read out the full length of a CoAP messsage represented by data. 

18 

19 Returns None if data is too short to read the (full) length. 

20 

21 The number returned is the number of bytes that has to be read into data to 

22 start reading the next message; it consists of a constant term, the token 

23 length and the extended length of options-plus-payload.""" 

24 

25 if not data: 

26 return None 

27 

28 length = data[0] >> 4 

29 tokenoffset = 2 

30 tkl = data[0] & 0x0F 

31 

32 if length >= 13: 

33 if length == 13: 

34 extlen = 1 

35 offset = 13 

36 elif length == 14: 

37 extlen = 2 

38 offset = 269 

39 else: 

40 extlen = 4 

41 offset = 65805 

42 if len(data) < extlen + 1: 

43 return None 

44 tokenoffset = 2 + extlen 

45 length = int.from_bytes(data[1 : 1 + extlen], "big") + offset 

46 return tokenoffset, tkl, length 

47 

48 

49def _decode_message(data: bytes) -> Message: 

50 tokenoffset, tkl, _ = _extract_message_size(data) 

51 if tkl > 8: 

52 raise error.UnparsableMessage("Overly long token") 

53 code = data[tokenoffset - 1] 

54 token = data[tokenoffset : tokenoffset + tkl] 

55 

56 msg = Message(code=code, token=token) 

57 

58 msg.payload = msg.opt.decode(data[tokenoffset + tkl :]) 

59 

60 return msg 

61 

62 

63def _encode_length(length: int): 

64 if length < 13: 

65 return (length, b"") 

66 elif length < 269: 

67 return (13, (length - 13).to_bytes(1, "big")) 

68 elif length < 65805: 

69 return (14, (length - 269).to_bytes(2, "big")) 

70 else: 

71 return (15, (length - 65805).to_bytes(4, "big")) 

72 

73 

74def _serialize(msg: Message) -> bytes: 

75 data_list = [msg.opt.encode()] 

76 if msg.payload: 

77 data_list += [b"\xff", msg.payload] 

78 data = b"".join(data_list) 

79 length, extlen = _encode_length(len(data)) 

80 

81 tkl = len(msg.token) 

82 if tkl > 8: 

83 raise ValueError("Overly long token") 

84 

85 return b"".join( 

86 (bytes(((length << 4) | tkl,)), extlen, bytes((msg.code,)), msg.token, data) 

87 ) 

88 

89 

90class TcpConnection( 

91 asyncio.Protocol, rfc8323common.RFC8323Remote, interfaces.EndpointAddress 

92): 

93 # currently, both the protocol and the EndpointAddress are the same object. 

94 # if, at a later point in time, the keepaliving of TCP connections should 

95 # depend on whether the library user still keeps a usable address around, 

96 # those functions could be split. 

97 

98 def __init__(self, ctx, log, loop, *, is_server) -> None: 

99 super().__init__() 

100 self._ctx = ctx 

101 self.log = log 

102 self.loop = loop 

103 

104 self._spool = b"" 

105 

106 self._remote_settings = None 

107 

108 self._transport: Optional[asyncio.Transport] = None 

109 self._local_is_server = is_server 

110 

111 @property 

112 def scheme(self): 

113 return self._ctx._scheme 

114 

115 def _send_message(self, msg: Message): 

116 self.log.debug("Sending message: %r", msg) 

117 assert ( 

118 self._transport is not None 

119 ), "Attempted to send message before connection" 

120 self._transport.write(_serialize(msg)) 

121 

122 def _abort_with(self, abort_msg): 

123 if self._transport is not None: 

124 self._send_message(abort_msg) 

125 self._transport.close() 

126 else: 

127 # FIXME: find out how this happens; i've only seen it after nmap 

128 # runs against an aiocoap server and then shutting it down. 

129 # "poisoning" the object to make sure this can not be exploited to 

130 # bypass the server shutdown. 

131 self._ctx = None 

132 

133 # implementing asyncio.Protocol 

134 

135 def connection_made(self, transport): 

136 self._transport = transport 

137 

138 ssl_object = transport.get_extra_info("ssl_object") 

139 if ssl_object is not None: 

140 server_name = getattr(ssl_object, "indicated_server_name", None) 

141 else: 

142 server_name = None 

143 

144 # `host` already contains the interface identifier, so throwing away 

145 # scope and interface identifier 

146 self._local_hostinfo = transport.get_extra_info("sockname")[:2] 

147 self._remote_hostinfo = transport.get_extra_info("peername")[:2] 

148 

149 def none_default_port(sockname): 

150 return ( 

151 sockname[0], 

152 None if sockname[1] == self._ctx._default_port else sockname[1], 

153 ) 

154 

155 self._local_hostinfo = none_default_port(self._local_hostinfo) 

156 self._remote_hostinfo = none_default_port(self._remote_hostinfo) 

157 

158 # SNI information available 

159 if server_name is not None: 

160 if self._local_is_server: 

161 self._local_hostinfo = (server_name, self._local_hostinfo[1]) 

162 else: 

163 self._remote_hostinfo = (server_name, self._remote_hostinfo[1]) 

164 

165 self._send_initial_csm() 

166 

167 def connection_lost(self, exc): 

168 # FIXME react meaningfully: 

169 # * send event through pool so it can propagate the error to all 

170 # requests on the same remote 

171 # * mark the address as erroneous so it won't be recognized by 

172 # fill_or_recognize_remote 

173 

174 self._ctx._dispatch_error(self, exc) 

175 

176 def data_received(self, data): 

177 # A rope would be more efficient here, but the expected case is that 

178 # _spool is b"" and spool gets emptied soon -- most messages will just 

179 # fit in a single TCP package and not be nagled together. 

180 # 

181 # (If this does become a bottleneck, say self._spool = SomeRope(b"") 

182 # and barely change anything else). 

183 

184 self._spool += data 

185 

186 while True: 

187 msglen = _extract_message_size(self._spool) 

188 if msglen is None: 

189 break 

190 msglen = sum(msglen) 

191 if msglen > self._my_max_message_size: 

192 self.abort("Overly large message announced") 

193 return 

194 

195 if msglen > len(self._spool): 

196 break 

197 

198 msg = self._spool[:msglen] 

199 try: 

200 msg = _decode_message(msg) 

201 except error.UnparsableMessage: 

202 self.abort("Failed to parse message") 

203 return 

204 msg.remote = self 

205 

206 self.log.debug("Received message: %r", msg) 

207 

208 self._spool = self._spool[msglen:] 

209 

210 if msg.code.is_signalling(): 

211 try: 

212 self._process_signaling(msg) 

213 except rfc8323common.CloseConnection as e: 

214 self._ctx._dispatch_error(self, e.args[0]) 

215 self._transport.close() 

216 continue 

217 

218 if self._remote_settings is None: 

219 self.abort("No CSM received") 

220 return 

221 

222 self._ctx._dispatch_incoming(self, msg) 

223 

224 def eof_received(self): 

225 # FIXME: as with connection_lost, but less noisy if announced 

226 # FIXME: return true and initiate own shutdown if that is what CoAP prescribes 

227 pass 

228 

229 def pause_writing(self): 

230 # FIXME: do something ;-) 

231 pass 

232 

233 def resume_writing(self): 

234 # FIXME: do something ;-) 

235 pass 

236 

237 # RFC8323Remote.release recommends subclassing this, but there's no easy 

238 # awaitable here yet, and no important business to finish, timeout-wise. 

239 

240 

241class _TCPPooling: 

242 # implementing TokenInterface 

243 

244 def send_message(self, message, messageerror_monitor): 

245 # Ignoring messageerror_monitor: CoAP over reliable transports has no 

246 # way of indicating that a particular message was bad, it always shuts 

247 # down the complete connection 

248 

249 if message.code.is_response(): 

250 no_response = (message.opt.no_response or 0) & ( 

251 1 << message.code.class_ - 1 

252 ) != 0 

253 if no_response: 

254 return 

255 

256 message.opt.no_response = None 

257 

258 message.remote._send_message(message) 

259 

260 # used by the TcpConnection instances 

261 

262 def _dispatch_incoming(self, connection, msg): 

263 if msg.code == 0: 

264 pass 

265 

266 if msg.code.is_response(): 

267 self._tokenmanager.process_response(msg) 

268 # ignoring the return value; unexpected responses can be the 

269 # asynchronous result of cancelled observations 

270 else: 

271 self._tokenmanager.process_request(msg) 

272 

273 def _dispatch_error(self, connection, exc): 

274 self._evict_from_pool(connection) 

275 

276 if self._tokenmanager is None: 

277 if exc is not None: 

278 self.log.warning("Ignoring late error during shutdown: %s", exc) 

279 else: 

280 # it's just a regular connection loss, that's to be expected during shutdown 

281 pass 

282 return 

283 

284 self._tokenmanager.dispatch_error(exc, connection) 

285 

286 # for diverting behavior of _TLSMixIn 

287 _scheme = "coap+tcp" 

288 _default_port = COAP_PORT 

289 

290 

291class TCPServer(_TCPPooling, interfaces.TokenInterface): 

292 def __init__(self) -> None: 

293 self._pool: Set[TcpConnection] = set() 

294 self.log: Optional[Logger] = None 

295 self.server = None 

296 

297 @classmethod 

298 async def create_server( 

299 cls, bind, tman: interfaces.TokenManager, log, loop, *, _server_context=None 

300 ): 

301 self = cls() 

302 self._tokenmanager = tman 

303 self.log = log 

304 # self.loop = loop 

305 

306 bind = bind or ("::", None) 

307 bind = ( 

308 bind[0], 

309 bind[1] + (self._default_port - COAP_PORT) 

310 if bind[1] 

311 else self._default_port, 

312 ) 

313 

314 def new_connection(): 

315 c = TcpConnection(self, log, loop, is_server=True) 

316 self._pool.add(c) 

317 return c 

318 

319 try: 

320 server = await loop.create_server( 

321 new_connection, 

322 bind[0], 

323 bind[1], 

324 ssl=_server_context, 

325 reuse_port=defaults.has_reuse_port(), 

326 ) 

327 except socket.gaierror as e: 

328 raise error.ResolutionError( 

329 "No local bindable address found for %s" % bind[0] 

330 ) from e 

331 self.server = server 

332 

333 return self 

334 

335 def _evict_from_pool(self, connection): 

336 # May easily happen twice, once when an error comes in and once when 

337 # the connection is (subsequently) closed. 

338 if connection in self._pool: 

339 self._pool.remove(connection) 

340 

341 # implementing TokenInterface 

342 

343 async def fill_or_recognize_remote(self, message): 

344 if ( 

345 message.remote is not None 

346 and isinstance(message.remote, TcpConnection) 

347 and message.remote._ctx is self 

348 ): 

349 return True 

350 

351 return False 

352 

353 async def shutdown(self): 

354 self.log.debug("Shutting down server %r", self) 

355 self._tokenmanager = None 

356 self.server.close() 

357 # Since server has been closed, we won't be getting any *more* 

358 # connections, so we can process them all now: 

359 shutdowns = [ 

360 asyncio.create_task( 

361 c.release(), 

362 name="Close client %s" % c, 

363 ) 

364 for c in self._pool 

365 ] 

366 shutdowns.append( 

367 asyncio.create_task( 

368 self.server.wait_closed(), name="Close server %s" % self 

369 ), 

370 ) 

371 # There is at least one member, so we can just .wait() 

372 await asyncio.wait(shutdowns) 

373 

374 

375class TCPClient(_TCPPooling, interfaces.TokenInterface): 

376 def __init__(self) -> None: 

377 self._pool: Dict[ 

378 Tuple[str, int], TcpConnection 

379 ] = {} #: (host, port) -> connection 

380 # note that connections are filed by host name, so different names for 

381 # the same address might end up with different connections, which is 

382 # probably okay for TCP, and crucial for later work with TLS. 

383 self.log: Optional[Logger] = None 

384 self.loop: Optional[asyncio.AbstractEventLoop] = None 

385 self.credentials = None 

386 

387 async def _spawn_protocol(self, message): 

388 if message.unresolved_remote is None: 

389 host = message.opt.uri_host 

390 port = message.opt.uri_port or self._default_port 

391 if host is None: 

392 raise ValueError( 

393 "No location found to send message to (neither in .opt.uri_host nor in .remote)" 

394 ) 

395 else: 

396 host, port = util.hostportsplit(message.unresolved_remote) 

397 port = port or self._default_port 

398 

399 if (host, port) in self._pool: 

400 return self._pool[(host, port)] 

401 

402 try: 

403 _, protocol = await self.loop.create_connection( 

404 lambda: TcpConnection(self, self.log, self.loop, is_server=False), 

405 host, 

406 port, 

407 ssl=self._ssl_context_factory(message.unresolved_remote), 

408 ) 

409 except socket.gaierror as e: 

410 raise error.ResolutionError( 

411 "No address information found for requests to %r" % host 

412 ) from e 

413 except OSError as e: 

414 raise error.NetworkError("Connection failed to %r" % host) from e 

415 

416 self._pool[(host, port)] = protocol 

417 

418 return protocol 

419 

420 # for diverting behavior of TLSClient 

421 def _ssl_context_factory(self, hostinfo): 

422 return None 

423 

424 def _evict_from_pool(self, connection): 

425 keys = [] 

426 for k, p in self._pool.items(): 

427 if p is connection: 

428 keys.append(k) 

429 # should really be zero or one 

430 for k in keys: 

431 self._pool.pop(k) 

432 

433 @classmethod 

434 async def create_client_transport( 

435 cls, tman: interfaces.TokenManager, log, loop, credentials=None 

436 ): 

437 # this is not actually asynchronous, and even though the interface 

438 # between the context and the creation of interfaces is not fully 

439 # standardized, this stays in the other inferfaces' style. 

440 self = cls() 

441 self._tokenmanager = tman 

442 self.log = log 

443 self.loop = loop 

444 # used by the TLS variant; FIXME not well thought through 

445 self.credentials = credentials 

446 

447 return self 

448 

449 # implementing TokenInterface 

450 

451 async def fill_or_recognize_remote(self, message): 

452 if ( 

453 message.remote is not None 

454 and isinstance(message.remote, TcpConnection) 

455 and message.remote._ctx is self 

456 ): 

457 return True 

458 

459 if message.requested_scheme == self._scheme: 

460 # FIXME: This could pool outgoing connections. 

461 # (Checking if an incoming connection is a pool candidate is 

462 # probably overkill because even if a URI can be constructed from a 

463 # ephemeral client port, nobody but us can use it, and we can just 

464 # set .remote). 

465 message.remote = await self._spawn_protocol(message) 

466 return True 

467 

468 return False 

469 

470 async def shutdown(self): 

471 self.log.debug("Shutting down any outgoing connections on on %r", self) 

472 self._tokenmanager = None 

473 

474 shutdowns = [ 

475 asyncio.create_task( 

476 c.release(), 

477 name="Close client %s" % c, 

478 ) 

479 for c in self._pool.values() 

480 ] 

481 if not shutdowns: 

482 # wait is documented to require a non-empty set 

483 return 

484 await asyncio.wait(shutdowns)