Skip to content

Commit 57522ba

Browse files
Fix webserver thread safety (#350)
* add test * Use const when possible * Protect webserver::registered_resources_* * Protect webserver::bans and webserver::allowances * Split bans_and_allowances_mutex in two distinct mutexes * Simplify `policy_callback` condition * Fix cpplint style issues - Move opening brace to end of line in policy_callback - Replace using-directive with using-declaration for chrono_literals - Use static_cast instead of C-style cast - Use rand_r instead of rand for thread safety --------- Co-authored-by: Florian CHEVASSU <fchevassu@antidot.net>
1 parent 0e57fd2 commit 57522ba

File tree

3 files changed

+107
-47
lines changed

3 files changed

+107
-47
lines changed

src/httpserver/webserver.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include <map>
4444
#include <memory>
4545
#include <set>
46+
#include <shared_mutex>
4647
#include <string>
4748

4849
#ifdef HAVE_GNUTLS
@@ -172,17 +173,21 @@ class webserver {
172173
const std::string file_upload_dir;
173174
const bool generate_random_filename_on_upload;
174175
const bool deferred_enabled;
175-
bool single_resource;
176-
bool tcp_nodelay;
176+
const bool single_resource;
177+
const bool tcp_nodelay;
177178
pthread_mutex_t mutexwait;
178179
pthread_cond_t mutexcond;
179-
render_ptr not_found_resource;
180-
render_ptr method_not_allowed_resource;
181-
render_ptr internal_error_resource;
180+
const render_ptr not_found_resource;
181+
const render_ptr method_not_allowed_resource;
182+
const render_ptr internal_error_resource;
183+
std::shared_mutex registered_resources_mutex;
182184
std::map<details::http_endpoint, http_resource*> registered_resources;
183185
std::map<std::string, http_resource*> registered_resources_str;
184186

187+
std::shared_mutex bans_mutex;
185188
std::set<http::ip_representation> bans;
189+
190+
std::shared_mutex allowances_mutex;
186191
std::set<http::ip_representation> allowances;
187192

188193
struct MHD_Daemon* daemon;

src/webserver.cpp

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <iosfwd>
4545
#include <iostream>
4646
#include <memory>
47+
#include <mutex>
4748
#include <set>
4849
#include <stdexcept>
4950
#include <string>
@@ -196,6 +197,7 @@ bool webserver::register_resource(const std::string& resource, http_resource* hr
196197

197198
details::http_endpoint idx(resource, family, true, regex_checking);
198199

200+
std::unique_lock registered_resources_lock(registered_resources_mutex);
199201
pair<map<details::http_endpoint, http_resource*>::iterator, bool> result = registered_resources.insert(map<details::http_endpoint, http_resource*>::value_type(idx, hrm));
200202

201203
if (!family && result.second) {
@@ -370,12 +372,14 @@ bool webserver::stop() {
370372
void webserver::unregister_resource(const string& resource) {
371373
// family does not matter - it just checks the url_normalized anyhow
372374
details::http_endpoint he(resource, false, true, regex_checking);
375+
std::unique_lock registered_resources_lock(registered_resources_mutex);
373376
registered_resources.erase(he);
374377
registered_resources.erase(he.get_url_complete());
375378
registered_resources_str.erase(he.get_url_complete());
376379
}
377380

378381
void webserver::ban_ip(const string& ip) {
382+
std::unique_lock bans_lock(bans_mutex);
379383
ip_representation t_ip(ip);
380384
set<ip_representation>::iterator it = bans.find(t_ip);
381385
if (it != bans.end() && (t_ip.weight() < (*it).weight())) {
@@ -387,6 +391,7 @@ void webserver::ban_ip(const string& ip) {
387391
}
388392

389393
void webserver::allow_ip(const string& ip) {
394+
std::unique_lock allowances_lock(allowances_mutex);
390395
ip_representation t_ip(ip);
391396
set<ip_representation>::iterator it = allowances.find(t_ip);
392397
if (it != allowances.end() && (t_ip.weight() < (*it).weight())) {
@@ -398,10 +403,12 @@ void webserver::allow_ip(const string& ip) {
398403
}
399404

400405
void webserver::unban_ip(const string& ip) {
406+
std::unique_lock bans_lock(bans_mutex);
401407
bans.erase(ip_representation(ip));
402408
}
403409

404410
void webserver::disallow_ip(const string& ip) {
411+
std::unique_lock allowances_lock(allowances_mutex);
405412
allowances.erase(ip_representation(ip));
406413
}
407414

@@ -446,14 +453,17 @@ MHD_Result policy_callback(void *cls, const struct sockaddr* addr, socklen_t add
446453
// Parameter needed to respect MHD interface, but not needed here.
447454
std::ignore = addrlen;
448455

449-
if (!(static_cast<webserver*>(cls))->ban_system_enabled) return MHD_YES;
456+
const auto ws = static_cast<webserver*>(cls);
450457

451-
if ((((static_cast<webserver*>(cls))->default_policy == http_utils::ACCEPT) &&
452-
((static_cast<webserver*>(cls))->bans.count(ip_representation(addr))) &&
453-
(!(static_cast<webserver*>(cls))->allowances.count(ip_representation(addr)))) ||
454-
(((static_cast<webserver*>(cls))->default_policy == http_utils::REJECT) &&
455-
((!(static_cast<webserver*>(cls))->allowances.count(ip_representation(addr))) ||
456-
((static_cast<webserver*>(cls))->bans.count(ip_representation(addr)))))) {
458+
if (!ws->ban_system_enabled) return MHD_YES;
459+
460+
std::shared_lock bans_lock(ws->bans_mutex);
461+
std::shared_lock allowances_lock(ws->allowances_mutex);
462+
const bool is_banned = ws->bans.count(ip_representation(addr));
463+
const bool is_allowed = ws->allowances.count(ip_representation(addr));
464+
465+
if ((ws->default_policy == http_utils::ACCEPT && is_banned && !is_allowed) ||
466+
(ws->default_policy == http_utils::REJECT && (!is_allowed || is_banned))) {
457467
return MHD_NO;
458468
}
459469

@@ -676,51 +686,54 @@ MHD_Result webserver::finalize_answer(MHD_Connection* connection, struct details
676686

677687
bool found = false;
678688
struct MHD_Response* raw_response;
679-
if (!single_resource) {
680-
const char* st_url = mr->standardized_url->c_str();
681-
fe = registered_resources_str.find(st_url);
682-
if (fe == registered_resources_str.end()) {
683-
if (regex_checking) {
684-
map<details::http_endpoint, http_resource*>::iterator found_endpoint;
685-
686-
details::http_endpoint endpoint(st_url, false, false, false);
687-
688-
map<details::http_endpoint, http_resource*>::iterator it;
689-
690-
size_t len = 0;
691-
size_t tot_len = 0;
692-
for (it = registered_resources.begin(); it != registered_resources.end(); ++it) {
693-
size_t endpoint_pieces_len = (*it).first.get_url_pieces().size();
694-
size_t endpoint_tot_len = (*it).first.get_url_complete().size();
695-
if (!found || endpoint_pieces_len > len || (endpoint_pieces_len == len && endpoint_tot_len > tot_len)) {
696-
if ((*it).first.match(endpoint)) {
697-
found = true;
698-
len = endpoint_pieces_len;
699-
tot_len = endpoint_tot_len;
700-
found_endpoint = it;
689+
{
690+
std::shared_lock registered_resources_lock(registered_resources_mutex);
691+
if (!single_resource) {
692+
const char* st_url = mr->standardized_url->c_str();
693+
fe = registered_resources_str.find(st_url);
694+
if (fe == registered_resources_str.end()) {
695+
if (regex_checking) {
696+
map<details::http_endpoint, http_resource*>::iterator found_endpoint;
697+
698+
details::http_endpoint endpoint(st_url, false, false, false);
699+
700+
map<details::http_endpoint, http_resource*>::iterator it;
701+
702+
size_t len = 0;
703+
size_t tot_len = 0;
704+
for (it = registered_resources.begin(); it != registered_resources.end(); ++it) {
705+
size_t endpoint_pieces_len = (*it).first.get_url_pieces().size();
706+
size_t endpoint_tot_len = (*it).first.get_url_complete().size();
707+
if (!found || endpoint_pieces_len > len || (endpoint_pieces_len == len && endpoint_tot_len > tot_len)) {
708+
if ((*it).first.match(endpoint)) {
709+
found = true;
710+
len = endpoint_pieces_len;
711+
tot_len = endpoint_tot_len;
712+
found_endpoint = it;
713+
}
701714
}
702715
}
703-
}
704716

705-
if (found) {
706-
vector<string> url_pars = found_endpoint->first.get_url_pars();
717+
if (found) {
718+
vector<string> url_pars = found_endpoint->first.get_url_pars();
707719

708-
vector<string> url_pieces = endpoint.get_url_pieces();
709-
vector<int> chunks = found_endpoint->first.get_chunk_positions();
710-
for (unsigned int i = 0; i < url_pars.size(); i++) {
711-
mr->dhr->set_arg(url_pars[i], url_pieces[chunks[i]]);
712-
}
720+
vector<string> url_pieces = endpoint.get_url_pieces();
721+
vector<int> chunks = found_endpoint->first.get_chunk_positions();
722+
for (unsigned int i = 0; i < url_pars.size(); i++) {
723+
mr->dhr->set_arg(url_pars[i], url_pieces[chunks[i]]);
724+
}
713725

714-
hrm = found_endpoint->second;
726+
hrm = found_endpoint->second;
727+
}
715728
}
729+
} else {
730+
hrm = fe->second;
731+
found = true;
716732
}
717733
} else {
718-
hrm = fe->second;
734+
hrm = registered_resources.begin()->second;
719735
found = true;
720736
}
721-
} else {
722-
hrm = registered_resources.begin()->second;
723-
found = true;
724737
}
725738

726739
if (found) {

test/integ/basic.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020

2121
#include <curl/curl.h>
22+
#include <atomic>
2223
#include <cstdio>
2324
#include <cstring>
2425
#include <iostream>
@@ -27,6 +28,7 @@
2728
#include <numeric>
2829
#include <sstream>
2930
#include <string>
31+
#include <thread>
3032
#include <utility>
3133
#include <vector>
3234

@@ -1573,6 +1575,46 @@ LT_BEGIN_AUTO_TEST(basic_suite, method_not_allowed_header)
15731575
curl_easy_cleanup(curl);
15741576
LT_END_AUTO_TEST(method_not_allowed_header)
15751577

1578+
LT_BEGIN_AUTO_TEST(basic_suite, thread_safety)
1579+
simple_resource resource;
1580+
1581+
std::atomic_bool done = false;
1582+
auto register_thread = std::thread([&]() {
1583+
int i = 0;
1584+
while (!done) {
1585+
ws->register_resource(
1586+
std::string("/route") + std::to_string(++i), &resource);
1587+
}
1588+
});
1589+
1590+
auto get_thread = std::thread([&](){
1591+
unsigned int seed = 42;
1592+
while (!done) {
1593+
CURL *curl = curl_easy_init();
1594+
std::string s;
1595+
std::string url = "localhost:" PORT_STRING "/route" + std::to_string(
1596+
static_cast<int>((rand_r(&seed) * 10000000.0) / RAND_MAX));
1597+
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
1598+
curl_easy_setopt(curl, CURLOPT_HTTPGET, 1L);
1599+
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writefunc);
1600+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &s);
1601+
curl_easy_perform(curl);
1602+
curl_easy_cleanup(curl);
1603+
}
1604+
});
1605+
1606+
using std::chrono_literals::operator""s;
1607+
std::this_thread::sleep_for(10s);
1608+
done = true;
1609+
if (register_thread.joinable()) {
1610+
register_thread.join();
1611+
}
1612+
if (get_thread.joinable()) {
1613+
get_thread.join();
1614+
}
1615+
LT_CHECK_EQ(1, 1);
1616+
LT_END_AUTO_TEST(thread_safety)
1617+
15761618
LT_BEGIN_AUTO_TEST_ENV()
15771619
AUTORUN_TESTS()
15781620
LT_END_AUTO_TEST_ENV()

0 commit comments

Comments
 (0)