diff --git a/be/src/service/arrow_flight/auth_server_middleware.cpp b/be/src/service/arrow_flight/auth_server_middleware.cpp new file mode 100644 index 0000000000..c0bf5b853b --- /dev/null +++ b/be/src/service/arrow_flight/auth_server_middleware.cpp @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "service/arrow_flight/auth_server_middleware.h" + +#include "service/arrow_flight/call_header_utils.h" + +namespace doris { +namespace flight { + +void NoOpHeaderAuthServerMiddleware::SendingHeaders( + arrow::flight::AddCallHeaders* outgoing_headers) { + outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerDefaultToken); +} + +arrow::Status NoOpHeaderAuthServerMiddlewareFactory::StartCall( + const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) { + std::string username, password; + ParseBasicHeader(context.incoming_headers(), username, password); + *middleware = std::make_shared(); + return arrow::Status::OK(); +} + +void NoOpBearerAuthServerMiddleware::SendingHeaders( + arrow::flight::AddCallHeaders* outgoing_headers) { + std::string bearer_token = + FindKeyValPrefixInCallHeaders(_incoming_headers, kAuthHeader, kBearerPrefix); + *_is_valid = (bearer_token == std::string(kBearerDefaultToken)); +} + +arrow::Status NoOpBearerAuthServerMiddlewareFactory::StartCall( + const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) { + *middleware = std::make_shared(context.incoming_headers(), + &_is_valid); + return arrow::Status::OK(); +} + +} // namespace flight +} // namespace doris diff --git a/be/src/service/arrow_flight/auth_server_middleware.h b/be/src/service/arrow_flight/auth_server_middleware.h new file mode 100644 index 0000000000..e5f40cf626 --- /dev/null +++ b/be/src/service/arrow_flight/auth_server_middleware.h @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/types.h" + +namespace doris { +namespace flight { + +// Just return default bearer token. +class NoOpHeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware { +public: + void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const arrow::Status& status) override {} + + [[nodiscard]] std::string name() const override { return "NoOpHeaderAuthServerMiddleware"; } +}; + +// Factory for base64 header authentication. +// No actual authentication. +class NoOpHeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory { +public: + NoOpHeaderAuthServerMiddlewareFactory() = default; + + arrow::Status StartCall(const arrow::flight::CallInfo& info, + const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) override; +}; + +// A server middleware for validating incoming bearer header authentication. +// Just compare with default bearer token. +class NoOpBearerAuthServerMiddleware : public arrow::flight::ServerMiddleware { +public: + explicit NoOpBearerAuthServerMiddleware(const arrow::flight::CallHeaders& incoming_headers, + bool* isValid) + : _is_valid(isValid) { + _incoming_headers = incoming_headers; + } + + void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const arrow::Status& status) override {} + + [[nodiscard]] std::string name() const override { return "NoOpBearerAuthServerMiddleware"; } + +private: + arrow::flight::CallHeaders _incoming_headers; + bool* _is_valid; +}; + +// Factory for base64 header authentication. +// No actual authentication. +class NoOpBearerAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory { +public: + NoOpBearerAuthServerMiddlewareFactory() : _is_valid(false) {} + + arrow::Status StartCall(const arrow::flight::CallInfo& info, + const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) override; + + [[nodiscard]] bool GetIsValid() const { return _is_valid; } + +private: + bool _is_valid; +}; + +} // namespace flight +} // namespace doris diff --git a/be/src/service/arrow_flight/call_header_utils.h b/be/src/service/arrow_flight/call_header_utils.h new file mode 100644 index 0000000000..88990228bb --- /dev/null +++ b/be/src/service/arrow_flight/call_header_utils.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/types.h" +#include "arrow/util/base64.h" + +namespace doris { +namespace flight { + +const char kBearerDefaultToken[] = "bearertoken"; +const char kBasicPrefix[] = "Basic "; +const char kBearerPrefix[] = "Bearer "; +const char kAuthHeader[] = "authorization"; + +// Function to look in CallHeaders for a key that has a value starting with prefix and +// return the rest of the value after the prefix. +std::string FindKeyValPrefixInCallHeaders(const arrow::flight::CallHeaders& incoming_headers, + const std::string& key, const std::string& prefix) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + auto iter = incoming_headers.find(key); + if (iter == incoming_headers.end()) { + return ""; + } + const std::string val(iter->second); + if (val.size() > prefix.length()) { + if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), char_compare)) { + return val.substr(prefix.length()); + } + } + return ""; +} + +void ParseBasicHeader(const arrow::flight::CallHeaders& incoming_headers, std::string& username, + std::string& password) { + std::string encoded_credentials = + FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix); + std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials)); + std::getline(decoded_stream, username, ':'); + std::getline(decoded_stream, password, ':'); +} + +} // namespace flight +} // namespace doris diff --git a/be/src/service/arrow_flight/flight_sql_service.cpp b/be/src/service/arrow_flight/flight_sql_service.cpp index 60add8698a..719f7a466c 100644 --- a/be/src/service/arrow_flight/flight_sql_service.cpp +++ b/be/src/service/arrow_flight/flight_sql_service.cpp @@ -74,7 +74,7 @@ public: } }; -FlightSqlServer::FlightSqlServer(std::shared_ptr impl) : impl_(std::move(impl)) {} +FlightSqlServer::FlightSqlServer(std::shared_ptr impl) : _impl(std::move(impl)) {} arrow::Result> FlightSqlServer::create() { std::shared_ptr impl = std::make_shared(); @@ -94,7 +94,7 @@ FlightSqlServer::~FlightSqlServer() { arrow::Result> FlightSqlServer::DoGetStatement( const arrow::flight::ServerCallContext& context, const arrow::flight::sql::StatementQueryTicket& command) { - return impl_->DoGetStatement(context, command); + return _impl->DoGetStatement(context, command); } Status FlightSqlServer::init(int port) { @@ -108,6 +108,18 @@ Status FlightSqlServer::init(int port) { arrow::flight::Location::ForGrpcTcp(BackendOptions::get_service_bind_address(), port) .Value(&bind_location)); arrow::flight::FlightServerOptions flight_options(bind_location); + + // Not authenticated in BE flight server. + // After the authentication between the ADBC Client and the FE flight server is completed, + // the FE flight server will put the query id in the Ticket and send it back to the Client. + // When the Client uses the Ticket to fetch data from the BE flight server, the BE flight + // server will verify the query id, this step is equivalent to authentication. + _header_middleware = std::make_shared(); + _bearer_middleware = std::make_shared(); + flight_options.auth_handler = std::make_unique(); + flight_options.middleware.push_back({"header-auth-server", _header_middleware}); + flight_options.middleware.push_back({"bearer-auth-server", _bearer_middleware}); + RETURN_DORIS_STATUS_IF_ERROR(Init(flight_options)); LOG(INFO) << "Arrow Flight Service bind to host: " << BackendOptions::get_service_bind_address() << ", port: " << port; diff --git a/be/src/service/arrow_flight/flight_sql_service.h b/be/src/service/arrow_flight/flight_sql_service.h index 4772e98d81..0334f1a313 100644 --- a/be/src/service/arrow_flight/flight_sql_service.h +++ b/be/src/service/arrow_flight/flight_sql_service.h @@ -21,6 +21,7 @@ #include "arrow/result.h" #include "common/status.h" #include "service/arrow_flight/arrow_flight_batch_reader.h" +#include "service/arrow_flight/auth_server_middleware.h" namespace doris { namespace flight { @@ -40,9 +41,12 @@ public: private: class Impl; - std::shared_ptr impl_; + std::shared_ptr _impl; bool _inited = false; + std::shared_ptr _header_middleware; + std::shared_ptr _bearer_middleware; + explicit FlightSqlServer(std::shared_ptr impl); }; diff --git a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java index ec5c47b70e..a049c651da 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java +++ b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java @@ -390,7 +390,7 @@ public class Config extends ConfigBase { @ConfField(description = {"FE MySQL server 的端口号", "The port of FE MySQL server"}) public static int query_port = 9030; - @ConfField(description = {"FE Arrow-Flight-SQL server 的端口号", "The port of FE Arrow-Flight-SQ server"}) + @ConfField(description = {"FE Arrow-Flight-SQL server 的端口号", "The port of FE Arrow-Flight-SQL server"}) public static int arrow_flight_sql_port = -1; @ConfField(description = {"MySQL 服务的 IO 线程数", "The number of IO threads in MySQL service"}) @@ -2211,4 +2211,15 @@ public class Config extends ConfigBase { "min buckets of auto bucket" }) public static int autobucket_min_buckets = 1; + + @ConfField(description = {"Arrow Flight Server中所有用户token的缓存上限,超过后LRU淘汰,默认值为2000", + "The cache limit of all user tokens in Arrow Flight Server. which will be eliminated by" + + "LRU rules after exceeding the limit, the default value is 2000."}) + public static int arrow_flight_token_cache_size = 2000; + + @ConfField(description = {"Arrow Flight Server中用户token的存活时间,自上次写入后过期时间,单位分钟,默认值为4320,即3天", + "The alive time of the user token in Arrow Flight Server, expire after write, unit minutes," + + "the default value is 4320, which is 3 days"}) + public static int arrow_flight_token_alive_time = 4320; + } diff --git a/fe/fe-core/pom.xml b/fe/fe-core/pom.xml index d2bc2b4795..34f7858112 100644 --- a/fe/fe-core/pom.xml +++ b/fe/fe-core/pom.xml @@ -759,6 +759,10 @@ under the License. org.apache.arrow flight-sql + + org.immutables + value + diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index 302d3544de..91be948aa9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -75,6 +75,12 @@ public class ConnectContext { private static final String SSL_PROTOCOL = "TLS"; + public enum ConnectType { + MYSQL, + ARROW_FLIGHT + } + + protected volatile ConnectType connectType; // set this id before analyze protected volatile long stmtId; protected volatile long forwardedStmtId; @@ -90,6 +96,8 @@ public class ConnectContext { protected volatile int connectionId; // Timestamp when the connection is make protected volatile long loginTime; + // arrow flight token + protected volatile String peerIdentity; // mysql net protected volatile MysqlChannel mysqlChannel; // state @@ -268,11 +276,31 @@ public class ConnectContext { return notEvalNondeterministicFunction; } + public ConnectType getConnectType() { + return connectType; + } + public ConnectContext() { - this(null); + this((StreamConnection) null); + } + + public ConnectContext(String peerIdentity) { + this.connectType = ConnectType.ARROW_FLIGHT; + this.peerIdentity = peerIdentity; + state = new QueryState(); + returnRows = 0; + isKilled = false; + sessionVariable = VariableMgr.newSessionVariable(); + mysqlChannel = new DummyMysqlChannel(); + command = MysqlCommand.COM_SLEEP; + if (Config.use_fuzzy_session_variable) { + sessionVariable.initFuzzyModeVariables(); + } + setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL); } public ConnectContext(StreamConnection connection) { + connectType = ConnectType.MYSQL; state = new QueryState(); returnRows = 0; serverCapability = MysqlCapability.DEFAULT_CAPABILITY; @@ -507,6 +535,10 @@ public class ConnectContext { this.loginTime = System.currentTimeMillis(); } + public String getPeerIdentity() { + return peerIdentity; + } + public MysqlChannel getMysqlChannel() { return mysqlChannel; } @@ -662,15 +694,28 @@ public class ConnectContext { this.resultSinkType = resultSinkType; } + public String getRemoteHostPortString() { + if (connectType.equals(ConnectType.MYSQL)) { + return getMysqlChannel().getRemoteHostPortString(); + } else if (connectType.equals(ConnectType.ARROW_FLIGHT)) { + // TODO Get flight client IP:Port + return peerIdentity; + } + return ""; + } + // kill operation with no protect. public void kill(boolean killConnection) { - LOG.warn("kill query from {}, kill connection: {}", getMysqlChannel().getRemoteHostPortString(), - killConnection); + LOG.warn("kill query from {}, kill connection: {}", getRemoteHostPortString(), killConnection); if (killConnection) { isKilled = true; - // Close channel to break connection with client - getMysqlChannel().close(); + if (connectType.equals(ConnectType.MYSQL)) { + // Close channel to break connection with client + getMysqlChannel().close(); + } else if (connectType.equals(ConnectType.ARROW_FLIGHT)) { + connectScheduler.unregisterConnection(this); + } } // Now, cancel running query. cancelQuery(); @@ -695,7 +740,7 @@ public class ConnectContext { if (delta > sessionVariable.getWaitTimeoutS() * 1000L) { // Need kill this connection. LOG.warn("kill wait timeout connection, remote: {}, wait timeout: {}", - getMysqlChannel().getRemoteHostPortString(), sessionVariable.getWaitTimeoutS()); + getRemoteHostPortString(), sessionVariable.getWaitTimeoutS()); killFlag = true; killConnection = true; @@ -706,11 +751,11 @@ public class ConnectContext { if (executor != null && executor.isInsertStmt()) { timeoutTag = "insert"; } - //to ms + // to ms long timeout = getExecTimeout() * 1000L; if (delta > timeout) { LOG.warn("kill {} timeout, remote: {}, query timeout: {}", - timeoutTag, getMysqlChannel().getRemoteHostPortString(), timeout); + timeoutTag, getRemoteHostPortString(), timeout); killFlag = true; } } @@ -791,7 +836,7 @@ public class ConnectContext { } row.add("" + connectionId); row.add(ClusterNamespace.getNameFromFullName(qualifiedUser)); - row.add(getMysqlChannel().getRemoteHostPortString()); + row.add(getRemoteHostPortString()); row.add(TimeUtils.longToTimeString(loginTime)); row.add(defaultCatalog); row.add(ClusterNamespace.getNameFromFullName(currentDb)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java index 5090e623a9..70bfd7e2d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.Env; import org.apache.doris.common.ThreadPoolManager; import org.apache.doris.common.util.DebugUtil; import org.apache.doris.mysql.privilege.PrivPredicate; +import org.apache.doris.qe.ConnectContext.ConnectType; import org.apache.doris.thrift.TUniqueId; import com.google.common.collect.Lists; @@ -45,6 +46,7 @@ public class ConnectScheduler { private final AtomicInteger nextConnectionId; private final Map connectionMap = Maps.newConcurrentMap(); private final Map connByUser = Maps.newConcurrentMap(); + private final Map flightToken2ConnectionId = Maps.newConcurrentMap(); // valid trace id -> query id private final Map traceId2QueryId = Maps.newConcurrentMap(); @@ -100,6 +102,9 @@ public class ConnectScheduler { return false; } connectionMap.put(ctx.getConnectionId(), ctx); + if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) { + flightToken2ConnectionId.put(ctx.getPeerIdentity(), ctx.getConnectionId()); + } return true; } @@ -111,6 +116,9 @@ public class ConnectScheduler { conns.decrementAndGet(); } numberConnection.decrementAndGet(); + if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) { + flightToken2ConnectionId.remove(ctx.getPeerIdentity()); + } } } @@ -118,6 +126,14 @@ public class ConnectScheduler { return connectionMap.get(connectionId); } + public ConnectContext getContext(String flightToken) { + if (flightToken2ConnectionId.containsKey(flightToken)) { + int connectionId = flightToken2ConnectionId.get(flightToken); + return getContext(connectionId); + } + return null; + } + public void cancelQuery(String queryId) { for (ConnectContext ctx : connectionMap.values()) { TUniqueId qid = ctx.queryId(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/QeService.java b/fe/fe-core/src/main/java/org/apache/doris/qe/QeService.java index f1e9a65345..00121319ec 100755 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/QeService.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/QeService.java @@ -18,7 +18,7 @@ package org.apache.doris.qe; import org.apache.doris.mysql.MysqlServer; -import org.apache.doris.service.arrowflight.FlightSqlService; +import org.apache.doris.service.arrowflight.DorisFlightSqlService; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -35,7 +35,7 @@ public class QeService { private MysqlServer mysqlServer; private int arrowFlightSQLPort; - private FlightSqlService flightSqlService; + private DorisFlightSqlService dorisFlightSqlService; @Deprecated public QeService(int port, int arrowFlightSQLPort) { @@ -63,8 +63,8 @@ public class QeService { System.exit(-1); } if (arrowFlightSQLPort != -1) { - this.flightSqlService = new FlightSqlService(arrowFlightSQLPort); - if (!flightSqlService.start()) { + this.dorisFlightSqlService = new DorisFlightSqlService(arrowFlightSQLPort); + if (!dorisFlightSqlService.start()) { System.exit(-1); } } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlServiceImpl.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java similarity index 92% rename from fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlServiceImpl.java rename to fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java index 38e275b1d5..0e73fbb2ad 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlServiceImpl.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java @@ -22,6 +22,9 @@ package org.apache.doris.service.arrowflight; import org.apache.doris.common.util.DebugUtil; import org.apache.doris.common.util.Util; +import org.apache.doris.mysql.MysqlCommand; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager; import com.google.protobuf.Any; import com.google.protobuf.ByteString; @@ -67,14 +70,16 @@ import org.apache.logging.log4j.Logger; import java.util.Collections; import java.util.List; -public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable { - private static final Logger LOG = LogManager.getLogger(FlightSqlServiceImpl.class); +public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable { + private static final Logger LOG = LogManager.getLogger(DorisFlightSqlProducer.class); private final Location location; private final BufferAllocator rootAllocator = new RootAllocator(); private final SqlInfoBuilder sqlInfoBuilder; + private final FlightSessionsManager flightSessionsManager; - public FlightSqlServiceImpl(final Location location) { + public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) { this.location = location; + this.flightSessionsManager = flightSessionsManager; sqlInfoBuilder = new SqlInfoBuilder(); sqlInfoBuilder.withFlightSqlServerName("DorisFE") .withFlightSqlServerVersion("1.0") @@ -103,9 +108,13 @@ public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable { @Override public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context, final FlightDescriptor descriptor) { + ConnectContext connectContext = null; try { + connectContext = flightSessionsManager.getConnectContext(context.peerIdentity()); + // Only for ConnectContext check timeout. + connectContext.setCommand(MysqlCommand.COM_QUERY); final String query = request.getQuery(); - final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query); + final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query, connectContext); flightStatementExecutor.executeQuery(); @@ -123,8 +132,13 @@ public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable { if (schema == null) { throw CallStatus.INTERNAL.withDescription("fetch arrow flight schema is null").toRuntimeException(); } + // TODO Set in BE callback after query end, Client client will not callback by default. + connectContext.setCommand(MysqlCommand.COM_SLEEP); return new FlightInfo(schema, descriptor, endpoints, -1, -1); } catch (Exception e) { + if (null != connectContext) { + connectContext.setCommand(MysqlCommand.COM_SLEEP); + } LOG.warn("get flight info statement failed, " + e.getMessage(), e); throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlService.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlService.java similarity index 64% rename from fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlService.java rename to fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlService.java index e0ec4bf10c..08e91a15ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlService.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlService.java @@ -17,6 +17,13 @@ package org.apache.doris.service.arrowflight; +import org.apache.doris.common.Config; +import org.apache.doris.service.arrowflight.auth2.FlightBearerTokenAuthenticator; +import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager; +import org.apache.doris.service.arrowflight.sessions.FlightSessionsWithTokenManager; +import org.apache.doris.service.arrowflight.tokens.FlightTokenManager; +import org.apache.doris.service.arrowflight.tokens.FlightTokenManagerImpl; + import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; @@ -29,16 +36,23 @@ import java.io.IOException; /** * flight sql protocol implementation based on nio. */ -public class FlightSqlService { - private static final Logger LOG = LogManager.getLogger(FlightSqlService.class); +public class DorisFlightSqlService { + private static final Logger LOG = LogManager.getLogger(DorisFlightSqlService.class); private final FlightServer flightServer; private volatile boolean running; + private final FlightTokenManager flightTokenManager; + private final FlightSessionsManager flightSessionsManager; - public FlightSqlService(int port) { + public DorisFlightSqlService(int port) { BufferAllocator allocator = new RootAllocator(); Location location = Location.forGrpcInsecure("0.0.0.0", port); - FlightSqlServiceImpl producer = new FlightSqlServiceImpl(location); - flightServer = FlightServer.builder(allocator, location, producer).build(); + this.flightTokenManager = new FlightTokenManagerImpl(Config.arrow_flight_token_cache_size, + Config.arrow_flight_token_alive_time); + this.flightSessionsManager = new FlightSessionsWithTokenManager(flightTokenManager); + + DorisFlightSqlProducer producer = new DorisFlightSqlProducer(location, flightSessionsManager); + flightServer = FlightServer.builder(allocator, location, producer) + .headerAuthenticator(new FlightBearerTokenAuthenticator(flightTokenManager)).build(); } // start Arrow Flight SQL service, return true if success, otherwise false diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightStatementExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightStatementExecutor.java index ced03350de..8c9cdf124f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightStatementExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightStatementExecutor.java @@ -21,21 +21,15 @@ package org.apache.doris.service.arrowflight; import org.apache.doris.analysis.Expr; -import org.apache.doris.analysis.UserIdentity; -import org.apache.doris.catalog.Env; import org.apache.doris.common.Status; import org.apache.doris.common.util.DebugUtil; import org.apache.doris.proto.InternalService; import org.apache.doris.proto.Types; -import org.apache.doris.qe.AutoCloseConnectContext; import org.apache.doris.qe.ConnectContext; -import org.apache.doris.qe.SessionVariable; import org.apache.doris.qe.StmtExecutor; import org.apache.doris.rpc.BackendServiceProxy; import org.apache.doris.rpc.RpcException; -import org.apache.doris.system.SystemInfoService; import org.apache.doris.thrift.TNetworkAddress; -import org.apache.doris.thrift.TResultSinkType; import org.apache.doris.thrift.TStatusCode; import org.apache.doris.thrift.TUniqueId; @@ -55,8 +49,8 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -public final class FlightStatementExecutor { - private AutoCloseConnectContext acConnectContext; +public final class FlightStatementExecutor implements AutoCloseable { + private ConnectContext connectContext; private final String query; private TUniqueId queryId; private TUniqueId finstId; @@ -64,9 +58,10 @@ public final class FlightStatementExecutor { private TNetworkAddress resultInternalServiceAddr; private ArrayList resultOutputExprs; - public FlightStatementExecutor(final String query) { + public FlightStatementExecutor(final String query, ConnectContext connectContext) { this.query = query; - acConnectContext = buildConnectContext(); + this.connectContext = connectContext; + connectContext.setThreadLocalInfo(); } public void setQueryId(TUniqueId queryId) { @@ -126,29 +121,14 @@ public final class FlightStatementExecutor { return Objects.hash(this); } - public static AutoCloseConnectContext buildConnectContext() { - ConnectContext connectContext = new ConnectContext(); - SessionVariable sessionVariable = connectContext.getSessionVariable(); - sessionVariable.internalSession = true; - sessionVariable.setEnablePipelineEngine(false); // TODO - sessionVariable.setEnablePipelineXEngine(false); // TODO - connectContext.setEnv(Env.getCurrentEnv()); - connectContext.setQualifiedUser(UserIdentity.ROOT.getQualifiedUser()); // TODO - connectContext.setCurrentUserIdentity(UserIdentity.ROOT); // TODO - connectContext.setStartTime(); - connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER); - connectContext.setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL); - return new AutoCloseConnectContext(connectContext); - } - public void executeQuery() { try { UUID uuid = UUID.randomUUID(); TUniqueId queryId = new TUniqueId(uuid.getMostSignificantBits(), uuid.getLeastSignificantBits()); setQueryId(queryId); - acConnectContext.connectContext.setQueryId(queryId); - StmtExecutor stmtExecutor = new StmtExecutor(acConnectContext.connectContext, getQuery()); - acConnectContext.connectContext.setExecutor(stmtExecutor); + connectContext.setQueryId(queryId); + StmtExecutor stmtExecutor = new StmtExecutor(connectContext, getQuery()); + connectContext.setExecutor(stmtExecutor); stmtExecutor.executeArrowFlightQuery(this); } catch (Exception e) { throw new RuntimeException("Failed to coord exec", e); @@ -221,4 +201,9 @@ public final class FlightStatementExecutor { DebugUtil.printId(tid), address), e); } } + + @Override + public void close() throws Exception { + ConnectContext.remove(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightAuthResult.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightAuthResult.java new file mode 100644 index 0000000000..d18925ae6e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightAuthResult.java @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from + +package org.apache.doris.service.arrowflight.auth2; + +import org.apache.doris.analysis.UserIdentity; + +import org.immutables.value.Value; + +/** + * Result of Authentication. + */ +@Value.Immutable +public interface FlightAuthResult { + String getUserName(); + + UserIdentity getUserIdentity(); + + String getRemoteIp(); + + static FlightAuthResult of(String userName, UserIdentity userIdentity, String remoteIp) { + return ImmutableFlightAuthResult.builder() + .userName(userName) + .userIdentity(userIdentity) + .remoteIp(remoteIp) + .build(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightAuthUtils.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightAuthUtils.java new file mode 100644 index 0000000000..b605dff66b --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightAuthUtils.java @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.auth2; + +import org.apache.doris.analysis.UserIdentity; +import org.apache.doris.catalog.Env; +import org.apache.doris.common.AuthenticationException; +import org.apache.doris.service.arrowflight.tokens.FlightTokenManager; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import org.apache.arrow.flight.CallStatus; +import org.apache.logging.log4j.Logger; + +import java.util.List; + +/** + * A collection of common Flight server authentication methods. + */ +public final class FlightAuthUtils { + private FlightAuthUtils() { + } + + /** + * Authenticate against with the provided credentials. + * + * @param username username. + * @param password password. + * @param logger the slf4j logger for logging. + * @throws org.apache.arrow.flight.FlightRuntimeException if unable to authenticate against + * with the provided credentials. + */ + public static FlightAuthResult authenticateCredentials(String username, String password, String remoteIp, + Logger logger) { + try { + List currentUserIdentity = Lists.newArrayList(); + + Env.getCurrentEnv().getAuth().checkPlainPassword(username, remoteIp, password, currentUserIdentity); + Preconditions.checkState(currentUserIdentity.size() == 1); + return FlightAuthResult.of(username, currentUserIdentity.get(0), remoteIp); + } catch (AuthenticationException e) { + logger.error("Unable to authenticate user {}", username, e); + final String errMsg = "Unable to authenticate user " + username + ", exception: " + e.getMessage(); + throw CallStatus.UNAUTHENTICATED.withCause(e).withDescription(errMsg).toRuntimeException(); + } + } + + /** + * Creates a new Bearer Token. Returns the bearer token associated with the User. + * + * @param flightTokenManager the TokenManager. + * @param username the user to create a Flight server session for. + * @param flightAuthResult the FlightAuthResult. + * @return the token associated with the FlightTokenDetails created. + */ + public static String createToken(FlightTokenManager flightTokenManager, String username, + FlightAuthResult flightAuthResult) { + return flightTokenManager.createToken(username, flightAuthResult).getToken(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightBearerTokenAuthenticator.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightBearerTokenAuthenticator.java new file mode 100644 index 0000000000..ef6e28b034 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightBearerTokenAuthenticator.java @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/auth2/DremioBearerTokenAuthenticator.java +// and modified by Doris + +package org.apache.doris.service.arrowflight.auth2; + +import org.apache.doris.service.arrowflight.tokens.FlightTokenManager; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.auth2.Auth2Constants; +import org.apache.arrow.flight.auth2.AuthUtilities; +import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Doris's custom implementation of CallHeaderAuthenticator for bearer token authentication. + * This class implements CallHeaderAuthenticator rather than BearerTokenAuthenticator. Doris + * creates FlightTokenDetails objects when the bearer token is created and requires access to the CallHeaders + * in getAuthResultWithBearerToken. + */ + +public class FlightBearerTokenAuthenticator implements CallHeaderAuthenticator { + private static final Logger LOG = LogManager.getLogger(FlightBearerTokenAuthenticator.class); + + private final CallHeaderAuthenticator initialAuthenticator; + private final FlightTokenManager flightTokenManager; + + public FlightBearerTokenAuthenticator(FlightTokenManager flightTokenManager) { + this.flightTokenManager = flightTokenManager; + this.initialAuthenticator = new BasicCallHeaderAuthenticator( + new FlightCredentialValidator(this.flightTokenManager)); + } + + /** + * If no bearer token is provided, the method initiates initial password and username + * authentication. Once authenticated, client properties are retrieved from incoming CallHeaders. + * Then it generates a token and creates a FlightTokenDetails with the retrieved client properties. + * associated with it. + *

