Skip to content

Commit fe28011

Browse files
committed
gh-74667: Make socket.getservbyname(), getservbyport() and getprotobyname() thread-safe
1 parent 1b9fe5c commit fe28011

6 files changed

Lines changed: 365 additions & 10 deletions

File tree

Lib/test/test_socket.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7528,6 +7528,56 @@ def detach():
75287528
with threading_helper.start_threads([t1, t2]):
75297529
pass
75307530

7531+
def test_getservby_getprotobyname_race(self):
7532+
# gh-74667: these used to share a static buffer with no lock, so
7533+
# concurrent calls clobbered each other's result.
7534+
try:
7535+
http_port = socket.getservbyname('http', 'tcp')
7536+
https_port = socket.getservbyname('https', 'tcp')
7537+
http_name = socket.getservbyport(http_port, 'tcp')
7538+
https_name = socket.getservbyport(https_port, 'tcp')
7539+
tcp = socket.getprotobyname('tcp')
7540+
udp = socket.getprotobyname('udp')
7541+
except OSError:
7542+
self.skipTest('required services/protocols are not available')
7543+
7544+
loops = 10000
7545+
errors = []
7546+
7547+
def check_servbyname(name, proto, expected):
7548+
for _ in range(loops):
7549+
if socket.getservbyname(name, proto) != expected:
7550+
errors.append('getservbyname')
7551+
return
7552+
7553+
def check_servbyport(port, proto, expected):
7554+
for _ in range(loops):
7555+
if socket.getservbyport(port, proto) != expected:
7556+
errors.append('getservbyport')
7557+
return
7558+
7559+
def check_protobyname(name, expected):
7560+
for _ in range(loops):
7561+
if socket.getprotobyname(name) != expected:
7562+
errors.append('getprotobyname')
7563+
return
7564+
7565+
threads = [
7566+
threading.Thread(target=check_servbyname,
7567+
args=('http', 'tcp', http_port)),
7568+
threading.Thread(target=check_servbyname,
7569+
args=('https', 'tcp', https_port)),
7570+
threading.Thread(target=check_servbyport,
7571+
args=(http_port, 'tcp', http_name)),
7572+
threading.Thread(target=check_servbyport,
7573+
args=(https_port, 'tcp', https_name)),
7574+
threading.Thread(target=check_protobyname, args=('tcp', tcp)),
7575+
threading.Thread(target=check_protobyname, args=('udp', udp)),
7576+
]
7577+
with threading_helper.start_threads(threads):
7578+
pass
7579+
self.assertEqual(errors, [])
7580+
75317581

