diff --git a/ArduinoCore-Linux/cores/arduino/Ethernet.h b/ArduinoCore-Linux/cores/arduino/Ethernet.h index 7db5dea..8d613a5 100644 --- a/ArduinoCore-Linux/cores/arduino/Ethernet.h +++ b/ArduinoCore-Linux/cores/arduino/Ethernet.h @@ -170,15 +170,13 @@ class EthernetClient : public Client { // checks if we are connected - using a timeout virtual uint8_t connected() override { - if (!is_connected) return false; // connect has failed - if (p_sock->connected()) return true; // check socket - long timeout = millis() + getConnectionTimeout(); - uint8_t result = p_sock->connected(); - while (result <= 0 && millis() < timeout) { - delay(200); - result = p_sock->connected(); - } - return result; + if (!is_connected || !p_sock) return false; + + // A disconnected socket should be reported immediately. Retrying here + // turns normal peer shutdown into a multi-second stall for callers like + // TelnetClient::closeOnDisconnect(). + is_connected = p_sock->connected(); + return is_connected; } // support conversion to bool @@ -186,32 +184,41 @@ class EthernetClient : public Client { // opens a conection virtual int connect(IPAddress ipAddress, uint16_t port) override { + return connect(ipAddress, port, getConnectionTimeout()); + } + + int connect(IPAddress ipAddress, uint16_t port, int32_t timeout_ms) { String str = String(ipAddress[0]) + String(".") + String(ipAddress[1]) + String(".") + String(ipAddress[2]) + String(".") + String(ipAddress[3]); this->address = ipAddress; this->port = port; - return connect(str.c_str(), port); + return connect(str.c_str(), port, timeout_ms); } // opens a connection virtual int connect(const char* address, uint16_t port) override { + return connect(address, port, getConnectionTimeout()); + } + + int connect(const char* address, uint16_t port, int32_t timeout_ms) { Logger.info(WIFICLIENT, "connect"); this->port = port; if (connectedFast()) { p_sock->close(); } - IPAddress adr = resolveAddress(address, port); + IPAddress adr = resolveAddress(address); if (adr == IPAddress(0, 0, 0, 0)) { is_connected = false; return 0; } + // performs the actual connection String str = adr.toString(); Logger.info("Connecting to ", str.c_str()); - p_sock->connect(str.c_str(), port); - is_connected = true; - return 1; + int result = p_sock->connect(str.c_str(), port, timeout_ms); + is_connected = result > 0; + return is_connected ? 1 : 0; } virtual size_t write(char c) { return write((uint8_t)c); } @@ -287,11 +294,13 @@ class EthernetClient : public Client { } virtual size_t readBytes(char* buffer, size_t len) { - return read((uint8_t*)buffer, len); + int result = read((uint8_t*)buffer, len); + return result < 0 ? 0 : static_cast(result); } virtual size_t readBytes(uint8_t* buffer, size_t len) { - return read(buffer, len); + int result = read(buffer, len); + return result < 0 ? 0 : static_cast(result); } // peeks one character @@ -334,21 +343,31 @@ class EthernetClient : public Client { IPAddress address{0, 0, 0, 0}; uint16_t port = 0; + IPAddress resolveHostname(const char* hostname) { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + struct addrinfo* result = nullptr; + if (getaddrinfo(hostname, nullptr, &hints, &result) != 0 || result == nullptr) { + Logger.error(WIFICLIENT, "Hostname resolution failed"); + return IPAddress(0, 0, 0, 0); + } + + auto* addr = reinterpret_cast(result->ai_addr); + IPAddress resolved(addr->sin_addr.s_addr); + freeaddrinfo(result); + return resolved; + } + // resolves the address and returns sockaddr_in - IPAddress resolveAddress(const char* address, uint16_t port) { + IPAddress resolveAddress(const char* address) { struct sockaddr_in serv_addr4; memset(&serv_addr4, 0, sizeof(serv_addr4)); serv_addr4.sin_family = AF_INET; - serv_addr4.sin_port = htons(port); if (::inet_pton(AF_INET, address, &serv_addr4.sin_addr) <= 0) { - // Not an IP, try to resolve hostname - struct hostent* he = ::gethostbyname(address); - if (he == nullptr || he->h_addr_list[0] == nullptr) { - Logger.error(WIFICLIENT, "Hostname resolution failed"); - serv_addr4.sin_addr.s_addr = 0; - } else { - memcpy(&serv_addr4.sin_addr, he->h_addr_list[0], he->h_length); - } + return resolveHostname(address); } return IPAddress(serv_addr4.sin_addr.s_addr); } @@ -367,11 +386,11 @@ class EthernetClient : public Client { int result = 0; long timeout = millis() + getTimeout(); result = p_sock->read(buffer, len); - while (result <= 0 && millis() < timeout) { + while (result == 0 && millis() < timeout) { delay(200); result = p_sock->read(buffer, len); } - //} + char lenStr[16]; sprintf(lenStr, "%d", result); Logger.debug(WIFICLIENT, "read->", lenStr); diff --git a/ArduinoCore-Linux/cores/arduino/NetworkClientSecure.h b/ArduinoCore-Linux/cores/arduino/NetworkClientSecure.h index ba9a480..e4a109e 100644 --- a/ArduinoCore-Linux/cores/arduino/NetworkClientSecure.h +++ b/ArduinoCore-Linux/cores/arduino/NetworkClientSecure.h @@ -63,7 +63,7 @@ class SocketImplSecure : public SocketImpl { } } // direct read - size_t read(uint8_t* buffer, size_t len) { + int read(uint8_t* buffer, size_t len) override { // size_t result = ::recv(sock, buffer, len, MSG_DONTWAIT ); if (ssl == nullptr) { wolfSSL_set_fd(ssl, sock); @@ -71,8 +71,20 @@ class SocketImplSecure : public SocketImpl { int result = ::wolfSSL_read(ssl, buffer, len); if (result < 0) { - result = 0; + int error = wolfSSL_get_error(ssl, result); + if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) { + return 0; + } + + is_connected = false; + return -1; } + + if (result == 0) { + is_connected = false; + return -1; + } + // char lenStr[80]; sprintf(lenStr, "%ld -> %d", len, result); diff --git a/ArduinoCore-Linux/cores/arduino/SocketImpl.cpp b/ArduinoCore-Linux/cores/arduino/SocketImpl.cpp index ea40427..f51f5aa 100644 --- a/ArduinoCore-Linux/cores/arduino/SocketImpl.cpp +++ b/ArduinoCore-Linux/cores/arduino/SocketImpl.cpp @@ -26,10 +26,12 @@ #include #include #include +#include #include #include #include #include +#include #ifdef __APPLE__ #include #include @@ -48,30 +50,57 @@ uint8_t SocketImpl::connected() { if (sock < 0) return false; char buf[2]; int result = ::recv(sock, &buf, 1, MSG_PEEK | MSG_DONTWAIT); - // if peek is working we are connected - if not we do further checks - is_connected = result >= 0; - if (!is_connected) { - int error_code; - socklen_t error_code_size; - // int getsockopt(int sockfd, int level, int optname,void *optval, socklen_t - // *optlen); - int result = - getsockopt(sock, SOL_SOCKET, SO_ERROR, &error_code, &error_code_size); - if (result != 0) { - char msg[50]; - sprintf(msg, "%d", result); - Logger.debug(SOCKET_IMPL, "getsockopt->", msg); - } + if (result > 0) { + is_connected = true; + return true; + } + + if (result == 0) { + Logger.info(SOCKET_IMPL, "peer closed connection"); + close(); + is_connected = false; + return false; + } - is_connected = (result == 0); + if (errno == EAGAIN || errno == EWOULDBLOCK) { + is_connected = true; + return true; } - return is_connected; + int error_code = 0; + socklen_t error_code_size = sizeof(error_code); + // int getsockopt(int sockfd, int level, int optname,void *optval, socklen_t + // *optlen); + int sockopt_result = + getsockopt(sock, SOL_SOCKET, SO_ERROR, &error_code, &error_code_size); + if (sockopt_result != 0) { + char msg[50]; + sprintf(msg, "%d", sockopt_result); + Logger.debug(SOCKET_IMPL, "getsockopt->", msg); + } + + if (sockopt_result == 0 && error_code == 0) { + is_connected = true; + return true; + } + + close(); + is_connected = false; + return false; } // opens a conection int SocketImpl::connect(const char *address, uint16_t port) { + return connect(address, port, -1); +} + +int SocketImpl::connect(const char *address, uint16_t port, int32_t timeout_ms) { + if (sock >= 0) { + close(); + } + if ((sock = ::socket(AF_INET, SOCK_STREAM, 0)) < 0) { + is_connected = false; Logger.error(SOCKET_IMPL, "could not create socket"); return -1; } @@ -89,13 +118,76 @@ int SocketImpl::connect(const char *address, uint16_t port) { // Convert IPv4 and IPv6 addresses from text to binary form if (::inet_pton(AF_INET, address_ip, &serv_addr.sin_addr) <= 0) { + close(); Logger.error(SOCKET_IMPL, "invalid address"); return -2; } - if (::connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) { - Logger.error(SOCKET_IMPL, "could not connect"); - return -3; + if (timeout_ms < 0) { + if (::connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) { + close(); + Logger.error(SOCKET_IMPL, "could not connect"); + return -3; + } + } else { + int flags = fcntl(sock, F_GETFL, 0); + if (flags < 0) { + flags = 0; + } + + if (fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { + close(); + Logger.error(SOCKET_IMPL, "could not set nonblocking connect"); + return -3; + } + + int result = ::connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)); + if (result < 0) { + if (errno != EINPROGRESS && errno != EWOULDBLOCK) { + fcntl(sock, F_SETFL, flags); + close(); + Logger.error(SOCKET_IMPL, "could not connect"); + return -3; + } + + fd_set writefds; + FD_ZERO(&writefds); + FD_SET(sock, &writefds); + + fd_set errorfds; + FD_ZERO(&errorfds); + FD_SET(sock, &errorfds); + + timeval timeout; + timeout.tv_sec = timeout_ms / 1000; + timeout.tv_usec = (timeout_ms % 1000) * 1000; + + result = select(sock + 1, nullptr, &writefds, &errorfds, &timeout); + if (result == 0) { + fcntl(sock, F_SETFL, flags); + close(); + Logger.error(SOCKET_IMPL, "connect timeout"); + return -4; + } + + if (result < 0) { + fcntl(sock, F_SETFL, flags); + close(); + Logger.error(SOCKET_IMPL, "select failed during connect"); + return -3; + } + + int error_code = 0; + socklen_t error_len = sizeof(error_code); + if (getsockopt(sock, SOL_SOCKET, SO_ERROR, &error_code, &error_len) != 0 || error_code != 0) { + fcntl(sock, F_SETFL, flags); + close(); + Logger.error(SOCKET_IMPL, "could not connect"); + return -3; + } + } + + fcntl(sock, F_SETFL, flags); } is_connected = true; @@ -112,8 +204,24 @@ size_t SocketImpl::write(const uint8_t *str, size_t len) { // provides the available bytes size_t SocketImpl::available() { + if (sock < 0) { + is_connected = false; + return 0; + } + int bytes_available; - ioctl(sock, FIONREAD, &bytes_available); + if (ioctl(sock, FIONREAD, &bytes_available) != 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + close(); + is_connected = false; + } + return 0; + } + + if (bytes_available == 0 && !connected()) { + return 0; + } + char msg[50]; sprintf(msg, "%d", bytes_available); Logger.debug(SOCKET_IMPL, "available->", msg); @@ -121,12 +229,35 @@ size_t SocketImpl::available() { } // direct read -size_t SocketImpl::read(uint8_t *buffer, size_t len) { - size_t result = ::recv(sock, buffer, len, MSG_DONTWAIT | MSG_NOSIGNAL); +int SocketImpl::read(uint8_t *buffer, size_t len) { + if (sock < 0) { + is_connected = false; + return -1; + } + + ssize_t result = ::recv(sock, buffer, len, MSG_DONTWAIT | MSG_NOSIGNAL); + if (result == 0) { + Logger.info(SOCKET_IMPL, "read EOF"); + close(); + is_connected = false; + return -1; + } + + if (result < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return 0; + } + + close(); + is_connected = false; + return -1; + } + char lenStr[80]; sprintf(lenStr, "%ld -> %ld", len, result); Logger.debug(SOCKET_IMPL, "read->", lenStr); - return result; + // Possible narrowing can't be helped because of Arduino API compatiblity + return static_cast(result); } // peeks one character @@ -139,7 +270,11 @@ int SocketImpl::peek() { void SocketImpl::close() { Logger.info(SOCKET_IMPL, "close"); - ::close(sock); + if (sock >= 0) { + ::close(sock); + sock = -1; + } + is_connected = false; } // Linux-compatible implementation: parse /proc/net/route for default interface diff --git a/ArduinoCore-Linux/cores/arduino/SocketImpl.h b/ArduinoCore-Linux/cores/arduino/SocketImpl.h index 1fab20d..9e993ec 100644 --- a/ArduinoCore-Linux/cores/arduino/SocketImpl.h +++ b/ArduinoCore-Linux/cores/arduino/SocketImpl.h @@ -45,12 +45,14 @@ class SocketImpl { virtual uint8_t connected(); // opens a conection virtual int connect(const char* address, uint16_t port); + // opens a connection with a timeout in milliseconds + virtual int connect(const char* address, uint16_t port, int32_t timeout_ms); // sends some data virtual size_t write(const uint8_t* str, size_t len); // provides the available bytes virtual size_t available(); // direct read - virtual size_t read(uint8_t* buffer, size_t len); + virtual int read(uint8_t* buffer, size_t len); // peeks one character virtual int peek(); // coloses the connection diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 38d61f6..a49f0eb 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory("spi") add_subdirectory("serial2") add_subdirectory("using-arduino-library") add_subdirectory("pwm") +add_subdirectory("network-connect-timing") # BME280 Sensor Examples arduino_library(SparkFunBME280 "https://github.com/sparkfun/SparkFun_BME280_Arduino_Library" ) diff --git a/examples/network-connect-timing/CMakeLists.txt b/examples/network-connect-timing/CMakeLists.txt new file mode 100644 index 0000000..231379f --- /dev/null +++ b/examples/network-connect-timing/CMakeLists.txt @@ -0,0 +1,5 @@ +arduino_sketch(network-connect-timing network-connect-timing.ino) + +if(APPLE) + target_compile_options(network-connect-timing PRIVATE -Wno-deprecated-declarations) +endif() \ No newline at end of file diff --git a/examples/network-connect-timing/network-connect-timing.ino b/examples/network-connect-timing/network-connect-timing.ino new file mode 100644 index 0000000..da1ca70 --- /dev/null +++ b/examples/network-connect-timing/network-connect-timing.ino @@ -0,0 +1,183 @@ +/* + Network connect timing test. + + There are three test setups to exercise different situations: + + Instant connect failure: + For this test, there should be nothing listening on localhost port 8123 + + network-connect-timing should say "connect failed" with + elapsed=0ms or nearly 0. + + Connect success: + Start a plain TCP listener locally, for example: + nc -lk 127.0.0.1 8123 + + network-connect-timing should say "connect succeeded", + display some send/response/connected/closed messages, + and the elapsed time should be 1000 ms or so. + + Connect timeout: + For this test, you need to configure the host kernel to drop packets. + It does not matter if the listener is running or not. + + On MacOS you can drop packets as follows: + Add this to /etc/pf.conf : + anchor "connect-timeout-test" + load anchor "connect-timeout-test" from "/etc/pf.anchors/connect-timeout-test" + + Create a file /etc/pf.anchors/connect-timeout-test containing: + block drop quick on lo0 proto tcp from any to 127.0.0.1 port 8123 + block drop quick on lo0 proto tcp from 127.0.0.1 to any port 8123 + + Then run this to enable the rules: + sudo pfctl -f /etc/pf.conf + sudo pfctl -e + + For Linux you can drop packets by issuing: + sudo nft add table inet connect_test + sudo nft 'add chain inet connect_test output { type filter hook output priority 0; }' + sudo nft 'add chain inet connect_test input { type filter hook input priority 0; }' + sudo nft 'add rule inet connect_test output oifname "lo" tcp dport 8123 drop' + sudo nft 'add rule inet connect_test input iifname "lo" tcp sport 8123 drop' + + With those packet-dropping rules in place, + network-connect-test should say "connect failed" + and the elapsed time should be a little more than 2000 ms + for the first test and 3000 ms for the second. +*/ + +#include +#include +#include + +namespace { + +constexpr const char* ssid = "localhost-test"; +constexpr const char* password = "unused"; +constexpr const char* host = "127.0.0.1"; +constexpr uint16_t port = 8123; +constexpr int32_t explicit_connect_timeout_ms = 2000; +constexpr int32_t default_connect_timeout_ms = 3000; +constexpr unsigned long read_timeout_ms = 1000; + +WiFiClient explicit_timeout_client; +WiFiClient default_timeout_client; + +void print_test_elapsed(const char* label, unsigned long start_ms) { + Serial.print(label); + Serial.print(" elapsed="); + Serial.print(millis() - start_ms); + Serial.println("ms"); +} + +void wait_for_wifi() { + Serial.print("Attempting to connect to SSID: "); + Serial.println(ssid); + WiFi.begin(ssid, password); + Serial.print("Connected to "); + Serial.println(ssid); + Serial.println(); +} + +bool connect_with_explicit_timeout() { + Serial.print("Connecting with explicit timeout to "); + Serial.print(host); + Serial.print(":"); + Serial.print(port); + Serial.print(" timeout="); + Serial.print(explicit_connect_timeout_ms); + Serial.println("ms"); + + if (!explicit_timeout_client.connect(host, port, explicit_connect_timeout_ms)) { + Serial.println("Explicit-timeout connect failed"); + return false; + } + + Serial.println("Explicit-timeout connect succeeded"); + return true; +} + +bool connect_with_default_timeout() { + default_timeout_client.setConnectionTimeout(default_connect_timeout_ms); + + Serial.print("Connecting with default connect(host, port) using timeout="); + Serial.print(default_connect_timeout_ms); + Serial.println("ms"); + + if (!default_timeout_client.connect(host, port)) { + Serial.println("Default-timeout connect failed"); + return false; + } + + Serial.println("Default-timeout connect succeeded"); + return true; +} + +void send_plain_text(WiFiClient& client, const char* label) { + Serial.print(label); + client.println("hello from WiFiClient"); +} + +void drain_input(WiFiClient& client, const char* label) { + unsigned long deadline = millis() + read_timeout_ms; + + Serial.print(label); + Serial.println(" waiting for response bytes"); + + while (client.connected() && millis() < deadline) { + while (client.available() > 0) { + int ch = client.read(); + if (ch >= 0) { + Serial.write(static_cast(ch)); + deadline = millis() + read_timeout_ms; + } + } + delay(10); + } + + Serial.println(); + Serial.print(label); + Serial.print(" connected() -> "); + Serial.println(client.connected() ? "true" : "false"); +} + +void close_client(WiFiClient& client, const char* label) { + client.stop(); + Serial.print(label); + Serial.println(" closed"); +} + +} // namespace + +void setup() { + Serial.begin(115200); + delay(100); + + wait_for_wifi(); + + unsigned long explicit_test_start_ms = millis(); + if (connect_with_explicit_timeout()) { + send_plain_text(explicit_timeout_client, "explicit_timeout_client"); + drain_input(explicit_timeout_client, "explicit_timeout_client"); + close_client(explicit_timeout_client, "explicit_timeout_client"); + } + print_test_elapsed("explicit-timeout test", explicit_test_start_ms); + Serial.println(); + + unsigned long default_test_start_ms = millis(); + if (connect_with_default_timeout()) { + send_plain_text(default_timeout_client, "default_timeout_client"); + drain_input(default_timeout_client, "default_timeout_client"); + close_client(default_timeout_client, "default_timeout_client"); + } + print_test_elapsed("default-timeout test", default_test_start_ms); + + Serial.println(); + Serial.println("network-connect-timing test complete"); +} + +void loop() { + exit(0); + // delay(1000); +}