/*
 *  TCP/IP or UDP/IP networking functions
 *
 *  Copyright (C) 2006-2015, ARM Limited, All Rights Reserved
 *  Copyright (c) 2018, RISC OS Open Ltd
 *  SPDX-License-Identifier: Apache-2.0
 *
 *  Licensed under the Apache License, Version 2.0 (the "License"); you may
 *  not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *  This file is part of mbed TLS (https://tls.mbed.org)
 */

/* Enable definition of getaddrinfo() even when compiling with -std=c99. Must
 * be set before config.h, which pulls in glibc's features.h indirectly.
 * Harmless on other platforms. */
#define _POSIX_C_SOURCE 200112L

#if !defined(MBEDTLS_CONFIG_FILE)
#include "mbedtls/config.h"
#else
#include MBEDTLS_CONFIG_FILE
#endif

#ifndef RISCOS
#error "This is a platform specific file for RISC OS only"
#endif

#if defined(MBEDTLS_NET_C)

#if defined(MBEDTLS_PLATFORM_C)
#include "mbedtls/platform.h"
#else
#include <stdlib.h>
#define mbedtls_time_t   time_t
#endif

#include "mbedtls/net_sockets.h"

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/errno.h>
#include <sys/time.h>
#include <sys/filio.h>
#include <unistd.h>
#include <signal.h>
#include <netdb.h>
#include <socklib.h>

#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "swis.h"
#include "Global/OsBytes.h"
#include "AsmUtils/callbacks.h"
#define Resolver_GetHostByName          0x046000
#define Resolver_GetHost                0x046001
#define Resolver_GetCache               0x046002
#define Resolver_CacheControl           0x046003

/*
 * Read processor mode like _kernel_processor_mode()
 */
static int arm_cpsr(void)
{
#define MODE_BITS    0xF   /* Just the mode */
#define MODE_M32_BIT 0x10  /* 32 versus 26 bit */
#define MODE_USR     0     /* USR32 or USR26 if masked with MODE_BITS */

	int procmode;

	__asm
	{
		MRS procmode, CPSR
	}

	return procmode;
}

/*
 * Prepare for using the sockets interface
 */
static int net_prepare(void)
{
	return 0;
}

/*
 * Initialize a context
 */
void mbedtls_net_init(mbedtls_net_context *ctx)
{
	ctx->fd = -1;
}

/*
 * Initiate a TCP connection with host:port and the given protocol
 */
int mbedtls_net_connect(mbedtls_net_context *ctx,
                        const char *url, const char *port, int proto)
{
	int ret;
	struct hostent     *host;
	char              **cur;
	struct sockaddr_in  addr;

	ret = net_prepare();
	if (ret != 0)
	{
		return ret;
	}

	/* No IPv6 support beneath us yet, resolve name to
	 * to IPv4 address.
	 */
	while (1)
	{
		_kernel_oserror *error;
		int errnum;

		error = _swix(Resolver_GetHost, _IN(0) | _OUTR(0,1),
		              url, &errnum, &host);
		if (error != NULL)
		{
			return MBEDTLS_ERR_NET_UNKNOWN_HOST;
		}
		if (errnum == EINPROGRESS)
		{
			if ((arm_cpsr() & MODE_BITS) != MODE_USR) usermode_donothing();
			continue;
		}
		if (errnum == 0) break;
		host = NULL; /* Some other error */
	}
	if (host == NULL)
	{
		return MBEDTLS_ERR_NET_UNKNOWN_HOST;
	}

	/* Try the sockaddrs until a connection succeeds */
	ret = MBEDTLS_ERR_NET_UNKNOWN_HOST;

	for (cur = host->h_addr_list; *cur != NULL; cur++)
	{
		ctx->fd = (int)socket(PF_INET,
		                      proto == MBEDTLS_NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM,
		                      proto == MBEDTLS_NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP);
		if (ctx->fd < 0)
		{
			ret = MBEDTLS_ERR_NET_SOCKET_FAILED;
			continue;
		}

		memset(&addr, 0, sizeof(addr));
		addr.sin_len = host->h_length;
		addr.sin_family = host->h_addrtype;
		addr.sin_port = htons(atoi(port));
		memcpy(&addr.sin_addr.s_addr, *cur, host->h_length);
		if (connect(ctx->fd, (struct sockaddr *)&addr, sizeof(addr)) == 0)
		{
			ret = 0;
			break;
		}

		close(ctx->fd);
		ret = MBEDTLS_ERR_NET_CONNECT_FAILED;
	}

	return ret;
}

