/*
 *  TCP/IP or UDP/IP networking functions
 *
 *  Copyright The Mbed TLS Contributors
 *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
 *
 *  This file is provided under the Apache License 2.0, or the
 *  GNU General Public License v2.0 or later.
 *
 *  **********
 *  Apache License 2.0:
 *
 *  Licensed under the Apache License, Version 2.0 (the "License"); you may
 *  not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *  **********
 *
 *  **********
 *  GNU General Public License v2.0 or later:
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with this program; if not, write to the Free Software Foundation, Inc.,
 *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 *  **********
 */

/* Enable definition of getaddrinfo() even when compiling with -std=c99. Must
 * be set before config.h, which pulls in glibc's features.h indirectly.
 * Harmless on other platforms. */
#define _POSIX_C_SOURCE 200112L

#if defined(__NetBSD__)
#define _XOPEN_SOURCE 600 /* sockaddr_storage */
#endif

#if !defined(MBEDTLS_CONFIG_FILE)
#include "mbedtls/config.h"
#else
#include MBEDTLS_CONFIG_FILE
#endif

#ifndef RISCOS
#error "This is a platform specific file for RISC OS only"
#endif

#if defined(MBEDTLS_NET_C)

#if defined(MBEDTLS_PLATFORM_C)
#include "mbedtls/platform.h"
#else
#include <stdlib.h>
#define mbedtls_time_t   time_t
#endif

#include "mbedtls/net_sockets.h"

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/errno.h>
#include <sys/time.h>
#include <sys/filio.h>
#include <unistd.h>
#include <signal.h>
#include <netdb.h>
#include <socklib.h>

#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "swis.h"
#include "Global/OsBytes.h"
#include "AsmUtils/callbacks.h"
#define Resolver_GetHostByName          0x046000
#define Resolver_GetHost                0x046001
#define Resolver_GetCache               0x046002
#define Resolver_CacheControl           0x046003

/*
 * Give network stack some grace between returns, when nothing was available/possible
 */
static void backoff(void)
{
#define MODE_BITS    0xF   /* Just the mode */
#define MODE_M32_BIT 0x10  /* 32 versus 26 bit */
#define MODE_USR     0     /* USR32 or USR26 if masked with MODE_BITS */

	int procmode;
	uint32_t mono = _swi(OS_ReadMonotonicTime, _RETURN(0));

	__asm
	{
		MRS procmode, CPSR
	}

	do
	{
		/* Trigger callbacks (if not already in user mode) */
		if ((procmode & MODE_BITS) != MODE_USR) usermode_donothing();
	} while (mono == _swi(OS_ReadMonotonicTime, _RETURN(0)));
}

/*
 * Prepare for using the sockets interface
 */
static int net_prepare(void)
{
	return 0;
}

/*
 * Initialize a context
 */
void mbedtls_net_init(mbedtls_net_context *ctx)
{
	ctx->fd = -1;
	ctx->nbio = 0; /* Default to blocking */
}

/*
 * Initiate a TCP connection with host:port and the given protocol
 */
int mbedtls_net_connect(mbedtls_net_context *ctx,
                        const char *url, const char *port, int proto)
{
	int ret;
	struct hostent     *host;
	char              **cur;
	struct sockaddr_in  addr;

	ret = net_prepare();
	if (ret != 0)
	{
		return ret;
	}

	/* No IPv6 support beneath us yet, resolve name to
	 * to IPv4 address.
	 */
	while (1)
	{
		_kernel_oserror *error;
		int errnum;

		error = _swix(Resolver_GetHost, _IN(0) | _OUTR(0,1),
		              url, &errnum, &host);
		if (error != NULL)
		{
			return MBEDTLS_ERR_NET_UNKNOWN_HOST;
		}
		if (errnum == EINPROGRESS)
		{
			backoff();
			continue;
		}
		if (errnum == 0) break;
		host = NULL; /* Some other error */
	}
	if (host == NULL)
	{
		return MBEDTLS_ERR_NET_UNKNOWN_HOST;
	}

	/* Try the sockaddrs until a connection succeeds */
	ret = MBEDTLS_ERR_NET_UNKNOWN_HOST;

	for (cur = host->h_addr_list; *cur != NULL; cur++)
	{
		ctx->fd = (int)socket(PF_INET,
		                      proto == MBEDTLS_NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM,
		                      proto == MBEDTLS_NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP);
		if (ctx->fd < 0)
		{
			ret = MBEDTLS_ERR_NET_SOCKET_FAILED;
			continue;
		}

		memset(&addr, 0, sizeof(addr));
		addr.sin_len = host->h_length;
		addr.sin_family = host->h_addrtype;
		addr.sin_port = htons(atoi(port));
		memcpy(&addr.sin_addr.s_addr, *cur, host->h_length);
		if (connect(ctx->fd, (struct sockaddr *)&addr, sizeof(addr)) == 0)
		{
			ret = 0;
			break;
		}

		close(ctx->fd);
		ret = MBEDTLS_ERR_NET_CONNECT_FAILED;
	}

	return ret;
}

