@@ -405,6 +405,114 @@ def test_create_unix_connection_6(self):
405405 ssl_handshake_timeout = SSL_HANDSHAKE_TIMEOUT ))
406406
407407
408+ def test_create_unix_connection_sock_cancel_detaches (self ):
409+ async def test ():
410+ srv_path = os .path .join (tempfile .mkdtemp (), 'test.sock' )
411+ srv = await asyncio .start_unix_server (
412+ lambda r , w : None , path = srv_path )
413+
414+ sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
415+ sock .setblocking (False )
416+ try :
417+ sock .connect (srv_path )
418+ except BlockingIOError :
419+ pass
420+ await asyncio .sleep (0.01 )
421+
422+ task = asyncio .ensure_future (
423+ self .loop .create_unix_connection (
424+ asyncio .Protocol , sock = sock ))
425+ await asyncio .sleep (0 )
426+ task .cancel ()
427+ with self .assertRaises (asyncio .CancelledError ):
428+ await task
429+
430+ self .assertEqual (sock .fileno (), - 1 )
431+
432+ srv .close ()
433+ await srv .wait_closed ()
434+ if os .path .exists (srv_path ):
435+ os .unlink (srv_path )
436+
437+ self .loop .run_until_complete (test ())
438+
439+ def test_create_unix_connection_sock_cancel_fd_leak (self ):
440+ # Same as test_create_connection_sock_cancel_fd_leak but for
441+ # the create_unix_connection(sock=) path.
442+
443+ async def test ():
444+ srv_path = os .path .join (tempfile .mkdtemp (), 'test.sock' )
445+ srv = await asyncio .start_unix_server (
446+ lambda r , w : None , path = srv_path )
447+
448+ sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
449+ sock .setblocking (False )
450+ await self .loop .sock_connect (sock , srv_path )
451+ stale_fd = sock .fileno ()
452+
453+ task = self .loop .create_task (
454+ self .loop .create_unix_connection (
455+ asyncio .Protocol , sock = sock ))
456+ await asyncio .sleep (0 )
457+ task .cancel ()
458+ with self .assertRaises (asyncio .CancelledError ):
459+ await task
460+
461+ # Create victim that reuses the fd.
462+ victim_sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
463+ victim_sock .setblocking (False )
464+ await self .loop .sock_connect (victim_sock , srv_path )
465+ victim_tr , _ = await self .loop .create_unix_connection (
466+ asyncio .Protocol , sock = victim_sock )
467+ victim_fd = victim_tr .get_extra_info ('socket' ).fileno ()
468+ if victim_fd != stale_fd :
469+ victim_tr .close ()
470+ sock .close ()
471+ srv .close ()
472+ await srv .wait_closed ()
473+ if os .path .exists (srv_path ):
474+ os .unlink (srv_path )
475+ raise unittest .SkipTest (
476+ f'fd not reused (got { victim_fd } , need { stale_fd } )' )
477+
478+ spy_a , spy_b = socket .socketpair ()
479+ spy_b .setblocking (False )
480+
481+ sock .close ()
482+
483+ victim_broken = False
484+ try :
485+ os .fstat (victim_fd )
486+ except OSError :
487+ victim_broken = True
488+
489+ if victim_broken :
490+ os .dup2 (spy_a .fileno (), stale_fd )
491+ spy_a .close ()
492+
493+ victim_tr .write (b'LEAKED' )
494+
495+ try :
496+ leaked = spy_b .recv (4096 )
497+ except BlockingIOError :
498+ leaked = b''
499+
500+ if victim_broken :
501+ os .close (stale_fd )
502+ spy_b .close ()
503+ victim_tr .close ()
504+ srv .close ()
505+ await srv .wait_closed ()
506+ if os .path .exists (srv_path ):
507+ os .unlink (srv_path )
508+
509+ self .assertEqual (leaked , b'' ,
510+ f"Data leaked to an unrelated socket: "
511+ f"got { leaked !r} " )
512+
513+ self .loop .run_until_complete (test ())
514+
515+
408516class Test_UV_Unix (_TestUnix , tb .UVTestCase ):
409517
410518 @unittest .skipUnless (hasattr (os , 'fspath' ), 'no os.fspath()' )
0 commit comments