75327582
class ReentrantMutationTests(unittest.TestCase):
75337583
"""Regression tests for re-entrant mutation in sendmsg/recvmsg_into.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Made :func:`socket.getservbyname`, :func:`socket.getservbyport` and
2+
:func:`socket.getprotobyname` thread-safe.

Modules/socketmodule.c

Lines changed: 126 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,19 @@ shutdown(how) -- shut down traffic in one or both directions\n\
205205
# define USE_GETHOSTBYNAME_LOCK
206206
#endif
207207

208+
/* These return a pointer to a static buffer shared between threads; a lock is
209+
needed wherever the reentrant *_r() variants are unavailable. */
210+
#if (!defined(HAVE_GETSERVBYNAME_R) || !defined(HAVE_GETSERVBYPORT_R) || \
211+
!defined(HAVE_GETPROTOBYNAME_R)) && !defined(MS_WINDOWS)
212+
# define USE_GETSERVBYNAME_LOCK
213+
#endif
214+
215+
/* netdb_lock is needed if any of the netdb lookups falls back to the
216+
non-reentrant function. */
217+
#if defined(USE_GETHOSTBYNAME_LOCK) || defined(USE_GETSERVBYNAME_LOCK)
218+
# define USE_NETDB_LOCK
219+
#endif
220+
208221
#if defined(__APPLE__) || defined(__CYGWIN__) || defined(__NetBSD__)
209222
# include <sys/ioctl.h>
210223
#endif
@@ -1168,8 +1181,9 @@ new_sockobject(socket_state *state, SOCKET_T fd, int family, int type,
11681181

11691182

11701183
/* Lock to allow python interpreter to continue, but only allow one
1171-
thread to be in gethostbyname or getaddrinfo */
1172-
#if defined(USE_GETHOSTBYNAME_LOCK)
1184+
thread to be in gethostbyname, getaddrinfo, getservby*() or
1185+
getprotobyname() */
1186+
#if defined(USE_NETDB_LOCK)
11731187
static PyThread_type_lock netdb_lock;
11741188
#endif
11751189

@@ -6361,21 +6375,55 @@ socket_getservbyname(PyObject *self, PyObject *args)
63616375
{
63626376
const char *name, *proto=NULL;
63636377
struct servent *sp;
6378+
PyObject *ret = NULL;
6379+
#ifdef HAVE_GETSERVBYNAME_R
6380+
struct servent serv;
6381+
char *buf = NULL;
6382+
size_t buf_len = 1024;
6383+
int err = 0;
6384+
#endif
63646385
if (!PyArg_ParseTuple(args, "s|s:getservbyname", &name, &proto))
63656386
return NULL;
63666387

63676388
if (PySys_Audit("socket.getservbyname", "ss", name, proto) < 0) {
63686389
return NULL;
63696390
}
63706391

6392+
#ifdef HAVE_GETSERVBYNAME_R
6393+
do {
6394+
char *new_buf = PyMem_RawRealloc(buf, buf_len);
6395+
if (new_buf == NULL) {
6396+
PyMem_RawFree(buf);
6397+
return PyErr_NoMemory();
6398+
}
6399+
buf = new_buf;
6400+
Py_BEGIN_ALLOW_THREADS
6401+
err = getservbyname_r(name, proto, &serv, buf, buf_len, &sp);
6402+
Py_END_ALLOW_THREADS
6403+
buf_len *= 2;
6404+
} while (err == ERANGE);
6405+
#else
63716406
Py_BEGIN_ALLOW_THREADS
6407+
#ifdef USE_GETSERVBYNAME_LOCK
6408+
PyThread_acquire_lock(netdb_lock, 1);
6409+
#endif
63726410
sp = getservbyname(name, proto);
63736411
Py_END_ALLOW_THREADS
6412+
#endif /* HAVE_GETSERVBYNAME_R */
63746413
if (sp == NULL) {
63756414
PyErr_SetString(PyExc_OSError, "service/proto not found");
6376-
return NULL;
63776415
}
6378-
return PyLong_FromLong((long) ntohs(sp->s_port));
6416+
else {
6417+
ret = PyLong_FromLong((long) ntohs(sp->s_port));
6418+
}
6419+
#ifdef HAVE_GETSERVBYNAME_R
6420+
PyMem_RawFree(buf);
6421+
#else
6422+
#ifdef USE_GETSERVBYNAME_LOCK
6423+
PyThread_release_lock(netdb_lock);
6424+
#endif
6425+
#endif /* HAVE_GETSERVBYNAME_R */
6426+
return ret;
63796427
}
63806428

63816429
PyDoc_STRVAR(getservbyname_doc,
@@ -6398,6 +6446,13 @@ socket_getservbyport(PyObject *self, PyObject *args)
63986446
int port;
63996447
const char *proto=NULL;
64006448
struct servent *sp;
6449+
PyObject *ret = NULL;
6450+
#ifdef HAVE_GETSERVBYPORT_R
6451+
struct servent serv;
6452+
char *buf = NULL;
6453+
size_t buf_len = 1024;
6454+
int err = 0;
6455+
#endif
64016456
if (!PyArg_ParseTuple(args, "i|s:getservbyport", &port, &proto))
64026457
return NULL;
64036458
if (port < 0 || port > 0xffff) {
@@ -6411,14 +6466,42 @@ socket_getservbyport(PyObject *self, PyObject *args)
64116466
return NULL;
64126467
}
64136468

6469+
#ifdef HAVE_GETSERVBYPORT_R
6470+
do {
6471+
char *new_buf = PyMem_RawRealloc(buf, buf_len);
6472+
if (new_buf == NULL) {
6473+
PyMem_RawFree(buf);
6474+
return PyErr_NoMemory();
6475+
}
6476+
buf = new_buf;
6477+
Py_BEGIN_ALLOW_THREADS
6478+
err = getservbyport_r(htons((short)port), proto, &serv, buf, buf_len,
6479+
&sp);
6480+
Py_END_ALLOW_THREADS
6481+
buf_len *= 2;
6482+
} while (err == ERANGE);
6483+
#else
64146484
Py_BEGIN_ALLOW_THREADS
6485+
#ifdef USE_GETSERVBYNAME_LOCK
6486+
PyThread_acquire_lock(netdb_lock, 1);
6487+
#endif
64156488
sp = getservbyport(htons((short)port), proto);
64166489
Py_END_ALLOW_THREADS
6490+
#endif /* HAVE_GETSERVBYPORT_R */
64176491
if (sp == NULL) {
64186492
PyErr_SetString(PyExc_OSError, "port/proto not found");
6419-
return NULL;
64206493
}
6421-
return PyUnicode_FromString(sp->s_name);
6494+
else {
6495+
ret = PyUnicode_FromString(sp->s_name);
6496+
}
6497+
#ifdef HAVE_GETSERVBYPORT_R
6498+
PyMem_RawFree(buf);
6499+
#else
6500+
#ifdef USE_GETSERVBYNAME_LOCK
6501+
PyThread_release_lock(netdb_lock);
6502+
#endif
6503+
#endif /* HAVE_GETSERVBYPORT_R */
6504+
return ret;
64226505
}
64236506

64246507
PyDoc_STRVAR(getservbyport_doc,
@@ -6440,16 +6523,50 @@ socket_getprotobyname(PyObject *self, PyObject *args)
64406523
{
64416524
const char *name;
64426525
struct protoent *sp;
6526+
PyObject *ret = NULL;
6527+
#ifdef HAVE_GETPROTOBYNAME_R
6528+
struct protoent proto;
6529+
char *buf = NULL;
6530+
size_t buf_len = 1024;
6531+
int err = 0;
6532+
#endif
64436533
if (!PyArg_ParseTuple(args, "s:getprotobyname", &name))
64446534
return NULL;
6535+
#ifdef HAVE_GETPROTOBYNAME_R
6536+
do {
6537+
char *new_buf = PyMem_RawRealloc(buf, buf_len);
6538+
if (new_buf == NULL) {
6539+
PyMem_RawFree(buf);
6540+
return PyErr_NoMemory();
6541+
}
6542+
buf = new_buf;
6543+
Py_BEGIN_ALLOW_THREADS
6544+
err = getprotobyname_r(name, &proto, buf, buf_len, &sp);
6545+
Py_END_ALLOW_THREADS
6546+
buf_len *= 2;
6547+
} while (err == ERANGE);
6548+
#else
64456549
Py_BEGIN_ALLOW_THREADS
6550+
#ifdef USE_GETSERVBYNAME_LOCK
6551+
PyThread_acquire_lock(netdb_lock, 1);
6552+
#endif
64466553
sp = getprotobyname(name);
64476554
Py_END_ALLOW_THREADS
6555+
#endif /* HAVE_GETPROTOBYNAME_R */
64486556
if (sp == NULL) {
64496557
PyErr_SetString(PyExc_OSError, "protocol not found");
6450-
return NULL;
64516558
}
6452-
return PyLong_FromLong((long) sp->p_proto);
6559+
else {
6560+
ret = PyLong_FromLong((long) sp->p_proto);
6561+
}
6562+
#ifdef HAVE_GETPROTOBYNAME_R
6563+
PyMem_RawFree(buf);
6564+
#else
6565+
#ifdef USE_GETSERVBYNAME_LOCK
6566+
PyThread_release_lock(netdb_lock);
6567+
#endif
6568+
#endif /* HAVE_GETPROTOBYNAME_R */
6569+
return ret;
64536570
}
64546571

64556572
PyDoc_STRVAR(getprotobyname_doc,
@@ -9292,8 +9409,7 @@ socket_exec(PyObject *m)
92929409
#endif
92939410
#endif /* _MSTCPIP_ */
92949411

9295-
/* Initialize gethostbyname lock */
9296-
#if defined(USE_GETHOSTBYNAME_LOCK)
9412+
#if defined(USE_NETDB_LOCK)
92979413
netdb_lock = PyThread_allocate_lock();
92989414
if (netdb_lock == NULL) {
92999415
goto error;

0 commit comments

Comments
 (0)