/*
 * Create a listening socket on bind_ip:port
 */
int mbedtls_net_bind(mbedtls_net_context *ctx,
                     const char *url, const char *port, int proto)
{
	int n, ret;
	const struct hostent *host;
	char                **cur;
	struct sockaddr_in    addr;
	static const struct in_addr inaddr_any_addr =
	{
		INADDR_ANY                    /* s_addr */
	};
	static const char *inaddr_any_addr_list[] =
	{
		(char *)&inaddr_any_addr,
		NULL
	};
	static const struct hostent inaddr_any_host = 
	{
		NULL,                         /* h_name */
		NULL,                         /* h_aliases */
		AF_INET,                      /* h_addrtype */
		sizeof(struct in_addr),       /* h_length */
		(char **)inaddr_any_addr_list /* h_addr_list */
	};

	ret = net_prepare();
	if (ret != 0)
	{
		return ret;
	}

	if (url == NULL)
	{
		/* Bind to INADDR_ANY */
		host = &inaddr_any_host;
	}
	else
	{
		/* No IPv6 support beneath us yet, resolve name to
		 * to IPv4 address.
		 */
		while (1)
		{
			_kernel_oserror *error;
			int errnum;

			error = _swix(Resolver_GetHost, _IN(0) | _OUTR(0,1),
			              url, &errnum, &host);
			if (error != NULL)
			{
				return MBEDTLS_ERR_NET_UNKNOWN_HOST;
			}
			if (errnum == EINPROGRESS)
			{
				backoff();
				continue;
			}
			if (errnum == 0) break;
			host = NULL; /* Some other error */
		}
	}
	if (host == NULL)
	{
		return MBEDTLS_ERR_NET_UNKNOWN_HOST;
	}
	
	/* Try the sockaddrs until a binding succeeds */
	ret = MBEDTLS_ERR_NET_UNKNOWN_HOST;
	for (cur = host->h_addr_list; *cur != NULL; cur++)
	{
		ctx->fd = (int)socket(PF_INET,
		                      proto == MBEDTLS_NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM,
		                      proto == MBEDTLS_NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP);
		if (ctx->fd < 0)
		{
			ret = MBEDTLS_ERR_NET_SOCKET_FAILED;
			continue;
		}

		n = 1;
		if (setsockopt(ctx->fd, SOL_SOCKET, SO_REUSEADDR,
		               (const char *)&n, sizeof(n)) != 0)
		{
			close(ctx->fd);
			ret = MBEDTLS_ERR_NET_SOCKET_FAILED;
			continue;
		}

		memset(&addr, 0, sizeof(addr));
		addr.sin_len = host->h_length;
		addr.sin_family = host->h_addrtype;
		addr.sin_port = htons(atoi(port));
		memcpy(&addr.sin_addr.s_addr, *cur, host->h_length);
		if (bind(ctx->fd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
		{
			close(ctx->fd);
			ret = MBEDTLS_ERR_NET_BIND_FAILED;
			continue;
		}
		
		/* Listen only makes sense for TCP */
		if (proto == MBEDTLS_NET_PROTO_TCP)
		{
			if (listen(ctx->fd, MBEDTLS_NET_LISTEN_BACKLOG) != 0)
			{
				close(ctx->fd);
				ret = MBEDTLS_ERR_NET_LISTEN_FAILED;
				continue;
			}
		}

		/* If we ever get here, it's a success */
		ret = 0;
		break;
	}

	return ret;
}

/*
 * Check if the requested operation would be blocking on a non-blocking socket
 * and thus 'failed' with a negative return value.
 *
 * Note: on a blocking socket this function always returns 0!
 */
static int net_would_block(const mbedtls_net_context *ctx)
{
	/*
	 * Never return 'WOULD BLOCK' on a blocking socket
	 */
	if (!ctx->nbio)
	{
		return 0;
	}

	switch (errno)
	{
#if defined EAGAIN
		case EAGAIN:
#endif
#if defined EWOULDBLOCK && EWOULDBLOCK != EAGAIN
		case EWOULDBLOCK:
#endif
			return 1;
	}

	return 0;
}

/*
 * Accept a connection from a remote client
 */
int mbedtls_net_accept(mbedtls_net_context *bind_ctx,
                       mbedtls_net_context *client_ctx,
                       void *client_ip, size_t buf_size, size_t *ip_len)
{
	int ret;
	int type;
	struct sockaddr client_addr;
#if defined(__socklen_t_defined) || defined(_SOCKLEN_T) || \
    defined(_SOCKLEN_T_DECLARED) || defined(__DEFINED_socklen_t) || \
    ( defined(__NetBSD__) && defined(socklen_t) )
	socklen_t n = (socklen_t)sizeof(client_addr);
	socklen_t type_len = (socklen_t)sizeof(type);
#else
	int n = sizeof(client_addr);
	int type_len = sizeof(type);
#endif

	/* Is this a TCP or UDP socket? */
	if ((getsockopt(bind_ctx->fd, SOL_SOCKET, SO_TYPE, (void *)&type, (int *)&type_len) != 0) ||
	    (type != SOCK_STREAM && type != SOCK_DGRAM))
	{
		return MBEDTLS_ERR_NET_ACCEPT_FAILED;
	}

	if (type == SOCK_STREAM)
	{
		/* TCP: actual accept() */
		ret = client_ctx->fd = (int)accept(bind_ctx->fd,
		                                   (struct sockaddr *)&client_addr, (int *)&n);
	}
	else
	{
		/* UDP: wait for a message, but keep it in the queue */
		char buf[1] = { 0 };

		ret = (int)recvfrom(bind_ctx->fd, buf, sizeof(buf), MSG_PEEK,
		                    (struct sockaddr *)&client_addr, (int *)&n);
	}

	if (ret < 0)
	{
		if (net_would_block(bind_ctx) != 0)
		{
			backoff();
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		return MBEDTLS_ERR_NET_ACCEPT_FAILED;
	}

	/* UDP: hijack the listening socket to communicate with the client,
	 * then bind a new socket to accept new connections
	 */
	if (type != SOCK_STREAM)
	{
		struct sockaddr local_addr;
		int one = 1;
		
		if (connect(bind_ctx->fd, (struct sockaddr *)&client_addr, n) != 0)
		{
			return MBEDTLS_ERR_NET_ACCEPT_FAILED;
		}

		client_ctx->fd = bind_ctx->fd;
		bind_ctx->fd   = -1; /* In case we exit early */

		n = sizeof(struct sockaddr);
		if ((getsockname(client_ctx->fd, (struct sockaddr *)&local_addr, (int *)&n) != 0) ||
		    ((bind_ctx->fd = (int)socket(local_addr.sa_family,
		                                 SOCK_DGRAM, IPPROTO_UDP)) < 0) ||
		    (setsockopt(bind_ctx->fd, SOL_SOCKET, SO_REUSEADDR,
		                (const char *)&one, sizeof(one)) != 0))
		{
			return MBEDTLS_ERR_NET_SOCKET_FAILED;
		}

		if (bind(bind_ctx->fd, (struct sockaddr *)&local_addr, n) != 0)
		{
			return MBEDTLS_ERR_NET_BIND_FAILED;
		}
	}

	if (client_ip != NULL)
	{
		struct sockaddr_in *addr4 = (struct sockaddr_in *) &client_addr;
		*ip_len = sizeof(addr4->sin_addr.s_addr);

		if (buf_size < *ip_len)
		{
			return MBEDTLS_ERR_NET_BUFFER_TOO_SMALL;
		}

		memcpy(client_ip, &addr4->sin_addr.s_addr, *ip_len);
	}

	return 0;
}

/*
 * Set the socket blocking or non-blocking
 */
int mbedtls_net_set_block(mbedtls_net_context *ctx)
{
	int on = 0;

	ctx->nbio = on;
	return socketioctl(ctx->fd, FIONBIO, &on);
}

int mbedtls_net_set_nonblock(mbedtls_net_context *ctx)
{
	int on = 1;

	ctx->nbio = on;
	return socketioctl(ctx->fd, FIONBIO, &on);
}

/*
 * Portable usleep helper
 */
void mbedtls_net_usleep(unsigned long usec)
{
	struct timeval tv;
	tv.tv_sec  = usec / 1000000;
	tv.tv_usec = usec % 1000000;
	select(0, NULL, NULL, NULL, &tv);
}

/*
 * Read at most 'len' characters
 */
int mbedtls_net_recv(void *ctx, unsigned char *buf, size_t len)
{
	int ret;
	int fd = ((mbedtls_net_context *)ctx)->fd;

	if (fd < 0)
	{
		return MBEDTLS_ERR_NET_INVALID_CONTEXT;
	}

	ret = (int)read(fd, buf, len);

	if (ret < 0)
	{
		if (net_would_block(ctx) != 0)
		{
			backoff();
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		if (errno == EPIPE || errno == ECONNRESET)
		{
			return MBEDTLS_ERR_NET_CONN_RESET;
		}

		if (errno == EINTR)
		{
			/* Escape */
			_swix(OS_Byte, _IN(0), OsByte_AcknowledgeEscape); 
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		return MBEDTLS_ERR_NET_RECV_FAILED;
	}

	return ret;
}

/*
 * Read at most 'len' characters, blocking for at most 'timeout' ms
 */
int mbedtls_net_recv_timeout( void *ctx, unsigned char *buf,
                              size_t len, uint32_t timeout )
{
	int ret;
	struct timeval tv;
	fd_set read_fds;
	int fd = ((mbedtls_net_context *)ctx)->fd;
	
	if (fd < 0)
	{
		return MBEDTLS_ERR_NET_INVALID_CONTEXT;
	}

	/* A limitation of select() is that it only works with file descriptors
	 * that are strictly less than FD_SETSIZE. This is a limitation of the
	 * fd_set type. Error out early, because attempting to call FD_SET on a
	 * large file descriptor is a buffer overflow on typical platforms.
	 */
	if (fd >= FD_SETSIZE)
	{
		return MBEDTLS_ERR_NET_POLL_FAILED;
	}

	FD_ZERO(&read_fds);
	FD_SET(fd, &read_fds);

	tv.tv_sec  = timeout / 1000;
	tv.tv_usec = ( timeout % 1000 ) * 1000;

	ret = select(fd + 1, &read_fds, NULL, NULL, timeout == 0 ? NULL : &tv);

	/* Zero fds ready means we timed out */
	if (ret == 0)
	{
		return MBEDTLS_ERR_SSL_TIMEOUT;
	}

	if (ret < 0)
	{
		if (errno == EINTR)
		{
			/* Escape */
			_swix(OS_Byte, _IN(0), OsByte_AcknowledgeEscape); 
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		return MBEDTLS_ERR_NET_RECV_FAILED;
	}

	/* This call will not block */
	return (mbedtls_net_recv(ctx, buf, len));
}

/*
 * Write at most 'len' characters
 */
int mbedtls_net_send(void *ctx, const unsigned char *buf, size_t len)
{
	int ret;
	int fd = ((mbedtls_net_context *)ctx)->fd;

	if (fd < 0)
	{
		return MBEDTLS_ERR_NET_INVALID_CONTEXT;
	}

	ret = (int)write(fd, buf, len);

	if (ret < 0)
	{
		if (net_would_block(ctx) != 0)
		{
			backoff();
			return MBEDTLS_ERR_SSL_WANT_WRITE;
		}

		if (errno == EPIPE || errno == ECONNRESET)
		{
			return MBEDTLS_ERR_NET_CONN_RESET;
		}

		if (errno == EINTR)
		{
			/* Escape */
			_swix(OS_Byte, _IN(0), OsByte_AcknowledgeEscape); 
			return MBEDTLS_ERR_SSL_WANT_WRITE;
		}

		return MBEDTLS_ERR_NET_SEND_FAILED;
	}

	return ret;
}

/*
 * Gracefully close the connection
 */
void mbedtls_net_free(mbedtls_net_context *ctx)
{
	if (ctx->fd == -1)
	{
		return;
	}

	shutdown(ctx->fd, 2);
	close(ctx->fd);

	ctx->fd = -1;
}

#endif /* MBEDTLS_NET_C */