Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions Lib/test/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7528,6 +7528,62 @@ def detach():
with threading_helper.start_threads([t1, t2]):
pass

def test_getservby_getprotobyname_race(self):
# gh-74667: these used to share a static buffer with no lock, so
# concurrent calls clobbered each other's result.
try:
http_port = socket.getservbyname('http', 'tcp')
https_port = socket.getservbyname('https', 'tcp')
http_name = socket.getservbyport(http_port, 'tcp')
https_name = socket.getservbyport(https_port, 'tcp')
tcp = socket.getprotobyname('tcp')
udp = socket.getprotobyname('udp')
except OSError:
self.skipTest('required services/protocols are not available')

loops = 10000
errors = []

def check_servbyname(name, proto, expected):
for _ in range(loops):
got = socket.getservbyname(name, proto)
if got != expected:
errors.append(f'getservbyname({name!r}, {proto!r}): '
f'{got!r} != {expected!r}')
return

def check_servbyport(port, proto, expected):
for _ in range(loops):
got = socket.getservbyport(port, proto)
if got != expected:
errors.append(f'getservbyport({port!r}, {proto!r}): '
f'{got!r} != {expected!r}')
return

def check_protobyname(name, expected):
for _ in range(loops):
got = socket.getprotobyname(name)
if got != expected:
errors.append(f'getprotobyname({name!r}): '
f'{got!r} != {expected!r}')
return

threads = [
threading.Thread(target=check_servbyname,
args=('http', 'tcp', http_port)),
threading.Thread(target=check_servbyname,
args=('https', 'tcp', https_port)),
threading.Thread(target=check_servbyport,
args=(http_port, 'tcp', http_name)),
threading.Thread(target=check_servbyport,
args=(https_port, 'tcp', https_name)),
threading.Thread(target=check_protobyname, args=('tcp', tcp)),
threading.Thread(target=check_protobyname, args=('udp', udp)),
]
with threading_helper.start_threads(threads):
pass
self.assertEqual(errors, [])