+ * If a bearer token is provided, the method validates the provided token. + * + * @param incomingHeaders call headers to retrieve client properties and auth headers from. + * @return an AuthResult with the bearer token and peer identity. + */ + @Override + public AuthResult authenticate(CallHeaders incomingHeaders) { + final String bearerToken = AuthUtilities.getValueFromAuthHeader(incomingHeaders, + Auth2Constants.BEARER_PREFIX); + + if (bearerToken != null) { + return validateBearer(bearerToken); + } else { + final AuthResult result = initialAuthenticator.authenticate(incomingHeaders); + return createAuthResultWithBearerToken(result.getPeerIdentity()); + } + } + + /** + * Validates provided token. + * + * @param token the token to validate. + * @return an AuthResult with the bearer token and peer identity. + */ + AuthResult validateBearer(String token) { + try { + flightTokenManager.validateToken(token); + return createAuthResultWithBearerToken(token); + } catch (IllegalArgumentException e) { + LOG.error("Bearer token validation failed.", e); + throw CallStatus.UNAUTHENTICATED.toRuntimeException(); + } + } + + + /** + * Helper method to create an AuthResult. + * + * @param token the token to create a FlightTokenDetails for. + * @return a new AuthResult with functionality to add given bearer token to the outgoing header. + */ + private AuthResult createAuthResultWithBearerToken(String token) { + return new AuthResult() { + @Override + public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, + Auth2Constants.BEARER_PREFIX + token); + } + + @Override + public String getPeerIdentity() { + return token; + } + }; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightCredentialValidator.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightCredentialValidator.java new file mode 100644 index 0000000000..6676e8526e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightCredentialValidator.java @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/auth2/DremioCredentialValidator.java +// and modified by Doris + +package org.apache.doris.service.arrowflight.auth2; + +import org.apache.doris.service.arrowflight.tokens.FlightTokenManager; + +import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator.AuthResult; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Authentication specialized CredentialValidator implementation. + */ +public class FlightCredentialValidator implements BasicCallHeaderAuthenticator.CredentialValidator { + private static final Logger LOG = LogManager.getLogger(FlightCredentialValidator.class); + + private final FlightTokenManager flightTokenManager; + + public FlightCredentialValidator(FlightTokenManager flightTokenManager) { + this.flightTokenManager = flightTokenManager; + } + + /** + * Authenticates against with the provided username and password. + * + * @param username username. + * @param password user password. + * @return AuthResult with username as the peer identity. + */ + @Override + public AuthResult validate(String username, String password) { + // TODO Add ClientAddress information while creating a Token + String remoteIp = "0.0.0.0"; + FlightAuthResult flightAuthResult = FlightAuthUtils.authenticateCredentials(username, password, remoteIp, LOG); + return getAuthResultWithBearerToken(flightAuthResult); + } + + + /** + * Generates a bearer token, parses client properties from incoming headers, then creates a + * FlightTokenDetails associated with the generated token and client properties. + * + * @param flightAuthResult the FlightAuthResult from initial authentication, with peer identity captured. + * @return an FlightAuthResult with the bearer token and peer identity. + */ + AuthResult getAuthResultWithBearerToken(FlightAuthResult flightAuthResult) { + final String username = flightAuthResult.getUserName(); + final String token = FlightAuthUtils.createToken(flightTokenManager, username, flightAuthResult); + return () -> token; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java new file mode 100644 index 0000000000..ed01098c67 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from + +package org.apache.doris.service.arrowflight.sessions; + +import org.apache.doris.analysis.UserIdentity; +import org.apache.doris.catalog.Env; +import org.apache.doris.common.ErrorCode; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.service.ExecuteEnv; +import org.apache.doris.system.SystemInfoService; + +import org.apache.arrow.flight.CallStatus; + +/** + * Manages Flight User Session ConnectContext. + */ +public interface FlightSessionsManager { + + /** + * Resolves an existing ConnectContext for the given peerIdentity. + *