/*
 * Create a listening socket on bind_ip:port
 */
int mbedtls_net_bind(mbedtls_net_context *ctx,
                     const char *url, const char *port, int proto)
{
	int n, ret;
	struct hostent     *host;
	char              **cur;
	struct sockaddr_in  addr;

	ret = net_prepare();
	if (ret != 0)
	{
		return ret;
	}

	/* No IPv6 support beneath us yet, resolve name to
	 * to IPv4 address.
	 */
	while (1)
	{
		_kernel_oserror *error;
		int errnum;

		error = _swix(Resolver_GetHost, _IN(0) | _OUTR(0,1),
		              url, &errnum, &host);
		if (error != NULL)
		{
			return MBEDTLS_ERR_NET_UNKNOWN_HOST;
		}
		if (errnum == EINPROGRESS)
		{
			if ((arm_cpsr() & MODE_BITS) != MODE_USR) usermode_donothing();
			continue;
		}
		if (errnum == 0) break;
		host = NULL; /* Some other error */
	}
	if (host == NULL)
	{
		return MBEDTLS_ERR_NET_UNKNOWN_HOST;
	}
	
	/* Try the sockaddrs until a binding succeeds */
	ret = MBEDTLS_ERR_NET_UNKNOWN_HOST;
	for (cur = host->h_addr_list; *cur != NULL; cur++)
	{
		ctx->fd = (int)socket(PF_INET,
		                      proto == MBEDTLS_NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM,
		                      proto == MBEDTLS_NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP);
		if (ctx->fd < 0)
		{
			ret = MBEDTLS_ERR_NET_SOCKET_FAILED;
			continue;
		}

		n = 1;
		if (setsockopt(ctx->fd, SOL_SOCKET, SO_REUSEADDR,
		               (const char *)&n, sizeof(n)) != 0)
		{
			close(ctx->fd);
			ret = MBEDTLS_ERR_NET_SOCKET_FAILED;
			continue;
		}

		memset(&addr, 0, sizeof(addr));
		addr.sin_len = host->h_length;
		addr.sin_family = host->h_addrtype;
		addr.sin_port = htons(atoi(port));
		memcpy(&addr.sin_addr.s_addr, *cur, host->h_length);
		if (bind(ctx->fd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
		{
			close(ctx->fd);
			ret = MBEDTLS_ERR_NET_BIND_FAILED;
			continue;
		}
		
		/* Listen only makes sense for TCP */
		if (proto == MBEDTLS_NET_PROTO_TCP)
		{
			if (listen(ctx->fd, MBEDTLS_NET_LISTEN_BACKLOG) != 0)
			{
				close(ctx->fd);
				ret = MBEDTLS_ERR_NET_LISTEN_FAILED;
				continue;
			}
		}

		/* If we ever get here, it's a success */
		ret = 0;
		break;
	}

	return ret;
}

/*
 * Check if the requested operation would be blocking on a non-blocking socket
 * and thus 'failed' with a negative return value.
 *
 * Note: on a blocking socket this function always returns 0!
 */
static int net_would_block(const mbedtls_net_context *ctx)
{
	int nbio, olderrno;

	/*
	 * Never return 'WOULD BLOCK' on a non-blocking socket
	 */
	olderrno = errno;
	socketioctl(ctx->fd, _IOR('f', 126, int) /* Read FIONBIO */, &nbio);
	errno = olderrno;
	if (!nbio)
	{
		return 0;
	}

	switch (errno)
	{
#if defined EAGAIN
		case EAGAIN:
#endif
#if defined EWOULDBLOCK && EWOULDBLOCK != EAGAIN
		case EWOULDBLOCK:
#endif
			return 1;
	}

	return 0;
}

/*
 * Accept a connection from a remote client
 */
int mbedtls_net_accept(mbedtls_net_context *bind_ctx,
                       mbedtls_net_context *client_ctx,
                       void *client_ip, size_t buf_size, size_t *ip_len)
{
	int ret;
	int type;
	struct sockaddr client_addr;
#if defined(__socklen_t_defined) || defined(_SOCKLEN_T) ||  \
    defined(_SOCKLEN_T_DECLARED) || defined(__DEFINED_socklen_t)
	socklen_t n = (socklen_t)sizeof(client_addr);
	socklen_t type_len = (socklen_t)sizeof(type);
#else
	int n = sizeof(client_addr);
	int type_len = sizeof(type);
#endif

	/* Is this a TCP or UDP socket? */
	if ((getsockopt(bind_ctx->fd, SOL_SOCKET, SO_TYPE, (void *)&type, (int *)&type_len) != 0) ||
	    (type != SOCK_STREAM && type != SOCK_DGRAM))
	{
		return MBEDTLS_ERR_NET_ACCEPT_FAILED;
	}

	if (type == SOCK_STREAM)
	{
		/* TCP: actual accept() */
		ret = client_ctx->fd = (int)accept(bind_ctx->fd,
		                                   (struct sockaddr *)&client_addr, (int *)&n);
	}
	else
	{
		/* UDP: wait for a message, but keep it in the queue */
		char buf[1] = { 0 };

		ret = (int)recvfrom(bind_ctx->fd, buf, sizeof(buf), MSG_PEEK,
		                    (struct sockaddr *)&client_addr, (int *)&n);
	}

	if (ret < 0)
	{
		if (net_would_block(bind_ctx) != 0)
		{
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		return MBEDTLS_ERR_NET_ACCEPT_FAILED;
	}

	/* UDP: hijack the listening socket to communicate with the client,
	 * then bind a new socket to accept new connections
	 */
	if (type != SOCK_STREAM)
	{
		struct sockaddr local_addr;
		int one = 1;
		
		if (connect(bind_ctx->fd, (struct sockaddr *)&client_addr, n) != 0)
		{
			return MBEDTLS_ERR_NET_ACCEPT_FAILED;
		}

		client_ctx->fd = bind_ctx->fd;
		bind_ctx->fd   = -1; /* In case we exit early */

		n = sizeof(struct sockaddr);
		if ((getsockname(client_ctx->fd, (struct sockaddr *)&local_addr, (int *)&n) != 0) ||
		    ((bind_ctx->fd = (int)socket(local_addr.sa_family,
		                                 SOCK_DGRAM, IPPROTO_UDP)) < 0) ||
		    (setsockopt(bind_ctx->fd, SOL_SOCKET, SO_REUSEADDR,
		                (const char *)&one, sizeof(one)) != 0))
		{
			return MBEDTLS_ERR_NET_SOCKET_FAILED;
		}

		if (bind(bind_ctx->fd, (struct sockaddr *)&local_addr, n) != 0)
		{
			return MBEDTLS_ERR_NET_BIND_FAILED;
		}
	}

	if (client_ip != NULL)
	{
		struct sockaddr_in *addr4 = (struct sockaddr_in *) &client_addr;
		*ip_len = sizeof(addr4->sin_addr.s_addr);

		if (buf_size < *ip_len)
		{
			return MBEDTLS_ERR_NET_BUFFER_TOO_SMALL;
		}

		memcpy(client_ip, &addr4->sin_addr.s_addr, *ip_len);
	}

	return 0;
}

/*
 * Set the socket blocking or non-blocking
 */
int mbedtls_net_set_block(mbedtls_net_context *ctx)
{
	int on = 0;

	return socketioctl(ctx->fd, FIONBIO, &on);
}

int mbedtls_net_set_nonblock(mbedtls_net_context *ctx)
{
	int on = 1;

	return socketioctl(ctx->fd, FIONBIO, &on);
}

/*
 * Portable usleep helper
 */
void mbedtls_net_usleep(unsigned long usec)
{
	struct timeval tv;
	tv.tv_sec  = usec / 1000000;
	tv.tv_usec = usec % 1000000;
	select(0, NULL, NULL, NULL, &tv);
}

/*
 * Read at most 'len' characters
 */
int mbedtls_net_recv(void *ctx, unsigned char *buf, size_t len)
{
	int ret;
	int fd = ((mbedtls_net_context *)ctx)->fd;

	if (fd < 0)
	{
		return MBEDTLS_ERR_NET_INVALID_CONTEXT;
	}

	ret = (int)read(fd, buf, len);

	if (ret < 0)
	{
		if (net_would_block(ctx) != 0)
		{
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		if (errno == EPIPE || errno == ECONNRESET)
		{
			return MBEDTLS_ERR_NET_CONN_RESET;
		}

		if (errno == EINTR)
		{
			/* Escape */
			_swix(OS_Byte, _IN(0), OsByte_AcknowledgeEscape); 
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		return MBEDTLS_ERR_NET_RECV_FAILED;
	}

	return ret;
}

/*
 * Read at most 'len' characters, blocking for at most 'timeout' ms
 */
int mbedtls_net_recv_timeout( void *ctx, unsigned char *buf,
                              size_t len, uint32_t timeout )
{
	int ret;
	struct timeval tv;
	fd_set read_fds;
	int fd = ((mbedtls_net_context *)ctx)->fd;
	
	if (fd < 0)
	{
		return MBEDTLS_ERR_NET_INVALID_CONTEXT;
	}

	FD_ZERO(&read_fds);
	FD_SET(fd, &read_fds);

	tv.tv_sec  = timeout / 1000;
	tv.tv_usec = ( timeout % 1000 ) * 1000;

	ret = select(fd + 1, &read_fds, NULL, NULL, timeout == 0 ? NULL : &tv);

	/* Zero fds ready means we timed out */
	if (ret == 0)
	{
		return MBEDTLS_ERR_SSL_TIMEOUT;
	}

	if (ret < 0)
	{
		if (errno == EINTR)
		{
			/* Escape */
			_swix(OS_Byte, _IN(0), OsByte_AcknowledgeEscape); 
			return MBEDTLS_ERR_SSL_WANT_READ;
		}

		return MBEDTLS_ERR_NET_RECV_FAILED;
	}

	/* This call will not block */
	return (mbedtls_net_recv(ctx, buf, len));
}

/*
 * Write at most 'len' characters
 */
int mbedtls_net_send(void *ctx, const unsigned char *buf, size_t len)
{
	int ret;
	int fd = ((mbedtls_net_context *)ctx)->fd;

	if (fd < 0)
	{
		return MBEDTLS_ERR_NET_INVALID_CONTEXT;
	}

	ret = (int)write(fd, buf, len);

	if (ret < 0)
	{
		if (net_would_block(ctx) != 0)
		{
			return MBEDTLS_ERR_SSL_WANT_WRITE;
		}

		if (errno == EPIPE || errno == ECONNRESET)
		{
			return MBEDTLS_ERR_NET_CONN_RESET;
		}

		if (errno == EINTR)
		{
			/* Escape */
			_swix(OS_Byte, _IN(0), OsByte_AcknowledgeEscape); 
			return MBEDTLS_ERR_SSL_WANT_WRITE;
		}

		return MBEDTLS_ERR_NET_SEND_FAILED;
	}

	return ret;
}

/*
 * Gracefully close the connection
 */
void mbedtls_net_free(mbedtls_net_context *ctx)
{
	if (ctx->fd == -1)
	{
		return;
	}

	shutdown(ctx->fd, 2);
	close(ctx->fd);

	ctx->fd = -1;
}

#endif /* MBEDTLS_NET_C */