class ReentrantMutationTests(unittest.TestCase):
"""Regression tests for re-entrant mutation in sendmsg/recvmsg_into.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Made :func:`socket.getservbyname`, :func:`socket.getservbyport` and
:func:`socket.getprotobyname` thread-safe.
136 changes: 126 additions & 10 deletions Modules/socketmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,19 @@ shutdown(how) -- shut down traffic in one or both directions\n\
# define USE_GETHOSTBYNAME_LOCK
#endif

/* These return a pointer to a static buffer shared between threads; a lock is
needed wherever the reentrant *_r() variants are unavailable. */
#if (!defined(HAVE_GETSERVBYNAME_R) || !defined(HAVE_GETSERVBYPORT_R) || \
!defined(HAVE_GETPROTOBYNAME_R)) && !defined(MS_WINDOWS)
# define USE_GETSERVBYNAME_LOCK
#endif

/* netdb_lock is needed if any of the netdb lookups falls back to the
non-reentrant function. */
#if defined(USE_GETHOSTBYNAME_LOCK) || defined(USE_GETSERVBYNAME_LOCK)
# define USE_NETDB_LOCK
#endif

#if defined(__APPLE__) || defined(__CYGWIN__) || defined(__NetBSD__)
# include <sys/ioctl.h>
#endif
Expand Down Expand Up @@ -1168,8 +1181,9 @@ new_sockobject(socket_state *state, SOCKET_T fd, int family, int type,


/* Lock to allow python interpreter to continue, but only allow one
thread to be in gethostbyname or getaddrinfo */
#if defined(USE_GETHOSTBYNAME_LOCK)
thread to be in gethostbyname, getaddrinfo, getservby*() or
getprotobyname() */
#if defined(USE_NETDB_LOCK)
static PyThread_type_lock netdb_lock;
#endif

Expand Down Expand Up @@ -6361,21 +6375,55 @@ socket_getservbyname(PyObject *self, PyObject *args)
{
const char *name, *proto=NULL;
struct servent *sp;
PyObject *ret = NULL;
#ifdef HAVE_GETSERVBYNAME_R
struct servent serv;
char *buf = NULL;
size_t buf_len = 16384;
int err = 0;
#endif
if (!PyArg_ParseTuple(args, "s|s:getservbyname", &name, &proto))
return NULL;

if (PySys_Audit("socket.getservbyname", "ss", name, proto) < 0) {
return NULL;
}

#ifdef HAVE_GETSERVBYNAME_R
do {
char *new_buf = PyMem_RawRealloc(buf, buf_len);
if (new_buf == NULL) {
PyMem_RawFree(buf);
return PyErr_NoMemory();
}
buf = new_buf;
Py_BEGIN_ALLOW_THREADS
err = getservbyname_r(name, proto, &serv, buf, buf_len, &sp);
Py_END_ALLOW_THREADS
buf_len *= 2;
} while (err == ERANGE);
#else
Py_BEGIN_ALLOW_THREADS
#ifdef USE_GETSERVBYNAME_LOCK
PyThread_acquire_lock(netdb_lock, 1);
#endif
sp = getservbyname(name, proto);
Py_END_ALLOW_THREADS
#endif /* HAVE_GETSERVBYNAME_R */
if (sp == NULL) {
PyErr_SetString(PyExc_OSError, "service/proto not found");
return NULL;
}
return PyLong_FromLong((long) ntohs(sp->s_port));
else {
ret = PyLong_FromLong((long) ntohs(sp->s_port));
}
#ifdef HAVE_GETSERVBYNAME_R
PyMem_RawFree(buf);
#else
#ifdef USE_GETSERVBYNAME_LOCK
PyThread_release_lock(netdb_lock);
#endif
#endif /* HAVE_GETSERVBYNAME_R */
return ret;
}

PyDoc_STRVAR(getservbyname_doc,
Expand All @@ -6398,6 +6446,13 @@ socket_getservbyport(PyObject *self, PyObject *args)
int port;
const char *proto=NULL;
struct servent *sp;
PyObject *ret = NULL;
#ifdef HAVE_GETSERVBYPORT_R
struct servent serv;
char *buf = NULL;
size_t buf_len = 16384;
int err = 0;
#endif
if (!PyArg_ParseTuple(args, "i|s:getservbyport", &port, &proto))
return NULL;
if (port < 0 || port > 0xffff) {
Expand All @@ -6411,14 +6466,42 @@ socket_getservbyport(PyObject *self, PyObject *args)
return NULL;
}

#ifdef HAVE_GETSERVBYPORT_R
do {
char *new_buf = PyMem_RawRealloc(buf, buf_len);
if (new_buf == NULL) {
PyMem_RawFree(buf);
return PyErr_NoMemory();
}
buf = new_buf;
Py_BEGIN_ALLOW_THREADS
err = getservbyport_r(htons((short)port), proto, &serv, buf, buf_len,
&sp);
Py_END_ALLOW_THREADS
buf_len *= 2;
} while (err == ERANGE);
#else
Py_BEGIN_ALLOW_THREADS
#ifdef USE_GETSERVBYNAME_LOCK
PyThread_acquire_lock(netdb_lock, 1);
#endif
sp = getservbyport(htons((short)port), proto);
Py_END_ALLOW_THREADS
#endif /* HAVE_GETSERVBYPORT_R */
if (sp == NULL) {
PyErr_SetString(PyExc_OSError, "port/proto not found");
return NULL;
}
return PyUnicode_FromString(sp->s_name);
else {
ret = PyUnicode_FromString(sp->s_name);
}
#ifdef HAVE_GETSERVBYPORT_R
PyMem_RawFree(buf);
#else
#ifdef USE_GETSERVBYNAME_LOCK
PyThread_release_lock(netdb_lock);
#endif
#endif /* HAVE_GETSERVBYPORT_R */
return ret;
}

PyDoc_STRVAR(getservbyport_doc,
Expand All @@ -6440,16 +6523,50 @@ socket_getprotobyname(PyObject *self, PyObject *args)
{
const char *name;
struct protoent *sp;
PyObject *ret = NULL;
#ifdef HAVE_GETPROTOBYNAME_R
struct protoent proto;
char *buf = NULL;
size_t buf_len = 16384;
int err = 0;
#endif
if (!PyArg_ParseTuple(args, "s:getprotobyname", &name))
return NULL;
#ifdef HAVE_GETPROTOBYNAME_R
do {
char *new_buf = PyMem_RawRealloc(buf, buf_len);
if (new_buf == NULL) {
PyMem_RawFree(buf);
return PyErr_NoMemory();
}
buf = new_buf;
Py_BEGIN_ALLOW_THREADS
err = getprotobyname_r(name, &proto, buf, buf_len, &sp);
Py_END_ALLOW_THREADS
buf_len *= 2;
} while (err == ERANGE);
#else
Py_BEGIN_ALLOW_THREADS
#ifdef USE_GETSERVBYNAME_LOCK
PyThread_acquire_lock(netdb_lock, 1);
#endif
sp = getprotobyname(name);
Py_END_ALLOW_THREADS
#endif /* HAVE_GETPROTOBYNAME_R */
if (sp == NULL) {
PyErr_SetString(PyExc_OSError, "protocol not found");
return NULL;
}
return PyLong_FromLong((long) sp->p_proto);
else {
ret = PyLong_FromLong((long) sp->p_proto);
}
#ifdef HAVE_GETPROTOBYNAME_R
PyMem_RawFree(buf);
#else
#ifdef USE_GETSERVBYNAME_LOCK
PyThread_release_lock(netdb_lock);
#endif
#endif /* HAVE_GETPROTOBYNAME_R */
return ret;
}

PyDoc_STRVAR(getprotobyname_doc,
Expand Down Expand Up @@ -9292,8 +9409,7 @@ socket_exec(PyObject *m)
#endif
#endif /* _MSTCPIP_ */

/* Initialize gethostbyname lock */
#if defined(USE_GETHOSTBYNAME_LOCK)
#if defined(USE_NETDB_LOCK)
netdb_lock = PyThread_allocate_lock();
if (netdb_lock == NULL) {
goto error;
Expand Down
Loading
Loading