+ * + * @param peerIdentity identity after authorization + * @return The ConnectContext or null if no sessionId is given. + */ + ConnectContext getConnectContext(String peerIdentity); + + /** + * Creates a ConnectContext object and store it in the local cache, assuming that peerIdentity was already + * validated. + * + * @param peerIdentity identity after authorization + */ + ConnectContext createConnectContext(String peerIdentity); + + public static ConnectContext buildConnectContext(String peerIdentity, UserIdentity userIdentity, String remoteIP) { + ConnectContext connectContext = new ConnectContext(peerIdentity); + connectContext.setEnv(Env.getCurrentEnv()); + connectContext.setStartTime(); + connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER); + connectContext.getSessionVariable().setEnablePipelineEngine(false); // TODO + connectContext.getSessionVariable().setEnablePipelineXEngine(false); // TODO + connectContext.setQualifiedUser(userIdentity.getQualifiedUser()); + connectContext.setCurrentUserIdentity(userIdentity); + connectContext.setRemoteIP(remoteIP); + connectContext.setUserQueryTimeout( + connectContext.getEnv().getAuth().getQueryTimeout(connectContext.getQualifiedUser())); + connectContext.setUserInsertTimeout( + connectContext.getEnv().getAuth().getInsertTimeout(connectContext.getQualifiedUser())); + + connectContext.setConnectScheduler(ExecuteEnv.getInstance().getScheduler()); + if (!ExecuteEnv.getInstance().getScheduler().registerConnection(connectContext)) { + connectContext.getState().setError(ErrorCode.ERR_TOO_MANY_USER_CONNECTIONS, + "Reach limit of connections"); + throw CallStatus.UNAUTHENTICATED.withDescription("Reach limit of connections").toRuntimeException(); + } + return connectContext; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java new file mode 100644 index 0000000000..ce12f610ea --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.sessions; + +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.service.ExecuteEnv; +import org.apache.doris.service.arrowflight.tokens.FlightTokenDetails; +import org.apache.doris.service.arrowflight.tokens.FlightTokenManager; + +import org.apache.arrow.flight.CallStatus; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class FlightSessionsWithTokenManager implements FlightSessionsManager { + private static final Logger LOG = LogManager.getLogger(FlightSessionsWithTokenManager.class); + + private final FlightTokenManager flightTokenManager; + + public FlightSessionsWithTokenManager(FlightTokenManager flightTokenManager) { + this.flightTokenManager = flightTokenManager; + } + + @Override + public ConnectContext getConnectContext(String peerIdentity) { + ConnectContext connectContext = ExecuteEnv.getInstance().getScheduler().getContext(peerIdentity); + if (null == connectContext) { + connectContext = createConnectContext(peerIdentity); + if (null == connectContext) { + flightTokenManager.invalidateToken(peerIdentity); + String err = "UserSession expire after access, need reauthorize."; + LOG.error(err); + throw CallStatus.UNAUTHENTICATED.withDescription(err).toRuntimeException(); + } + return connectContext; + } + return connectContext; + } + + @Override + public ConnectContext createConnectContext(String peerIdentity) { + try { + final FlightTokenDetails flightTokenDetails = flightTokenManager.validateToken(peerIdentity); + if (flightTokenDetails.getCreatedSession()) { + return null; + } + return FlightSessionsManager.buildConnectContext(peerIdentity, flightTokenDetails.getUserIdentity(), + flightTokenDetails.getRemoteIp()); + } catch (IllegalArgumentException e) { + LOG.error("Bearer token validation failed.", e); + throw CallStatus.UNAUTHENTICATED.toRuntimeException(); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenDetails.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenDetails.java new file mode 100644 index 0000000000..be9166eead --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenDetails.java @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.tokens; + +import org.apache.doris.analysis.UserIdentity; + +import com.google.common.base.Preconditions; + +/** + * Details of a token. + */ +public final class FlightTokenDetails { + + private final String token; + private final String username; + private final long issuedAt; + private final long expiresAt; + private final String remoteIp; + private final UserIdentity userIdentity; + private boolean createdSession = false; + + public FlightTokenDetails() { + this.token = ""; + this.username = ""; + this.issuedAt = 0; + this.expiresAt = 0; + this.remoteIp = ""; + this.userIdentity = new UserIdentity(username, remoteIp); + } + + public FlightTokenDetails(String token, String username, long issuedAt, long expiresAt, UserIdentity userIdentity, + String remoteIp) { + Preconditions.checkNotNull(token); + Preconditions.checkNotNull(username); + this.token = token; + this.username = username; + this.issuedAt = issuedAt; + this.expiresAt = expiresAt; + this.remoteIp = remoteIp; + this.userIdentity = userIdentity; + } + + public FlightTokenDetails(String token, String username, long expiresAt) { + Preconditions.checkNotNull(token); + Preconditions.checkNotNull(username); + this.token = token; + this.username = username; + this.expiresAt = expiresAt; + this.issuedAt = 0; + this.remoteIp = ""; + this.userIdentity = new UserIdentity(username, remoteIp); + } + + public String getToken() { + return token; + } + + public String getUsername() { + return username; + } + + public long getIssuedAt() { + return issuedAt; + } + + public long getExpiresAt() { + return expiresAt; + } + + public String getRemoteIp() { + return remoteIp; + } + + public UserIdentity getUserIdentity() { + return userIdentity; + } + + public void setCreatedSession(boolean createdSession) { + this.createdSession = createdSession; + } + + public boolean getCreatedSession() { + return createdSession; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenManager.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenManager.java new file mode 100644 index 0000000000..23435c3c04 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenManager.java @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/tokens/TokenManager.java +// and modified by Doris + +package org.apache.doris.service.arrowflight.tokens; + +import org.apache.doris.service.arrowflight.auth2.FlightAuthResult; + +/** + * Token manager. + */ +public interface FlightTokenManager extends AutoCloseable { + + /** + * Generate a securely random token. + * + * @return a token string + */ + String newToken(); + + /** + * Create a token for the session, and return details about the token. + * + * @param username user name + * @param flightAuthResult auth result + * @return token details + */ + FlightTokenDetails createToken(String username, FlightAuthResult flightAuthResult); + + /** + * Validate the token, and return details about the token. + * + * @param token session token + * @return token details + * @throws IllegalArgumentException if the token is invalid or expired + */ + FlightTokenDetails validateToken(String token) throws IllegalArgumentException; + + /** + * Invalidate the token. + * + * @param token session token + */ + void invalidateToken(String token); + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenManagerImpl.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenManagerImpl.java new file mode 100644 index 0000000000..54e53e931d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/tokens/FlightTokenManagerImpl.java @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/tokens/TokenManagerImpl.java +// and modified by Doris + +package org.apache.doris.service.arrowflight.tokens; + +import org.apache.doris.service.arrowflight.auth2.FlightAuthResult; + +import com.google.common.base.Preconditions; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.math.BigInteger; +import java.security.SecureRandom; +import java.util.concurrent.TimeUnit; + +/** + * Token manager implementation. + */ +public class FlightTokenManagerImpl implements FlightTokenManager { + private static final Logger LOG = LogManager.getLogger(FlightTokenManagerImpl.class); + + private final SecureRandom generator = new SecureRandom(); + private final int cacheExpiration; + + private LoadingCache tokenCache; + + public FlightTokenManagerImpl(final int cacheSize, final int cacheExpiration) { + this.cacheExpiration = cacheExpiration; + + this.tokenCache = CacheBuilder.newBuilder() + .maximumSize(cacheSize) + .expireAfterWrite(cacheExpiration, TimeUnit.MINUTES) + .build(new CacheLoader() { + @Override + public FlightTokenDetails load(String key) { + return new FlightTokenDetails(); + } + }); + } + + // From https://stackoverflow.com/questions/41107/how-to-generate-a-random-alpha-numeric-string + // ... This works by choosing 130 bits from a cryptographically secure random bit generator, and encoding + // them in base-32. 128 bits is considered to be cryptographically strong, but each digit in a base 32 + // number can encode 5 bits, so 128 is rounded up to the next multiple of 5 ... Why 32? Because 32 = 2^5; + // each character will represent exactly 5 bits, and 130 bits can be evenly divided into characters. + @Override + public String newToken() { + return new BigInteger(130, generator).toString(32); + } + + @Override + public FlightTokenDetails createToken(final String username, final FlightAuthResult flightAuthResult) { + final String token = newToken(); + final long now = System.currentTimeMillis(); + final long expires = now + TimeUnit.MILLISECONDS.convert(cacheExpiration, TimeUnit.MINUTES); + + final FlightTokenDetails flightTokenDetails = new FlightTokenDetails(token, username, now, expires, + flightAuthResult.getUserIdentity(), flightAuthResult.getRemoteIp()); + + tokenCache.put(token, flightTokenDetails); + LOG.trace("Created flight token for user: {}", username); + return flightTokenDetails; + } + + @Override + public FlightTokenDetails validateToken(final String token) throws IllegalArgumentException { + final FlightTokenDetails value = getTokenDetails(token); + if (System.currentTimeMillis() >= value.getExpiresAt()) { + tokenCache.invalidate(token); // removes from the store as well + throw new IllegalArgumentException("token expired"); + } + + LOG.trace("Validated flight token for user: {}", value.getUsername()); + return value; + } + + @Override + public void invalidateToken(final String token) { + LOG.trace("Invalidate flight token, {}", token); + tokenCache.invalidate(token); // removes from the store as well + } + + private FlightTokenDetails getTokenDetails(final String token) { + Preconditions.checkNotNull(token, "invalid token"); + final FlightTokenDetails value; + try { + value = tokenCache.getUnchecked(token); + } catch (CacheLoader.InvalidCacheLoadException ignored) { + throw new IllegalArgumentException("invalid token"); + } + + return value; + } + + @Override + public void close() throws Exception { + tokenCache.invalidateAll(); + } +} diff --git a/fe/pom.xml b/fe/pom.xml index 12573bd057..0e8193bc53 100644 --- a/fe/pom.xml +++ b/fe/pom.xml @@ -314,6 +314,7 @@ under the License. 6.5.1 2.0.3 1.5.4 + 2.9.3 3.0.0 2.3.2 @@ -1524,6 +1525,12 @@ under the License. arrow-jdbc ${arrow.version} + + org.immutables + value + ${immutables.version} + provided +