/** * Copyright (c) 2021 OceanBase * OceanBase CE is licensed under Mulan PubL v2. * You can use this software according to the terms and conditions of the Mulan PubL v2. * You may obtain a copy of Mulan PubL v2 at: * http://license.coscl.org.cn/MulanPubL-2.0 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. * See the Mulan PubL v2 for more details. */ #include #include #include #include #include #include "ussl-hook.h" #include "ussl-deps.h" #include "loop/ussl_sock.h" #include "loop/ussl_link.h" #include "loop/ussl_ihash_map.h" #include "loop/auth-methods.h" #include "loop/ussl_eloop.h" #include "loop/ussl_factory.h" #include "loop/handle-event.h" #include "loop/ussl_listenfd.h" #include "loop/pipefd.h" #include "ssl/ssl_config.h" #include "ussl-loop.h" #include "message.h" #define USSL_MAX_FD_NUM 1024 * 1024 static __thread uint64_t ussl_dispatch_id; static __thread int ussl_server_ctx_id = -1; static uint64_t global_gid_arr[USSL_MAX_FD_NUM]; static int global_client_ctx_id_arr[USSL_MAX_FD_NUM]; static int global_send_negotiation_arr[USSL_MAX_FD_NUM]; int is_ussl_bg_thread_started = 0; static int ussl_is_stopped = 0; extern int sockaddr_compare_c(struct sockaddr_storage *left, struct sockaddr_storage *right); extern char *sockaddr_to_str_c(struct sockaddr_storage *sock_addr, char *buf, int len); static __attribute__((constructor(102))) void init_global_array() { for (int i = 0; i < USSL_MAX_FD_NUM; ++i) { global_gid_arr[i] = UINT64_MAX; global_client_ctx_id_arr[i] = -1; global_send_negotiation_arr[i] = -1; } } static void ussl_clear_client_opt(int fd) { if (fd >= 0 && fd < USSL_MAX_FD_NUM) { global_gid_arr[fd] = UINT64_MAX; global_client_ctx_id_arr[fd] = -1; global_send_negotiation_arr[fd] = -1; } } static int ussl_is_pipe(int fd) { struct stat st; return 0 == fstat(fd, &st) && S_ISFIFO(st.st_mode); } int ussl_connect(int sockfd, struct sockaddr_storage *address, socklen_t address_len) { int ret = 0; if ((ret = libc_connect(sockfd, (struct sockaddr*)address, address_len)) < 0) { if (EINPROGRESS != errno) { ussl_clear_client_opt(sockfd); } } else { struct sockaddr_storage self_addr; socklen_t len = sizeof(self_addr); if (0 == getsockname(sockfd, (struct sockaddr *)&self_addr, &len)) { if (sockaddr_compare_c(address, &self_addr) != 0) { ret = -1; errno = EIO; char str[128]; ussl_log_warn("connection to %s failed, self connect self", sockaddr_to_str_c(address, str, sizeof(str))); } } else { ret = -1; ussl_log_warn("getsockname failed, fd:%d, error:%s", sockfd, strerror(errno)); } /* if gid has been set and connect return success, we also let the ret to be -1 and errno to be * EINPROGRESS */ if (ret == 0 && sockfd >= 0 && sockfd < USSL_MAX_FD_NUM && global_gid_arr[sockfd] != UINT64_MAX) { ret = -1; errno = EINPROGRESS; } } return ret; } static int ussl_check_methods(int methods) { int ret = 0; int flag = methods & ~USSL_AUTH_NONE & ~USSL_AUTH_SSL_HANDSHAKE & ~USSL_AUTH_SSL_IO; if (methods <= 0 || 0 != flag) { ret = -1; errno = EINVAL; ussl_log_warn("invalid methods, all valid methods: USSL_AUTH_NONE(1) " "USSL_AUTH_SSL_HANDSHAKE(2), USSL_AUTH_SSL_IO(4).") } return ret; } static int ussl_check_fd(int fd) { int ret = 0; if (fd >= USSL_MAX_FD_NUM || fd < 0) { ret = -1; errno = EINVAL; ussl_log_error( "fd is expected to be a non-negative integer and less than %d, but currently fd is %d", USSL_MAX_FD_NUM, fd); } return ret; } static int ussl_check_ctx_id(int ctx_id) { int ret = 0; if (ctx_id < 0) { ret = -1; errno = EINVAL; ussl_log_error("ctx_id cannot be negative, ctx_id:%d", ctx_id); } return ret; } int ussl_setsockopt(int socket, int level, int optname, const void *optval, socklen_t optlen) { int ret = 0; if (ATOMIC_BCAS(&is_ussl_bg_thread_started, 0, 1)) { ret = ussl_init_bg_thread(); if (0 != ret) { ussl_log_error("start ussl-bk-thread failed!, ret:%d", ret); ATOMIC_STORE(&is_ussl_bg_thread_started, 0); } } if (0 == ret) { if (SOL_OB_SOCKET == level) { if (SO_OB_SET_CLIENT_GID == optname || SO_OB_SET_SERVER_GID == optname) { uint64_t gid = *(uint64_t *)optval; if (0 == (ret = ussl_check_fd(socket))) { global_gid_arr[socket] = gid; } } else if (SO_OB_SET_CLIENT_SSL_CTX_ID == optname) { int client_ctx_id = *(int *)optval; if (0 == (ret = ussl_check_fd(socket)) && 0 == (ret = ussl_check_ctx_id(client_ctx_id))) { global_client_ctx_id_arr[socket] = client_ctx_id; } } else if (SO_OB_SET_SERVER_SSL_CTX_ID == optname) { int server_ctx_id = *(int *)optval; if (0 == (ret = ussl_check_ctx_id(server_ctx_id))) { ussl_server_ctx_id = server_ctx_id; } } else if (SO_OB_SET_SEND_NEGOTIATION_FLAG == optname) { int need_send_flag = *(int *)optval; if (0 == (ret = ussl_check_fd(socket))) { global_send_negotiation_arr[socket] = need_send_flag; } } } else if (SOL_OB_CTX == level) { if (SO_OB_CTX_SERVER_AUTH_METHODS == optname) { int methods = *(int *)optval; if (0 == (ret = ussl_check_methods(methods))){ set_server_auth_methods(methods); } } else if (SO_OB_CTX_CLIENT_AUTH_METHODS == optname) { int methods = *(int *)optval; if (0 == (ret = ussl_check_methods(methods))){ set_client_auth_methods(methods); } } else if (SO_OB_CTX_SET_SSL_CONFIG == optname) { ret = ssl_load_config(socket, (ssl_config_item_t *)optval); } } else { ret = libc_setsockopt(socket, level, optname, optval, optlen); } } return ret; } void ussl_stop() { ATOMIC_STORE(&ussl_is_stopped, 1); } int ussl_is_stop() { return ATOMIC_LOAD(&ussl_is_stopped); } void ussl_wait() { ussl_wait_bg_thread(); } int ussl_listen(int socket, int backlog) { int ret = 0; if (socket < 0 || socket >= USSL_MAX_FD_NUM) { ret = -1; errno = EINVAL; ussl_fatal( "fd is expected to be a non-negative integer and less than %d, but currently fd is %d", USSL_MAX_FD_NUM, socket); } else { if (0 != ussl_loop_add_listen(socket, backlog)) { ret = -1; ussl_log_error("add listen filed, fd:%d", socket); } } return ret; } int ussl_accept(int socket, struct sockaddr *address, socklen_t *address_len) { return ussl_accept4(socket, address, address_len, 0); } int ussl_accept4(int socket, struct sockaddr *address, socklen_t *address_len, int flags) { int accept_fd = -1; if (ussl_is_pipe(socket)) { if (read(socket, &accept_fd, sizeof(accept_fd)) != sizeof(accept_fd)) { accept_fd = -1; ussl_log_error("read from pipe fd filed, fd:%d, errno:%d", socket, errno); } else if (NULL != address && getpeername(accept_fd, address, address_len) < 0) { accept_fd = -1; ussl_log_error("getpeername failed, fd:%d, errno:%d", accept_fd, errno); } } else { accept_fd = libc_accept4(socket, address, address_len, flags); } return accept_fd; } int ussl_epoll_ctl(int epfd, int op, int fd, struct epoll_event *event) { int ret = 0; // Normally, ussl_setsockopt ia called before ussl_epoll_ctl and there is a checker for // array out-of-bounds in ussl_setsockopt. If the global_gid_arr is out-of-bounds here, // it must be a serious error. if (fd < 0 || fd >= USSL_MAX_FD_NUM) { ret = -1; ussl_fatal( "fd is expected to be a non-negative integer and less than %d, but currently fd is %d", USSL_MAX_FD_NUM, fd); } else if (global_gid_arr[fd] != UINT64_MAX && EPOLL_CTL_ADD == op) { ret = ussl_loop_add_clientfd(fd, global_gid_arr[fd], global_client_ctx_id_arr[fd], global_send_negotiation_arr[fd], get_client_auth_methods(), epfd, event); ussl_clear_client_opt(fd); } else { ret = libc_epoll_ctl(epfd, op, fd, event); } return ret; } ssize_t ussl_read(int fd, char *buf, size_t nbytes) { return read_regard_ssl(fd, buf, nbytes); } ssize_t ussl_write(int fd, const void *buf, size_t nbytes) { return write_regard_ssl(fd, buf, nbytes); } int ussl_close(int fd) { ussl_reset_rpc_connection_type(fd); fd_disable_ssl(fd); return libc_close(fd); } ssize_t ussl_writev(int fildes, const struct iovec *iov, int iovcnt) { return writev_regard_ssl(fildes, iov, iovcnt); } #include "ussl-deps.c" #include "loop/auth-methods.c" #include "loop/ussl_eloop.c" #include "loop/ussl_factory.c" #include "loop/handle-event.c" #include "loop/ussl_listenfd.c" #include "loop/pipefd.c" #include "loop/ussl_sock.c" #include "loop/ussl_link.c" #include "loop/ussl_ihash_map.c" #include "ssl/ssl_config.c" #include "ussl-loop.c" #include "message.c"