Files
oceanbase/deps/ussl-hook/ussl-hook.c
2024-03-19 08:39:37 +00:00

305 lines
9.0 KiB
C

/**
* Copyright (c) 2021 OceanBase
* OceanBase CE is licensed under Mulan PubL v2.
* You can use this software according to the terms and conditions of the Mulan PubL v2.
* You may obtain a copy of Mulan PubL v2 at:
* http://license.coscl.org.cn/MulanPubL-2.0
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PubL v2 for more details.
*/
#include <sys/socket.h>
#include <sys/epoll.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/stat.h>
#include "ussl-hook.h"
#include "ussl-deps.h"
#include "loop/ussl_sock.h"
#include "loop/ussl_link.h"
#include "loop/ussl_ihash_map.h"
#include "loop/auth-methods.h"
#include "loop/ussl_eloop.h"
#include "loop/ussl_factory.h"
#include "loop/handle-event.h"
#include "loop/ussl_listenfd.h"
#include "loop/pipefd.h"
#include "ssl/ssl_config.h"
#include "ussl-loop.h"
#include "message.h"
#define USSL_MAX_FD_NUM 1024 * 1024
static __thread uint64_t ussl_dispatch_id;
static __thread int ussl_server_ctx_id = -1;
static uint64_t global_gid_arr[USSL_MAX_FD_NUM];
static int global_client_ctx_id_arr[USSL_MAX_FD_NUM];
static int global_send_negotiation_arr[USSL_MAX_FD_NUM];
int is_ussl_bg_thread_started = 0;
static int ussl_is_stopped = 0;
extern int sockaddr_compare_c(struct sockaddr_storage *left, struct sockaddr_storage *right);
extern char *sockaddr_to_str_c(struct sockaddr_storage *sock_addr, char *buf, int len);
static __attribute__((constructor(102))) void init_global_array()
{
for (int i = 0; i < USSL_MAX_FD_NUM; ++i) {
global_gid_arr[i] = UINT64_MAX;
global_client_ctx_id_arr[i] = -1;
global_send_negotiation_arr[i] = -1;
}
}
static void ussl_clear_client_opt(int fd)
{
if (fd >= 0 && fd < USSL_MAX_FD_NUM) {
global_gid_arr[fd] = UINT64_MAX;
global_client_ctx_id_arr[fd] = -1;
global_send_negotiation_arr[fd] = -1;
}
}
static int ussl_is_pipe(int fd)
{
struct stat st;
return 0 == fstat(fd, &st) && S_ISFIFO(st.st_mode);
}
int ussl_connect(int sockfd, struct sockaddr_storage *address, socklen_t address_len)
{
int ret = 0;
if ((ret = libc_connect(sockfd, (struct sockaddr*)address, address_len)) < 0) {
if (EINPROGRESS != errno) {
ussl_clear_client_opt(sockfd);
}
} else {
struct sockaddr_storage self_addr;
socklen_t len = sizeof(self_addr);
if (0 == getsockname(sockfd, (struct sockaddr *)&self_addr, &len)) {
if (sockaddr_compare_c(address, &self_addr) != 0) {
ret = -1;
errno = EIO;
char str[128];
ussl_log_warn("connection to %s failed, self connect self", sockaddr_to_str_c(address, str, sizeof(str)));
}
} else {
ret = -1;
ussl_log_warn("getsockname failed, fd:%d, error:%s", sockfd, strerror(errno));
}
/* if gid has been set and connect return success, we also let the ret to be -1 and errno to be
* EINPROGRESS */
if (ret == 0 && sockfd >= 0 && sockfd < USSL_MAX_FD_NUM && global_gid_arr[sockfd] != UINT64_MAX) {
ret = -1;
errno = EINPROGRESS;
}
}
return ret;
}
static int ussl_check_methods(int methods)
{
int ret = 0;
int flag =
methods & ~USSL_AUTH_NONE & ~USSL_AUTH_SSL_HANDSHAKE & ~USSL_AUTH_SSL_IO;
if (methods <= 0 || 0 != flag) {
ret = -1;
errno = EINVAL;
ussl_log_warn("invalid methods, all valid methods: USSL_AUTH_NONE(1) "
"USSL_AUTH_SSL_HANDSHAKE(2), USSL_AUTH_SSL_IO(4).")
}
return ret;
}
static int ussl_check_fd(int fd)
{
int ret = 0;
if (fd >= USSL_MAX_FD_NUM || fd < 0) {
ret = -1;
errno = EINVAL;
ussl_log_error(
"fd is expected to be a non-negative integer and less than %d, but currently fd is %d",
USSL_MAX_FD_NUM, fd);
}
return ret;
}
static int ussl_check_ctx_id(int ctx_id)
{
int ret = 0;
if (ctx_id < 0) {
ret = -1;
errno = EINVAL;
ussl_log_error("ctx_id cannot be negative, ctx_id:%d", ctx_id);
}
return ret;
}
int ussl_setsockopt(int socket, int level, int optname, const void *optval, socklen_t optlen)
{
int ret = 0;
if (ATOMIC_BCAS(&is_ussl_bg_thread_started, 0, 1)) {
ret = ussl_init_bg_thread();
if (0 != ret) {
ussl_log_error("start ussl-bk-thread failed!, ret:%d", ret);
ATOMIC_STORE(&is_ussl_bg_thread_started, 0);
}
}
if (0 == ret) {
if (SOL_OB_SOCKET == level) {
if (SO_OB_SET_CLIENT_GID == optname || SO_OB_SET_SERVER_GID == optname) {
uint64_t gid = *(uint64_t *)optval;
if (0 == (ret = ussl_check_fd(socket))) {
global_gid_arr[socket] = gid;
}
} else if (SO_OB_SET_CLIENT_SSL_CTX_ID == optname) {
int client_ctx_id = *(int *)optval;
if (0 == (ret = ussl_check_fd(socket)) && 0 == (ret = ussl_check_ctx_id(client_ctx_id))) {
global_client_ctx_id_arr[socket] = client_ctx_id;
}
} else if (SO_OB_SET_SERVER_SSL_CTX_ID == optname) {
int server_ctx_id = *(int *)optval;
if (0 == (ret = ussl_check_ctx_id(server_ctx_id))) {
ussl_server_ctx_id = server_ctx_id;
}
} else if (SO_OB_SET_SEND_NEGOTIATION_FLAG == optname) {
int need_send_flag = *(int *)optval;
if (0 == (ret = ussl_check_fd(socket))) {
global_send_negotiation_arr[socket] = need_send_flag;
}
}
} else if (SOL_OB_CTX == level) {
if (SO_OB_CTX_SERVER_AUTH_METHODS == optname) {
int methods = *(int *)optval;
if (0 == (ret = ussl_check_methods(methods))){
set_server_auth_methods(methods);
}
} else if (SO_OB_CTX_CLIENT_AUTH_METHODS == optname) {
int methods = *(int *)optval;
if (0 == (ret = ussl_check_methods(methods))){
set_client_auth_methods(methods);
}
} else if (SO_OB_CTX_SET_SSL_CONFIG == optname) {
ret = ssl_load_config(socket, (ssl_config_item_t *)optval);
}
} else {
ret = libc_setsockopt(socket, level, optname, optval, optlen);
}
}
return ret;
}
void ussl_stop()
{
ATOMIC_STORE(&ussl_is_stopped, 1);
}
int ussl_is_stop()
{
return ATOMIC_LOAD(&ussl_is_stopped);
}
void ussl_wait()
{
ussl_wait_bg_thread();
}
int ussl_listen(int socket, int backlog)
{
int ret = 0;
if (socket < 0 || socket >= USSL_MAX_FD_NUM) {
ret = -1;
errno = EINVAL;
ussl_fatal(
"fd is expected to be a non-negative integer and less than %d, but currently fd is %d",
USSL_MAX_FD_NUM, socket);
} else {
if (0 != ussl_loop_add_listen(socket, backlog)) {
ret = -1;
ussl_log_error("add listen filed, fd:%d", socket);
}
}
return ret;
}
int ussl_accept(int socket, struct sockaddr *address, socklen_t *address_len)
{
return ussl_accept4(socket, address, address_len, 0);
}
int ussl_accept4(int socket, struct sockaddr *address, socklen_t *address_len, int flags)
{
int accept_fd = -1;
if (ussl_is_pipe(socket)) {
if (read(socket, &accept_fd, sizeof(accept_fd)) != sizeof(accept_fd)) {
accept_fd = -1;
ussl_log_error("read from pipe fd filed, fd:%d, errno:%d", socket, errno);
} else if (NULL != address && getpeername(accept_fd, address, address_len) < 0) {
accept_fd = -1;
ussl_log_error("getpeername failed, fd:%d, errno:%d", accept_fd, errno);
}
} else {
accept_fd = libc_accept4(socket, address, address_len, flags);
}
return accept_fd;
}
int ussl_epoll_ctl(int epfd, int op, int fd, struct epoll_event *event)
{
int ret = 0;
// Normally, ussl_setsockopt ia called before ussl_epoll_ctl and there is a checker for
// array out-of-bounds in ussl_setsockopt. If the global_gid_arr is out-of-bounds here,
// it must be a serious error.
if (fd < 0 || fd >= USSL_MAX_FD_NUM) {
ret = -1;
ussl_fatal(
"fd is expected to be a non-negative integer and less than %d, but currently fd is %d",
USSL_MAX_FD_NUM, fd);
} else if (global_gid_arr[fd] != UINT64_MAX && EPOLL_CTL_ADD == op) {
ret = ussl_loop_add_clientfd(fd, global_gid_arr[fd], global_client_ctx_id_arr[fd], global_send_negotiation_arr[fd],
get_client_auth_methods(), epfd, event);
ussl_clear_client_opt(fd);
} else {
ret = libc_epoll_ctl(epfd, op, fd, event);
}
return ret;
}
ssize_t ussl_read(int fd, char *buf, size_t nbytes)
{
return read_regard_ssl(fd, buf, nbytes);
}
ssize_t ussl_write(int fd, const void *buf, size_t nbytes)
{
return write_regard_ssl(fd, buf, nbytes);
}
int ussl_close(int fd)
{
ussl_reset_rpc_connection_type(fd);
fd_disable_ssl(fd);
return libc_close(fd);
}
ssize_t ussl_writev(int fildes, const struct iovec *iov, int iovcnt)
{
return writev_regard_ssl(fildes, iov, iovcnt);
}
#include "ussl-deps.c"
#include "loop/auth-methods.c"
#include "loop/ussl_eloop.c"
#include "loop/ussl_factory.c"
#include "loop/handle-event.c"
#include "loop/ussl_listenfd.c"
#include "loop/pipefd.c"
#include "loop/ussl_sock.c"
#include "loop/ussl_link.c"
#include "loop/ussl_ihash_map.c"
#include "ssl/ssl_config.c"
#include "ussl-loop.c"
#include "message.c"